Skip to content

Commit bc7cb77

Browse files
committed
Fixed parameter handling bugs for async iterate_over_data_types. output_dataset param is now keyword-only
1 parent 85507d1 commit bc7cb77

File tree

5 files changed

+78
-15
lines changed

5 files changed

+78
-15
lines changed

src/omnipy/compute/mixins/iterate.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def _generate_new_signature_for_iteration(self, job_func: Callable) -> None:
178178
else:
179179
output_dataset_param = Parameter(
180180
name=self._output_dataset_param,
181-
kind=Parameter.POSITIONAL_OR_KEYWORD,
181+
kind=Parameter.KEYWORD_ONLY,
182+
default=None,
182183
annotation=output_dataset_cls,
183184
)
184185
rest_params = rest_params + [output_dataset_param]
@@ -200,19 +201,28 @@ def _extract_output_dataset(
200201
self_as_signature_func_job_base_mixin = cast(SignatureFuncJobBaseMixin, self)
201202

202203
if self._output_dataset_param:
204+
if self._output_dataset_param_in_func and self._output_dataset_param not in kwargs:
205+
kwargs = kwargs.copy()
206+
kwargs[self._output_dataset_param] = cast(
207+
Dataset, self_as_signature_func_job_base_mixin.return_type())
208+
203209
bound_args = self_as_signature_func_job_base_mixin.get_bound_args(
204210
dataset, *args, **kwargs)
205-
output_dataset: Dataset = bound_args.arguments[self._output_dataset_param]
206-
207-
if self._output_dataset_param_in_func:
208-
return_args = bound_args.args[1:]
209-
else:
210-
return_args = bound_args.args[1:-1]
211+
output_dataset: Dataset | None = bound_args.arguments[self._output_dataset_param]
211212

212-
return output_dataset, return_args, bound_args.kwargs
213+
return_args = bound_args.args[1:]
214+
return_kwargs = bound_args.kwargs
215+
if not self._output_dataset_param_in_func:
216+
return_kwargs.pop(self._output_dataset_param)
213217
else:
218+
output_dataset = None
219+
return_args = args
220+
return_kwargs = kwargs
221+
222+
if output_dataset is None:
214223
output_dataset = cast(Dataset, self_as_signature_func_job_base_mixin.return_type())
215-
return output_dataset, args, kwargs
224+
225+
return output_dataset, return_args, return_kwargs
216226

217227
def _prepare_data_arg(self, data_file):
218228
return data_file if is_model_subclass(self._input_dataset_type) else data_file.contents

tests/compute/cases/iterate_tasks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
async_single_int_plus_future_return_alphanum_string_func,
1313
async_single_int_plus_int_return_str_func,
1414
async_single_int_plus_int_return_str_model_with_output_str_dataset_func,
15+
single_int_model_plus_default_int_pair_return_str_model_func,
1516
single_int_model_plus_int_return_str_model_func,
1617
single_int_plus_int_return_str_func,
1718
single_int_plus_int_return_str_model_with_output_str_dataset_func,
@@ -89,6 +90,28 @@ def case_sync_single_int_model_plus_int_return_str_model_func() -> IterateDataFi
8990
)
9091

9192

93+
@pc.case(
94+
id='sync_single_int_model_plus_default_int_pair_return_str_model_func',
95+
tags=[
96+
'sync',
97+
'function',
98+
'iterate',
99+
'no_output_dataset',
100+
'str_output_dataset',
101+
'int_output_dataset'
102+
],
103+
)
104+
def case_sync_single_int_model_plus_default_int_pair_return_str_model_func(
105+
) -> IterateDataFilesCase:
106+
return IterateDataFilesCase(
107+
task_func=single_int_model_plus_default_int_pair_return_str_model_func,
108+
args=(1,),
109+
kwargs={'other_number': 1},
110+
func_is_async=False,
111+
iterate_over_data_files=True,
112+
)
113+
114+
92115
@pc.case(
93116
id='sync_single_int_plus_int_return_str_func',
94117
tags=[

tests/compute/cases/raw/functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,15 @@ def single_int_model_plus_int_return_str_model_func(
6464
data_number: Model[int],
6565
number: int,
6666
) -> Model[str]:
67-
return str(data_number.contents + number)
67+
return str(data_number.contents + number) # type: ignore[return-value]
68+
69+
70+
def single_int_model_plus_default_int_pair_return_str_model_func(
71+
data_number: Model[int],
72+
number: int = 0,
73+
other_number: int = 0,
74+
) -> Model[str]:
75+
return str(data_number.contents + number + other_number) # type: ignore[return-value]
6876

6977

7078
def single_int_plus_int_return_str_func(data_number: int, number: int) -> str:
@@ -76,7 +84,7 @@ def single_int_plus_int_return_str_model_with_output_str_dataset_func(
7684
number: int,
7785
output_dataset: Dataset[Model[str]],
7886
) -> Model[str]:
79-
return str(data_number + number)
87+
return str(data_number + number) # type: ignore[return-value]
8088

8189

8290
def single_int_plus_int_return_str_with_output_int_dataset_func(
@@ -92,7 +100,7 @@ async def async_single_int_model_plus_int_return_str_model_func(
92100
number: int,
93101
) -> Model[str]:
94102
await asyncio.sleep(random() / 10.0)
95-
return str(data_number.contents + number)
103+
return str(data_number.contents + number) # type: ignore[return-value]
96104

97105

98106
async def async_single_int_plus_int_return_str_func(data_number: int, number: int) -> str:
@@ -114,7 +122,7 @@ async def async_single_int_plus_int_return_str_model_with_output_str_dataset_fun
114122
output_dataset: Dataset[Model[str]],
115123
) -> Model[str]:
116124
await asyncio.sleep(random() / 10.0)
117-
return str(data_number + number)
125+
return str(data_number + number) # type: ignore[return-value]
118126

119127

120128
async def async_single_int_plus_future_int_return_str_func(

tests/compute/mixins/test_iterate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ async def test_iterate_over_data_files_task(case: IterateDataFilesCase) -> None:
131131
_assert_str_result(case, returned_dataset)
132132

133133

134+
@pc.parametrize_with_cases(
135+
'case', cases='..cases.iterate_tasks', has_tag=['iterate', 'str_output_dataset'])
136+
async def test_iterate_over_data_files_with_default_output_dataset_param_task(
137+
case: IterateDataFilesCase) -> None:
138+
139+
task_template = TaskTemplate(
140+
iterate_over_data_files=case.iterate_over_data_files,
141+
output_dataset_param='output_dataset')(
142+
case.task_func)
143+
144+
dataset = Dataset[Model[int]](dict(a=3, b=5, c=-2))
145+
146+
dataset_or_task = _run_task_template(case, task_template, dataset)
147+
returned_dataset = await _ensure_dataset_await_if_task(case, dataset_or_task)
148+
149+
_assert_str_result(case, returned_dataset)
150+
151+
134152
@pc.parametrize_with_cases(
135153
'case', cases='..cases.iterate_tasks', has_tag=['iterate', 'str_output_dataset'])
136154
async def test_iterate_over_data_files_with_output_dataset_param_task(

tests/compute/mixins/test_mixin_integration.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def test_iterate_over_data_files_func_signature_output_dataset_param() -> None:
6262
'output_dataset':
6363
Parameter(
6464
'output_dataset',
65-
Parameter.POSITIONAL_OR_KEYWORD,
65+
Parameter.KEYWORD_ONLY,
66+
default=None,
6667
annotation=Dataset[Model[str]])
6768
}
6869
assert task_obj.return_type is Dataset[Model[str]]
@@ -85,7 +86,10 @@ def test_iterate_over_data_files_func_signature_output_dataset_param_and_cls() -
8586
Parameter('number', Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
8687
'output_dataset':
8788
Parameter(
88-
'output_dataset', Parameter.POSITIONAL_OR_KEYWORD, annotation=CustomStrDataset)
89+
'output_dataset',
90+
Parameter.KEYWORD_ONLY,
91+
default=None,
92+
annotation=CustomStrDataset)
8993
}
9094
assert task_obj.return_type is CustomStrDataset
9195

0 commit comments

Comments
 (0)