Skip to content

Commit ed7eaa5

Browse files
committed
Added support for ParamDataset
1 parent 402cfc7 commit ed7eaa5

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

src/omnipy/modules/general/tasks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from omnipy.compute.task import TaskTemplate
55
from omnipy.compute.typing import mypy_fix_task_template
6-
from omnipy.data.dataset import Dataset, ModelT
6+
from omnipy.data.dataset import Dataset, ModelT, ParamDataset
77
from omnipy.data.model import Model
88

99

@@ -58,5 +58,5 @@ def import_directory(directory: str,
5858

5959
@mypy_fix_task_template
6060
@TaskTemplate
61-
def convert_dataset(dataset: Dataset, dataset_cls: type[_DatasetT]) -> _DatasetT:
62-
return dataset_cls(dataset)
61+
def convert_dataset(dataset: Dataset, dataset_cls: type[_DatasetT], **kwargs: object) -> _DatasetT:
62+
return dataset_cls(dataset, **kwargs)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from omnipy import convert_dataset
2+
from omnipy.data.dataset import Dataset, ParamDataset
3+
from omnipy.data.model import Model, ParamModel
4+
5+
6+
class FloatModel(Model[float]):
7+
...
8+
9+
10+
class IntModel(Model[int]):
11+
...
12+
13+
14+
class RoundedIntModel(ParamModel[float | int, bool]):
15+
@classmethod
16+
def _parse_data(cls, data: float | int, round_to_nearest: bool = False) -> int:
17+
if isinstance(data, int):
18+
return data
19+
20+
return round(data) if round_to_nearest else int(data)
21+
22+
...
23+
24+
25+
class FloatDataset(Dataset[FloatModel]):
26+
...
27+
28+
29+
class IntDataset(Dataset[IntModel]):
30+
...
31+
32+
33+
class RoundedIntDataset(ParamDataset[RoundedIntModel, bool]):
34+
...
35+
36+
37+
def test_convert_dataset():
38+
floats = FloatDataset(a=1.23, b=3.6)
39+
ints = convert_dataset.run(floats, dataset_cls=IntDataset)
40+
assert isinstance(ints, IntDataset)
41+
assert ints.to_data() == dict(a=1, b=3)
42+
43+
44+
def test_convert_dataset_with_default_params():
45+
floats = FloatDataset(a=1.23, b=3.6)
46+
ints = convert_dataset.run(floats, dataset_cls=RoundedIntDataset)
47+
assert isinstance(ints, RoundedIntDataset)
48+
assert ints.to_data() == dict(a=1, b=3)
49+
50+
51+
def test_convert_dataset_with_params():
52+
floats = FloatDataset(a=1.23, b=3.6)
53+
ints = convert_dataset.run(floats, dataset_cls=RoundedIntDataset, round_to_nearest=True)
54+
assert isinstance(ints, RoundedIntDataset)
55+
assert ints.to_data() == dict(a=1, b=4)

0 commit comments

Comments
 (0)