Skip to content

Commit 64b2368

Browse files
committed
Implemented LinearFlow for both local and prefect runner (and everywhere else)
1 parent 6bbfecb commit 64b2368

File tree

15 files changed

+448
-38
lines changed

15 files changed

+448
-38
lines changed

src/unifair/compute/flow.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
66

77
from unifair.compute.job import CallableDecoratingJobTemplateMixin, Job, JobConfig, JobTemplate
8-
from unifair.engine.protocols import IsDagFlowRunnerEngine, IsFuncFlowRunnerEngine, IsTaskTemplate
8+
from unifair.engine.protocols import (IsDagFlowRunnerEngine,
9+
IsFuncFlowRunnerEngine,
10+
IsLinearFlowRunnerEngine,
11+
IsTaskTemplate)
912
from unifair.util.callable_decorator_cls import callable_decorator_cls
1013

1114

@@ -50,6 +53,73 @@ def _call_func(self, *args: Any, **kwargs: Any) -> Any:
5053
pass
5154

5255

56+
class LinearFlowConfig(FlowConfig):
57+
def __init__(
58+
self,
59+
linear_flow_func: Callable,
60+
*task_templates: IsTaskTemplate,
61+
name: Optional[str] = None,
62+
**kwargs: Any,
63+
):
64+
self._linear_flow_func = linear_flow_func
65+
self._linear_flow_func_signature = inspect.signature(self._linear_flow_func)
66+
self._task_templates: Tuple[IsTaskTemplate] = task_templates
67+
name = name if name is not None else self._linear_flow_func.__name__
68+
super().__init__(name=name, **kwargs)
69+
70+
def _get_init_arg_values(self) -> Union[Tuple[()], Tuple[Any, ...]]:
71+
return self._linear_flow_func, *self._task_templates
72+
73+
def _get_init_kwarg_public_property_keys(self) -> Tuple[str, ...]:
74+
return ()
75+
76+
@property
77+
def task_templates(self) -> Tuple[IsTaskTemplate]:
78+
return self._task_templates
79+
80+
def has_coroutine_func(self) -> bool:
81+
return asyncio.iscoroutinefunction(self._linear_flow_func)
82+
83+
@property
84+
def param_signatures(self) -> MappingProxyType:
85+
return self._linear_flow_func_signature.parameters
86+
87+
@property
88+
def return_type(self) -> Type[Any]:
89+
return self._linear_flow_func_signature.return_annotation
90+
91+
92+
@callable_decorator_cls
93+
class LinearFlowTemplate(CallableDecoratingJobTemplateMixin['LinearFlowTemplate'],
94+
FlowTemplate,
95+
LinearFlowConfig):
96+
@classmethod
97+
def _get_job_subcls_for_apply(cls) -> Type[Job]:
98+
return LinearFlow
99+
100+
def _apply_engine_decorator(self, flow: 'LinearFlow') -> 'LinearFlow':
101+
if self.engine is not None and isinstance(self.engine, IsLinearFlowRunnerEngine):
102+
return self.engine.linear_flow_decorator(flow) # noqa # Pycharm static type checker bug
103+
else:
104+
raise RuntimeError(f'Engine "{self.engine}" does not support DAG flows')
105+
106+
107+
class LinearFlow(Flow, LinearFlowConfig):
108+
@classmethod
109+
def _get_job_config_subcls_for_init(cls) -> Type[JobConfig]:
110+
return LinearFlowConfig
111+
112+
@classmethod
113+
def _get_job_template_subcls_for_revise(cls) -> Type[JobTemplate]:
114+
return LinearFlowTemplate # noqa # Pycharm static type checker bug
115+
116+
def _call_func(self, *args: Any, **kwargs: Any) -> Any:
117+
raise NotImplementedError
118+
119+
def get_call_args(self, *args: object, **kwargs: object) -> Dict[str, object]:
120+
return inspect.signature(self._linear_flow_func).bind(*args, **kwargs).arguments
121+
122+
53123
class DagFlowConfig(FlowConfig):
54124
def __init__(
55125
self,

src/unifair/engine/job_runner.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import inspect
33
import sys
44
from types import AsyncGeneratorType, GeneratorType
5-
from typing import Any, Awaitable, Callable, cast
5+
from typing import Any, Awaitable, Callable, cast, Iterable
66

77
from unifair.engine.base import Engine
88
from unifair.engine.constants import RunState
9-
from unifair.engine.protocols import IsDagFlow, IsFuncFlow, IsJob, IsTask
9+
from unifair.engine.protocols import IsDagFlow, IsFuncFlow, IsJob, IsLinearFlow, IsTask
1010

1111

1212
class JobRunnerEngine(Engine, ABC):
@@ -81,6 +81,48 @@ def _run_task(self, state: Any, task: IsTask, call_func: Callable, *args, **kwar
8181
...
8282

8383

84+
class LinearFlowRunnerEngine(JobRunnerEngine):
85+
def linear_flow_decorator(self, linear_flow: IsLinearFlow) -> IsLinearFlow:
86+
# prev_call_func = flow._call_func # Only raises error anyway
87+
88+
self._register_job_state(linear_flow, RunState.INITIALIZED)
89+
state = self._init_linear_flow(linear_flow)
90+
91+
def _call_func(*args: object, **kwargs: object) -> Any:
92+
self._register_job_state(linear_flow, RunState.RUNNING)
93+
flow_result = self._run_linear_flow(state, linear_flow, *args, **kwargs)
94+
return self._decorate_result_with_job_finalization_detector(linear_flow, flow_result)
95+
96+
setattr(linear_flow, '_call_func', _call_func)
97+
return linear_flow
98+
99+
@staticmethod
100+
def default_linear_flow_run_decorator(linear_flow: IsLinearFlow) -> Any:
101+
def _inner_run_linear_flow(*args: object, **kwargs: object):
102+
103+
result = None
104+
for i, job in enumerate(linear_flow.task_templates):
105+
with linear_flow.flow_context:
106+
# TODO: Better handling or kwargs
107+
if i == 0:
108+
result = job(*args, **kwargs)
109+
else:
110+
result = job(*args)
111+
112+
args = result if isinstance(result, Iterable) else [result]
113+
return result
114+
115+
return _inner_run_linear_flow
116+
117+
@abstractmethod
118+
def _init_linear_flow(self, linear_flow: IsLinearFlow) -> Any:
119+
...
120+
121+
@abstractmethod
122+
def _run_linear_flow(self, state: Any, linear_flow: IsLinearFlow, *args, **kwargs) -> Any:
123+
...
124+
125+
84126
class DagFlowRunnerEngine(JobRunnerEngine):
85127
def dag_flow_decorator(self, dag_flow: IsDagFlow) -> IsDagFlow:
86128
# prev_call_func = flow._call_func # Only raises error anyway

src/unifair/engine/local.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
from typing import Any, Callable, Type
22

33
from unifair.config.engine import LocalRunnerConfig
4-
from unifair.engine.job_runner import DagFlowRunnerEngine, FuncFlowRunnerEngine, TaskRunnerEngine
5-
from unifair.engine.protocols import IsDagFlow, IsFuncFlow, IsLocalRunnerConfig, IsTask
6-
7-
8-
class LocalRunner(TaskRunnerEngine, DagFlowRunnerEngine, FuncFlowRunnerEngine):
4+
from unifair.engine.job_runner import (DagFlowRunnerEngine,
5+
FuncFlowRunnerEngine,
6+
LinearFlowRunnerEngine,
7+
TaskRunnerEngine)
8+
from unifair.engine.protocols import (IsDagFlow,
9+
IsFuncFlow,
10+
IsLinearFlow,
11+
IsLocalRunnerConfig,
12+
IsTask)
13+
14+
15+
class LocalRunner(TaskRunnerEngine,
16+
LinearFlowRunnerEngine,
17+
DagFlowRunnerEngine,
18+
FuncFlowRunnerEngine):
919
def _init_engine(self) -> None:
1020
...
1121

@@ -22,6 +32,12 @@ def _init_task(self, task: IsTask, call_func: Callable) -> Any:
2232
def _run_task(self, state: Any, task: IsTask, call_func: Callable, *args, **kwargs) -> Any:
2333
return call_func(*args, **kwargs)
2434

35+
def _init_linear_flow(self, flow: IsLinearFlow) -> Any:
36+
...
37+
38+
def _run_linear_flow(self, state: Any, flow: IsLinearFlow, *args, **kwargs) -> Any:
39+
return self.default_linear_flow_run_decorator(flow)(*args, **kwargs)
40+
2541
def _init_dag_flow(self, flow: IsDagFlow) -> Any:
2642
...
2743

src/unifair/engine/prefect.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,22 @@
99
from prefect.tasks import task_input_hash
1010

1111
from unifair.config.engine import PrefectEngineConfig
12-
from unifair.engine.job_runner import DagFlowRunnerEngine, FuncFlowRunnerEngine, TaskRunnerEngine
13-
from unifair.engine.protocols import IsDagFlow, IsFuncFlow, IsPrefectEngineConfig, IsTask
12+
from unifair.engine.job_runner import (DagFlowRunnerEngine,
13+
FuncFlowRunnerEngine,
14+
LinearFlowRunnerEngine,
15+
TaskRunnerEngine)
16+
from unifair.engine.protocols import (IsDagFlow,
17+
IsFuncFlow,
18+
IsLinearFlow,
19+
IsPrefectEngineConfig,
20+
IsTask)
1421
from unifair.util.helpers import resolve
1522

1623

17-
class PrefectEngine(TaskRunnerEngine, DagFlowRunnerEngine, FuncFlowRunnerEngine):
24+
class PrefectEngine(TaskRunnerEngine,
25+
LinearFlowRunnerEngine,
26+
DagFlowRunnerEngine,
27+
FuncFlowRunnerEngine):
1828
def _init_engine(self) -> None:
1929
...
2030

@@ -71,6 +81,33 @@ def task_flow(*inner_args, **inner_kwargs):
7181

7282
return task_flow(*args, **kwargs)
7383

84+
# LinearFlowRunnerEngine
85+
def _init_linear_flow(self, linear_flow: IsLinearFlow) -> Any:
86+
assert isinstance(self._config, PrefectEngineConfig)
87+
flow_kwargs = dict(name=linear_flow.name,)
88+
call_func = self.default_linear_flow_run_decorator(linear_flow)
89+
90+
if linear_flow.has_coroutine_func():
91+
92+
@prefect_flow(**flow_kwargs)
93+
async def run_linear_flow(*inner_args, **inner_kwargs):
94+
with linear_flow.flow_context:
95+
return await resolve(call_func(*inner_args, **inner_kwargs))
96+
else:
97+
98+
@prefect_flow(**flow_kwargs)
99+
def run_linear_flow(*inner_args, **inner_kwargs):
100+
with linear_flow.flow_context:
101+
return call_func(*inner_args, **inner_kwargs)
102+
103+
return run_linear_flow
104+
105+
def _run_linear_flow(self, state: Any, linear_flow: IsLinearFlow, *args, **kwargs) -> Any:
106+
107+
_prefect_flow = state
108+
109+
return _prefect_flow(*args, **kwargs)
110+
74111
# DagFlowRunnerEngine
75112

76113
def _init_dag_flow(self, dag_flow: IsDagFlow) -> Any:

src/unifair/engine/protocols.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,20 @@ def apply(self) -> IsFlow:
136136
...
137137

138138

139+
class IsLinearFlow(IsFlow, Protocol):
140+
task_templates: Tuple[IsTaskTemplate]
141+
142+
def has_coroutine_func(self) -> bool:
143+
...
144+
145+
def get_call_args(self, *args, **kwargs) -> Dict[str, object]:
146+
...
147+
148+
149+
class IsLinearFlowTemplate(IsLinearFlow, IsFlowTemplate, Protocol):
150+
...
151+
152+
139153
class IsDagFlow(IsFlow, Protocol):
140154
task_templates: Tuple[IsTaskTemplate]
141155

@@ -178,6 +192,12 @@ def task_decorator(self, task: IsTask) -> IsTask:
178192
...
179193

180194

195+
@runtime_checkable
196+
class IsLinearFlowRunnerEngine(IsEngine, Protocol):
197+
def linear_flow_decorator(self, linear_flow: IsLinearFlow) -> IsLinearFlow:
198+
...
199+
200+
181201
@runtime_checkable
182202
class IsDagFlowRunnerEngine(IsEngine, Protocol):
183203
def dag_flow_decorator(self, dag_flow: IsDagFlow) -> IsDagFlow:

tests/compute/cases/decorators.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest_cases as pc
22

3-
from unifair.compute.flow import DagFlowTemplate, FuncFlowTemplate
3+
from unifair.compute.flow import DagFlowTemplate, FuncFlowTemplate, LinearFlowTemplate
44
from unifair.compute.task import TaskTemplate
55

66

@@ -31,6 +31,34 @@ def plus_other(number: int, other: int) -> int:
3131
return plus_other # noqa # Pycharm static type checker bug
3232

3333

34+
@pc.case(
35+
id='linear_flow-plus_five(number)',
36+
tags=['sync', 'function', 'linear_flow', 'plain'],
37+
)
38+
@pc.parametrize_with_cases('plus_one_template', cases='.', has_tag='task')
39+
def case_linear_flow_number_plus_five_template(plus_one_template) -> LinearFlowTemplate:
40+
@LinearFlowTemplate(*((plus_one_template,) * 5))
41+
def plus_five(number: int) -> int: # noqa
42+
...
43+
44+
return plus_five # noqa # Pycharm static type checker bug
45+
46+
47+
@pc.case(
48+
id='linear_flow-plus_five(x)',
49+
tags=['sync', 'function', 'linear_flow', 'with_kw_params'],
50+
)
51+
@pc.parametrize_with_cases('plus_one_template', cases='.', has_tag='task')
52+
def case_linear_flow_x_plus_five_template(plus_one_template) -> FuncFlowTemplate:
53+
iterative_x_plus_one_template = plus_one_template.refine(param_key_map=dict(number='x'),)
54+
55+
@LinearFlowTemplate(*((iterative_x_plus_one_template,) * 5))
56+
def plus_five(x: int) -> int: # noqa
57+
...
58+
59+
return plus_five # noqa # Pycharm static type checker bug
60+
61+
3462
@pc.case(
3563
id='dag_flow-plus_five(number)',
3664
tags=['sync', 'function', 'dag_flow', 'plain'],

tests/compute/cases/flows.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import pytest_cases as pc
77

8-
from unifair.compute.flow import DagFlowTemplate, FlowTemplate, FuncFlowTemplate
8+
from unifair.compute.flow import DagFlowTemplate, FlowTemplate, FuncFlowTemplate, LinearFlowTemplate
99
from unifair.compute.task import TaskTemplate
1010

1111
from .tasks import TaskCase
@@ -28,6 +28,26 @@ class FlowCase(Generic[ArgT, ReturnT]):
2828
# TODO: Add assert_signature_and_return_type_func
2929

3030

31+
@pc.case(
32+
id='sync-linearflow-single-task',
33+
tags=['sync', 'linearflow', 'singlethread', 'success'],
34+
)
35+
@pc.parametrize_with_cases('task_case', cases='.tasks')
36+
def case_sync_linearflow_single_task(task_case: TaskCase) -> FlowCase[[], None]:
37+
task_template = TaskTemplate(task_case.task_func)
38+
linear_flow = LinearFlowTemplate(task_template)(task_case.task_func)
39+
40+
return FlowCase(
41+
name=task_case.name,
42+
flow_func=task_case.task_func,
43+
flow_template=linear_flow,
44+
args=task_case.args,
45+
kwargs=task_case.kwargs,
46+
assert_results_func=task_case.assert_results_func,
47+
# assert_signature_and_return_type_func=task_case.assert_signature_and_return_type_func,
48+
)
49+
50+
3151
@pc.case(
3252
id='sync-dagflow-single-task',
3353
tags=['sync', 'dagflow', 'singlethread', 'success'],

tests/compute/helpers/mocks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union
33

44
from unifair.compute.job import CallableDecoratingJobTemplateMixin, Job, JobConfig, JobTemplate
5-
from unifair.engine.job_runner import DagFlowRunnerEngine
5+
from unifair.engine.job_runner import DagFlowRunnerEngine, LinearFlowRunnerEngine
66
from unifair.engine.protocols import (IsDagFlow,
77
IsEngineConfig,
88
IsFuncFlow,
9+
IsLinearFlow,
910
IsRunStateRegistry,
1011
IsTask)
1112
from unifair.util.callable_decorator_cls import callable_decorator_cls
@@ -213,6 +214,10 @@ def _call_func(*args: Any, **kwargs: Any) -> Any:
213214
setattr(task, '_call_func', _call_func)
214215
return task
215216

217+
def linear_flow_decorator(self, flow: IsLinearFlow) -> IsLinearFlow: # noqa
218+
setattr(flow, '_call_func', LinearFlowRunnerEngine.default_linear_flow_run_decorator(flow))
219+
return flow
220+
216221
def dag_flow_decorator(self, flow: IsDagFlow) -> IsDagFlow: # noqa
217222
setattr(flow, '_call_func', DagFlowRunnerEngine.default_dag_flow_run_decorator(flow))
218223
return flow

0 commit comments

Comments
 (0)