Skip to content

Commit b8faace

Browse files
authored
[ML][Pipelines] Bugfix: exclude group input during validate init/finalize (Azure#28226)
* exclude group when validate init/finalize * add comment * handle group in try getting data binding * replace with GroupInput._is_group_attr_dict * iterate all items in group
1 parent abd4ea5 commit b8faace

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/pipeline_job.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import typing
99
from functools import partial
1010
from pathlib import Path
11-
from typing import Dict, Optional, Union
11+
from typing import Dict, List, Optional, Union
1212

1313
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobBase
1414
from azure.ai.ml._restclient.v2022_10_01_preview.models import PipelineJob as RestPipelineJob
@@ -33,6 +33,7 @@
3333
from azure.ai.ml.entities._builders.pipeline import Pipeline
3434
from azure.ai.ml.entities._component.component import Component
3535
from azure.ai.ml.entities._component.pipeline_component import PipelineComponent
36+
from azure.ai.ml.entities._inputs_outputs.group_input import GroupInput
3637

3738
# from azure.ai.ml.entities._job.identity import AmlToken, Identity, ManagedIdentity, UserIdentity
3839
from azure.ai.ml.entities._credentials import (
@@ -337,31 +338,44 @@ def _is_control_flow_node(_validate_job_name: str) -> bool:
337338
return issubclass(type(_validate_job), ControlFlowNode)
338339

339340
def _is_isolated_job(_validate_job_name: str) -> bool:
340-
def _try_get_data_binding(_input_output_data) -> Union[str, None]:
341-
"""Try to get data binding from input/output data, return None if not found."""
341+
def _try_get_data_bindings(_name: str, _input_output_data) -> Union[List[str], None]:
342+
"""Try to get data bindings from input/output data, return None if not found."""
343+
# handle group input
344+
if GroupInput._is_group_attr_dict(_input_output_data):
345+
# flatten to avoid nested cases
346+
flattened_values = list(_input_output_data.flatten(_name).values())
347+
# handle invalid empty group
348+
if len(flattened_values) == 0:
349+
return None
350+
return [_value.path for _value in flattened_values]
351+
_input_output_data = _input_output_data._data
342352
if isinstance(_input_output_data, str):
343-
return _input_output_data
353+
return [_input_output_data]
344354
if not hasattr(_input_output_data, "_data_binding"):
345355
return None
346-
return _input_output_data._data_binding()
356+
return [_input_output_data._data_binding()]
347357

348358
_validate_job = self.jobs[_validate_job_name]
349359
# no input to validate job
350360
for _input_name in _validate_job.inputs:
351-
_data_binding = _try_get_data_binding(_validate_job.inputs[_input_name]._data)
352-
if _data_binding is not None and is_data_binding_expression(_data_binding, ["parent", "jobs"]):
353-
return False
354-
# no output from validate job
361+
_data_bindings = _try_get_data_bindings(_input_name, _validate_job.inputs[_input_name])
362+
if _data_bindings is None:
363+
continue
364+
for _data_binding in _data_bindings:
365+
if is_data_binding_expression(_data_binding, ["parent", "jobs"]):
366+
return False
367+
# no output from validate job - iterate other jobs input(s) to validate
355368
for _job_name, _job in self.jobs.items():
356369
# exclude control flow node as it does not have inputs
357370
if _is_control_flow_node(_job_name):
358371
continue
359372
for _input_name in _job.inputs:
360-
_data_binding = _try_get_data_binding(_job.inputs[_input_name]._data)
361-
if _data_binding is not None and is_data_binding_expression(
362-
_data_binding, ["parent", "jobs", _validate_job_name]
363-
):
364-
return False
373+
_data_bindings = _try_get_data_bindings(_input_name, _job.inputs[_input_name])
374+
if _data_bindings is None:
375+
continue
376+
for _data_binding in _data_bindings:
377+
if is_data_binding_expression(_data_binding, ["parent", "jobs", _validate_job_name]):
378+
return False
365379
return True
366380

367381
# validate on_init

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_group.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,6 @@ def my_pipeline() -> ParentOutputs:
619619
}
620620
assert "Nested group annotation is not supported in pipeline output." in str(e.value)
621621

622-
623-
624622
with pytest.raises(UserErrorException) as e:
625623
@group
626624
class GroupOutputs:

0 commit comments

Comments
 (0)