|
5 | 5 | from typing import Any, Callable, Dict, Optional, Tuple, Type, Union |
6 | 6 |
|
7 | 7 | from unifair.compute.job import CallableDecoratingJobTemplateMixin, Job, JobConfig, JobTemplate |
8 | | -from unifair.engine.protocols import IsDagFlowRunnerEngine, IsFuncFlowRunnerEngine, IsTaskTemplate |
| 8 | +from unifair.engine.protocols import (IsDagFlowRunnerEngine, |
| 9 | + IsFuncFlowRunnerEngine, |
| 10 | + IsLinearFlowRunnerEngine, |
| 11 | + IsTaskTemplate) |
9 | 12 | from unifair.util.callable_decorator_cls import callable_decorator_cls |
10 | 13 |
|
11 | 14 |
|
@@ -50,6 +53,73 @@ def _call_func(self, *args: Any, **kwargs: Any) -> Any: |
50 | 53 | pass |
51 | 54 |
|
52 | 55 |
|
| 56 | +class LinearFlowConfig(FlowConfig): |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + linear_flow_func: Callable, |
| 60 | + *task_templates: IsTaskTemplate, |
| 61 | + name: Optional[str] = None, |
| 62 | + **kwargs: Any, |
| 63 | + ): |
| 64 | + self._linear_flow_func = linear_flow_func |
| 65 | + self._linear_flow_func_signature = inspect.signature(self._linear_flow_func) |
| 66 | + self._task_templates: Tuple[IsTaskTemplate] = task_templates |
| 67 | + name = name if name is not None else self._linear_flow_func.__name__ |
| 68 | + super().__init__(name=name, **kwargs) |
| 69 | + |
| 70 | + def _get_init_arg_values(self) -> Union[Tuple[()], Tuple[Any, ...]]: |
| 71 | + return self._linear_flow_func, *self._task_templates |
| 72 | + |
| 73 | + def _get_init_kwarg_public_property_keys(self) -> Tuple[str, ...]: |
| 74 | + return () |
| 75 | + |
| 76 | + @property |
| 77 | + def task_templates(self) -> Tuple[IsTaskTemplate]: |
| 78 | + return self._task_templates |
| 79 | + |
| 80 | + def has_coroutine_func(self) -> bool: |
| 81 | + return asyncio.iscoroutinefunction(self._linear_flow_func) |
| 82 | + |
| 83 | + @property |
| 84 | + def param_signatures(self) -> MappingProxyType: |
| 85 | + return self._linear_flow_func_signature.parameters |
| 86 | + |
| 87 | + @property |
| 88 | + def return_type(self) -> Type[Any]: |
| 89 | + return self._linear_flow_func_signature.return_annotation |
| 90 | + |
| 91 | + |
| 92 | +@callable_decorator_cls |
| 93 | +class LinearFlowTemplate(CallableDecoratingJobTemplateMixin['LinearFlowTemplate'], |
| 94 | + FlowTemplate, |
| 95 | + LinearFlowConfig): |
| 96 | + @classmethod |
| 97 | + def _get_job_subcls_for_apply(cls) -> Type[Job]: |
| 98 | + return LinearFlow |
| 99 | + |
| 100 | + def _apply_engine_decorator(self, flow: 'LinearFlow') -> 'LinearFlow': |
| 101 | + if self.engine is not None and isinstance(self.engine, IsLinearFlowRunnerEngine): |
| 102 | + return self.engine.linear_flow_decorator(flow) # noqa # Pycharm static type checker bug |
| 103 | + else: |
| 104 | + raise RuntimeError(f'Engine "{self.engine}" does not support DAG flows') |
| 105 | + |
| 106 | + |
| 107 | +class LinearFlow(Flow, LinearFlowConfig): |
| 108 | + @classmethod |
| 109 | + def _get_job_config_subcls_for_init(cls) -> Type[JobConfig]: |
| 110 | + return LinearFlowConfig |
| 111 | + |
| 112 | + @classmethod |
| 113 | + def _get_job_template_subcls_for_revise(cls) -> Type[JobTemplate]: |
| 114 | + return LinearFlowTemplate # noqa # Pycharm static type checker bug |
| 115 | + |
| 116 | + def _call_func(self, *args: Any, **kwargs: Any) -> Any: |
| 117 | + raise NotImplementedError |
| 118 | + |
| 119 | + def get_call_args(self, *args: object, **kwargs: object) -> Dict[str, object]: |
| 120 | + return inspect.signature(self._linear_flow_func).bind(*args, **kwargs).arguments |
| 121 | + |
| 122 | + |
53 | 123 | class DagFlowConfig(FlowConfig): |
54 | 124 | def __init__( |
55 | 125 | self, |
|
0 commit comments