Skip to content

Commit 5225120

Browse files
authored
[ML][Pipelines] Dynamic inputs: support output binding (Azure#28542)
* test: add case output binding in dynamic args * get annotation type when NodeOutput
1 parent 559bff1 commit 5225120

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ def _update_inputs(self, pipeline_inputs):
400400
value = value._data
401401
if isinstance(value, Input):
402402
anno = copy.copy(value)
403+
elif isinstance(value, NodeOutput):
404+
anno = Input(type=value.type)
403405
else:
404406
anno = _get_annotation_by_value(value)
405407
anno.name = input_name

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,24 @@ def pipeline_with_variable_args(key_1: int, **kargs):
21642164
match="pipeline_with_variable_args\(\) got multiple values for argument 'key_1'\."):
21652165
pipeline_with_variable_args(10, key_1=10)
21662166

2167+
def test_pipeline_with_output_binding_in_dynamic_args(self):
2168+
hello_world_func = load_component(components_dir / "helloworld_component.yml")
2169+
hello_world_no_inputs_func = load_component(components_dir / "helloworld_component_no_inputs.yml")
2170+
2171+
@dsl.pipeline
2172+
def pipeline_func_consume_dynamic_arg(**kwargs):
2173+
hello_world_func(component_in_number=kwargs["int_param"], component_in_path=kwargs["path_param"])
2174+
2175+
@dsl.pipeline
2176+
def root_pipeline_func():
2177+
node = hello_world_no_inputs_func()
2178+
kwargs = {"int_param": 0, "path_param": node.outputs.component_out_path}
2179+
pipeline_func_consume_dynamic_arg(**kwargs)
2180+
2181+
pipeline_job = root_pipeline_func()
2182+
pipeline_job.settings.default_compute = "cpu-cluster"
2183+
assert pipeline_job._customized_validate().passed is True
2184+
21672185
def test_condition_node_consumption(self):
21682186
from azure.ai.ml.dsl._condition import condition
21692187

0 commit comments

Comments
 (0)