Skip to content

Commit 23dbd6e

Browse files
authored
[ML][Pipelines] Support assets in parallel_for items. (Azure#28673)
* support yaml cases * support in SDK * fix tests * expose _to_rest_items interface * remove unnecessary tests * fix empty items * fix tests
1 parent 183ee8b commit 23dbd6e

13 files changed

+2318
-100
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/control_flow_job.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import copy
55
import json
66

7-
from marshmallow import INCLUDE, fields, pre_dump
7+
from marshmallow import INCLUDE, fields, pre_dump, pre_load
88

99
from azure.ai.ml._schema.core.fields import DataBindingStr, NestedField, StringTransformedEnum, UnionField
1010
from azure.ai.ml._schema.core.schema import PathAwareSchema
1111
from azure.ai.ml.constants._component import ControlFlowType
1212

1313
from ..job.input_output_entry import OutputSchema
14+
from ..job.input_output_fields_provider import InputsField
1415
from ..job.job_limits import DoWhileLimitsSchema
1516
from .component_job import _resolve_outputs
1617
from .pipeline_job_io import OutputBindingStr
@@ -73,14 +74,21 @@ def resolve_inputs_outputs(self, data, **kwargs): # pylint: disable=no-self-use
7374

7475
return result
7576

77+
@pre_dump
78+
def convert_control_flow_body_to_binding_str(self, data, **kwargs): # pylint: disable=no-self-use, unused-argument
79+
80+
return super(DoWhileSchema, self).convert_control_flow_body_to_binding_str(data, **kwargs)
81+
7682

7783
class ParallelForSchema(BaseLoopSchema):
7884
type = StringTransformedEnum(allowed_values=[ControlFlowType.PARALLEL_FOR])
7985
items = UnionField(
8086
[
87+
fields.Dict(keys=fields.Str(), values=InputsField()),
88+
fields.List(InputsField()),
89+
# put str in last to make sure other type items won't become string when dumps.
90+
# TODO: only support binding here
8191
fields.Str(),
82-
fields.Dict(keys=fields.Str(), values=fields.Dict()),
83-
fields.List(fields.Dict()),
8492
],
8593
required=True,
8694
)
@@ -90,19 +98,42 @@ class ParallelForSchema(BaseLoopSchema):
9098
values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True),
9199
)
92100

101+
@pre_load
102+
def load_items(self, data, **kwargs): # pylint: disable=no-self-use, unused-argument
103+
# load items from json to convert the assets in it to rest
104+
try:
105+
items = data["items"]
106+
if isinstance(items, str):
107+
items = json.loads(items)
108+
data["items"] = items
109+
except Exception: # pylint: disable=broad-except
110+
pass
111+
return data
112+
93113
@pre_dump
94-
def serialize_items(self, data, **kwargs): # pylint: disable=no-self-use, unused-argument
95-
from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
114+
def convert_control_flow_body_to_binding_str(self, data, **kwargs): # pylint: disable=no-self-use, unused-argument
96115

97-
result = copy.copy(data)
98-
if isinstance(result.items, (dict, list)):
99-
# use str to serialize input/output builder
100-
result._items = json.dumps(result.items, default=lambda x: str(x) if isinstance(x, InputOutputBase) else x)
101-
return result
116+
return super(ParallelForSchema, self).convert_control_flow_body_to_binding_str(data, **kwargs)
102117

103118
@pre_dump
104119
def resolve_outputs(self, job, **kwargs): # pylint: disable=unused-argument
105120

106121
result = copy.copy(job)
107122
_resolve_outputs(result, job)
108123
return result
124+
125+
@pre_dump
126+
def serialize_items(self, data, **kwargs): # pylint: disable=no-self-use, unused-argument
127+
# serialize items to json string to avoid being removed by _dump_for_validation
128+
from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
129+
130+
def _binding_handler(obj):
131+
if isinstance(obj, InputOutputBase):
132+
return str(obj)
133+
return repr(obj)
134+
135+
result = copy.copy(data)
136+
if isinstance(result.items, (dict, list)):
137+
# use str to serialize input/output builder
138+
result._items = json.dumps(result.items, default=_binding_handler)
139+
return result

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel_for.py

Lines changed: 130 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44
import json
5+
import os
56
from typing import Dict, Union
67

7-
from azure.ai.ml import Output
8+
from azure.ai.ml import Output, Input
89
from azure.ai.ml._schema import PathAwareSchema
910
from azure.ai.ml._schema.pipeline.control_flow_job import ParallelForSchema
1011
from azure.ai.ml._utils.utils import is_data_binding_expression
@@ -28,7 +29,7 @@ class ParallelFor(LoopNode, NodeIOMixin):
2829
:param body: Pipeline job for the parallel for loop body.
2930
:type body: Pipeline
3031
:param items: The loop body's input which will bind to the loop node.
31-
:type items: Union[list, dict, str, PipelineInput, NodeOutput]
32+
:type items: typing.Union[list, dict, str, NodeOutput, PipelineInput]
3233
:param max_concurrency: Maximum number of concurrent iterations to run. All loop body nodes will be executed
3334
in parallel if not specified.
3435
:type max_concurrency: int
@@ -105,17 +106,92 @@ def _attr_type_map(cls) -> dict:
105106
"items": (dict, list, str, PipelineInput, NodeOutput),
106107
}
107108

109+
@classmethod
110+
def _to_rest_item(cls, item: dict) -> dict:
111+
"""Convert item to rest object."""
112+
primitive_inputs, asset_inputs = {}, {}
113+
# validate item
114+
for key, val in item.items():
115+
if isinstance(val, Input):
116+
asset_inputs[key] = val
117+
elif isinstance(val, (PipelineInput, NodeOutput)):
118+
# convert binding object to string
119+
primitive_inputs[key] = str(val)
120+
else:
121+
primitive_inputs[key] = val
122+
return {
123+
# asset type inputs will be converted to JobInput dict:
124+
# {"asset_param": {"uri": "xxx", "job_input_type": "uri_file"}}
125+
**cls._input_entity_to_rest_inputs(input_entity=asset_inputs),
126+
# primitive inputs has primitive type value like this
127+
# {"int_param": 1}
128+
**primitive_inputs
129+
}
130+
131+
@classmethod
132+
def _to_rest_items(cls, items: Union[list, dict, str, NodeOutput, PipelineInput]) -> str:
133+
"""Convert items to rest object."""
134+
# validate items.
135+
cls._validate_items(items=items, raise_error=True, body_component=None)
136+
# convert items to rest object
137+
if isinstance(items, list):
138+
rest_items = [cls._to_rest_item(item=i) for i in items]
139+
rest_items = json.dumps(rest_items)
140+
elif isinstance(items, dict):
141+
rest_items = {k: cls._to_rest_item(item=v) for k, v in items.items()}
142+
rest_items = json.dumps(rest_items)
143+
elif isinstance(items, (NodeOutput, PipelineInput)):
144+
rest_items = str(items)
145+
elif isinstance(items, str):
146+
rest_items = items
147+
else:
148+
raise UserErrorException("Unsupported items type: {}".format(type(items)))
149+
return rest_items
150+
108151
def _to_rest_object(self, **kwargs) -> dict: # pylint: disable=unused-argument
109152
"""Convert self to a rest object for remote call."""
110153
rest_node = super(ParallelFor, self)._to_rest_object(**kwargs)
111-
rest_node.update(dict(outputs=self._to_rest_outputs()))
154+
# convert items to rest object
155+
rest_items = self._to_rest_items(items=self.items)
156+
rest_node.update(dict(
157+
items=rest_items,
158+
outputs=self._to_rest_outputs()
159+
))
112160
return convert_ordered_dict_to_dict(rest_node)
113161

162+
@classmethod
163+
def _from_rest_item(cls, rest_item):
164+
"""Convert rest item to item."""
165+
primitive_inputs, asset_inputs = {}, {}
166+
for key, val in rest_item.items():
167+
if isinstance(val, dict) and val.get("job_input_type"):
168+
asset_inputs[key] = val
169+
else:
170+
primitive_inputs[key] = val
171+
return {
172+
**cls._from_rest_inputs(inputs=asset_inputs),
173+
**primitive_inputs
174+
}
175+
176+
@classmethod
177+
def _from_rest_items(cls, rest_items: str) -> Union[dict, list, str]:
178+
"""Convert items from rest object."""
179+
try:
180+
items = json.loads(rest_items)
181+
except json.JSONDecodeError:
182+
# return original items when failed to load
183+
return rest_items
184+
if isinstance(items, list):
185+
return [cls._from_rest_item(rest_item=i) for i in items]
186+
if isinstance(items, dict):
187+
return {k: cls._from_rest_item(rest_item=v) for k, v in items.items()}
188+
return rest_items
189+
114190
@classmethod
115191
def _from_rest_object(cls, obj: dict, pipeline_jobs: dict) -> "ParallelFor":
116192
# pylint: disable=protected-access
117-
118193
obj = BaseNode._from_rest_object_to_init_params(obj)
194+
obj["items"] = cls._from_rest_items(rest_items=obj.get("items", ""))
119195
return cls._create_instance_from_schema_dict(pipeline_jobs=pipeline_jobs, loaded_data=obj)
120196

121197
@classmethod
@@ -149,11 +225,21 @@ def _convert_output_meta(self, outputs):
149225
aggregate_outputs[name] = resolved_output
150226
return aggregate_outputs
151227

152-
def _validate_items(self, raise_error=True):
153-
validation_result = self._create_empty_validation_result()
154-
if self.items is not None:
155-
items = self.items
228+
def _customized_validate(self):
229+
"""Customized validation for parallel for node."""
230+
# pylint: disable=protected-access
231+
validation_result = self._validate_body(raise_error=False)
232+
validation_result.merge_with(
233+
self._validate_items(items=self.items, raise_error=False, body_component=self.body._component)
234+
)
235+
return validation_result
236+
237+
@classmethod
238+
def _validate_items(cls, items, raise_error=True, body_component=None):
239+
validation_result = cls._create_empty_validation_result()
240+
if items is not None:
156241
if isinstance(items, str):
242+
# TODO: remove the validation
157243
# try to deserialize str if it's a json string
158244
try:
159245
items = json.loads(items)
@@ -168,7 +254,7 @@ def _validate_items(self, raise_error=True):
168254
items = list(items.values())
169255
if isinstance(items, list):
170256
if len(items) > 0:
171-
self._validate_items_list(items, validation_result)
257+
cls._validate_items_list(items, validation_result, body_component=body_component)
172258
else:
173259
validation_result.append_error(
174260
yaml_path="items",
@@ -179,17 +265,12 @@ def _validate_items(self, raise_error=True):
179265
message="Items is required for parallel_for node",
180266
)
181267
return validation_result.try_raise(
182-
self._get_validation_error_target(),
268+
cls._get_validation_error_target(),
183269
raise_error=raise_error,
184270
)
185271

186-
def _customized_validate(self):
187-
"""Customized validation for parallel for node."""
188-
validation_result = self._validate_body(raise_error=False)
189-
validation_result.merge_with(self._validate_items(raise_error=False))
190-
return validation_result
191-
192-
def _validate_items_list(self, items: list, validation_result):
272+
@classmethod
273+
def _validate_items_list(cls, items: list, validation_result, body_component=None):
193274
# pylint: disable=protected-access
194275
meta = {}
195276
# all items have to be dict and have matched meta
@@ -213,10 +294,41 @@ def _validate_items_list(self, items: list, validation_result):
213294
message=msg
214295
)
215296
# items' keys should appear in body's inputs
216-
body_component = self.body._component
217297
if isinstance(body_component, Component) and (not item.keys() <= body_component.inputs.keys()):
218298
msg = f"Item {item} got unmatched inputs with loop body component inputs {body_component.inputs}."
219299
validation_result.append_error(
220300
yaml_path="items",
221301
message=msg
222302
)
303+
# validate item value type
304+
cls._validate_item_value_type(item=item, validation_result=validation_result)
305+
306+
@classmethod
307+
def _validate_item_value_type(cls, item: dict, validation_result):
308+
# pylint: disable=protected-access
309+
supported_types = (Input, str, bool, int, float, PipelineInput)
310+
for _, val in item.items():
311+
if not isinstance(val, supported_types):
312+
validation_result.append_error(
313+
yaml_path="items",
314+
message="Unsupported type {} in parallel_for items. Supported types are: {}".format(
315+
type(val), supported_types
316+
)
317+
)
318+
if isinstance(val, Input):
319+
cls._validate_input_item_value(entry=val, validation_result=validation_result)
320+
321+
@classmethod
322+
def _validate_input_item_value(cls, entry: Input, validation_result):
323+
if not isinstance(entry, Input):
324+
return
325+
if not entry.path:
326+
validation_result.append_error(
327+
yaml_path="items",
328+
message=f"Input path not provided for {entry}.",
329+
)
330+
if isinstance(entry.path, str) and os.path.exists(entry.path):
331+
validation_result.append_error(
332+
yaml_path="items",
333+
message=f"Local file input {entry} is not supported, please create it as a dataset.",
334+
)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/_io/mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,14 @@ def _to_rest_inputs(self) -> Dict[str, Dict]:
154154
}
155155
"""
156156
built_inputs = self._build_inputs()
157+
return self._input_entity_to_rest_inputs(input_entity=built_inputs)
158+
159+
@classmethod
160+
def _input_entity_to_rest_inputs(cls, input_entity: Dict[str, Input]) -> Dict[str, Dict]:
157161

158162
# Convert io entity to rest io objects
159163
input_bindings, dataset_literal_inputs = process_sdk_component_job_io(
160-
built_inputs, [ComponentJobConstants.INPUT_PATTERN]
164+
input_entity, [ComponentJobConstants.INPUT_PATTERN]
161165
)
162166

163167
# parse input_bindings to InputLiteral(value=str(binding))

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ def is_empty_target(obj):
205205

206206
def convert_ordered_dict_to_dict(target_object: Union[Dict, List], remove_empty=True) -> Union[Dict, List]:
207207
"""Convert ordered dict to dict. Remove keys with None value.
208-
209208
This is a workaround for rest request must be in dict instead of
210209
ordered dict.
211210
"""

sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_controlflow_pipeline.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,3 +645,42 @@ def parallel_for_pipeline():
645645
'value': '${{parent.outputs.component_out_path}}'}},
646646
'type': 'parallel_for'}
647647
}
648+
649+
def test_parallel_for_pipeline_with_asset_items(self, client: MLClient):
650+
hello_world_component = load_component(
651+
source="./tests/test_configs/components/helloworld_component.yml"
652+
)
653+
654+
@pipeline
655+
def parallel_for_pipeline():
656+
parallel_body = hello_world_component()
657+
parallel_node = parallel_for(
658+
body=parallel_body,
659+
items=[
660+
{"component_in_number": 1, "component_in_path": test_input},
661+
{"component_in_number": 2, "component_in_path": test_input},
662+
]
663+
)
664+
after_node = hello_world_component(
665+
component_in_path=parallel_node.outputs.component_out_path,
666+
)
667+
668+
pipeline_job = parallel_for_pipeline()
669+
pipeline_job.settings.default_compute = "cpu-cluster"
670+
671+
with include_private_preview_nodes_in_pipeline():
672+
pipeline_job = assert_job_cancel(pipeline_job, client)
673+
674+
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
675+
assert dsl_pipeline_job_dict["properties"]["jobs"]["parallel_node"] == {
676+
'body': '${{parent.jobs.parallel_body}}',
677+
'items': '[{"component_in_path": {"uri": '
678+
'"https://dprepdata.blob.core.windows.net/demo/Titanic.csv", '
679+
'"job_input_type": "uri_file"}, '
680+
'"component_in_number": 1}, {"component_in_path": '
681+
'{"uri": '
682+
'"https://dprepdata.blob.core.windows.net/demo/Titanic.csv", '
683+
'"job_input_type": "uri_file"}, '
684+
'"component_in_number": 2}]',
685+
'type': 'parallel_for'
686+
}

0 commit comments

Comments
 (0)