Skip to content

Commit abd4ea5

Browse files
authored
support xxx model (Azure#28225)
1 parent cfc996e commit abd4ea5

File tree

2 files changed

+11
-27
lines changed

2 files changed

+11
-27
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel_for.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ class ParallelFor(LoopNode, NodeIOMixin):
3838
AssetTypes.URI_FILE: AssetTypes.MLTABLE,
3939
AssetTypes.URI_FOLDER: AssetTypes.MLTABLE,
4040
AssetTypes.MLTABLE: AssetTypes.MLTABLE,
41+
AssetTypes.MLFLOW_MODEL: AssetTypes.MLTABLE,
42+
AssetTypes.TRITON_MODEL: AssetTypes.MLTABLE,
43+
AssetTypes.CUSTOM_MODEL: AssetTypes.MLTABLE,
44+
# legacy path support
45+
"path": AssetTypes.MLTABLE,
4146
ComponentParameterTypes.NUMBER: ComponentParameterTypes.STRING,
4247
ComponentParameterTypes.STRING: ComponentParameterTypes.STRING,
4348
ComponentParameterTypes.BOOLEAN: ComponentParameterTypes.STRING,
@@ -129,6 +134,8 @@ def _convert_output_meta(self, outputs):
129134
if output.type in self.OUT_TYPE_MAPPING:
130135
new_type = self.OUT_TYPE_MAPPING[output.type]
131136
else:
137+
# when loop body introduces some new output type, this will be raised as a reminder to support is in
138+
# parallel for
132139
raise UserErrorException(
133140
"Referencing output with type {} is not supported in parallel_for node.".format(output.type)
134141
)

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_controlflow_pipeline.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ def my_pipeline():
259259
({"type": "uri_file"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
260260
({"type": "uri_folder"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
261261
({"type": "mltable"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
262+
({"type": "mlflow_model"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
263+
({"type": "triton_model"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
264+
({"type": "custom_model"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
265+
({"type": "path"}, {'job_output_type': 'mltable'}, {'type': 'mltable'}, True),
262266
({"type": "number"}, {}, {'type': 'string'}, False),
263267
({"type": "string", "is_control": True}, {}, {'type': 'string', "is_control": True}, False),
264268
({"type": "boolean", "is_control": True}, {}, {'type': 'string', "is_control": True}, False),
@@ -298,30 +302,3 @@ def my_pipeline():
298302
pipeline_component = my_job.component
299303
rest_component = pipeline_component._to_rest_object().as_dict()
300304
assert rest_component["properties"]["component_spec"]["outputs"] == {'output': component_out_dict}
301-
302-
@pytest.mark.parametrize(
303-
"out_type", ["mlflow_model", "triton_model", "custom_model"]
304-
)
305-
def test_parallel_for_output_unsupported_case(self, out_type):
306-
basic_component = load_component(
307-
source="./tests/test_configs/components/helloworld_component.yml",
308-
params_override=[
309-
{"outputs.component_out_path": {"type": out_type}}
310-
]
311-
)
312-
313-
@pipeline
314-
def my_pipeline():
315-
body = basic_component(component_in_path=Input(path="test_path1"))
316-
317-
parallel_for(
318-
body=body,
319-
items={
320-
"iter1": {"component_in_number": 1},
321-
"iter2": {"component_in_number": 2}
322-
}
323-
)
324-
325-
with pytest.raises(UserErrorException) as e:
326-
my_pipeline()
327-
assert f"Referencing output with type {out_type} is not supported" in str(e.value)

0 commit comments

Comments
 (0)