Skip to content

Commit ae81f89

Browse files
authored
[Do_while] Fix pipeline component with do-while operator cannot be submitted. (Azure#27063)
* remove empty mapping check * fix error * fix test case * fix comment
1 parent 28d53ac commit ae81f89

File tree

6 files changed

+4718
-67
lines changed

6 files changed

+4718
-67
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/do_while.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_port_obj(body, port_name, is_input=True, validate_port=True):
101101
port = body.inputs.get(port_name, None)
102102
else:
103103
port = body.outputs.get(port_name, None)
104-
if not port:
104+
if port is None:
105105
if validate_port:
106106
raise ValidationError(
107107
message=f"Cannot find {port_name} in do_while loop body {'inputs' if is_input else 'outputs'}.",
@@ -229,7 +229,7 @@ def _validate_loop_condition(self, raise_error=True):
229229
if validation_result.passed:
230230
# Check condition is a control output.
231231
condition_name = self.condition if isinstance(self.condition, str) else self.condition._name
232-
if not self.body.component.outputs[condition_name].is_control:
232+
if not self.body._outputs[condition_name].is_control:
233233
validation_result.append_error(
234234
yaml_path="condition",
235235
message=(
@@ -251,18 +251,14 @@ def _validate_do_while_limit(self, raise_error=True):
251251
elif self.limits.max_iteration_count > DO_WHILE_MAX_ITERATION or self.limits.max_iteration_count < 0:
252252
validation_result.append_error(
253253
yaml_path="limit.max_iteration_count",
254-
message=f"The max iteration count cannot be less than 0 and larger than {DO_WHILE_MAX_ITERATION}.",
254+
message=f"The max iteration count cannot be less than 0 or larger than {DO_WHILE_MAX_ITERATION}.",
255255
)
256256
return validation_result.try_raise(self._get_validation_error_target(), raise_error=raise_error)
257257

258258
def _validate_body_output_mapping(self, raise_error=True):
259259
# pylint disable=protected-access
260260
validation_result = self._create_empty_validation_result()
261-
if not self.mapping:
262-
validation_result.append_error(
263-
yaml_path="mapping", message="The mapping of body output to input cannot be empty."
264-
)
265-
elif not isinstance(self.mapping, dict):
261+
if not isinstance(self.mapping, dict):
266262
validation_result.append_error(
267263
yaml_path="mapping", message=f"Mapping expects a dict type but passes in a {type(self.mapping)} type."
268264
)
@@ -277,7 +273,7 @@ def _validate_body_output_mapping(self, raise_error=True):
277273
output, self.body.outputs, port_type="output", yaml_path="mapping"
278274
)
279275
if validate_results.passed:
280-
is_control_output = self.body.component.outputs[output_name].is_control
276+
is_control_output = self.body._outputs[output_name].is_control
281277
inputs = inputs if isinstance(inputs, list) else [inputs]
282278
for item in inputs:
283279
input_validate_results = self._validate_port(
@@ -287,11 +283,12 @@ def _validate_body_output_mapping(self, raise_error=True):
287283
# pylint: disable=protected-access
288284
input_name = item if isinstance(item, str) else item._name
289285
input_output_mapping[input_name] = input_output_mapping.get(input_name, []) + [output_name]
286+
is_primitive_type = self.body._inputs[input_name]._meta._is_primitive_type
290287

291288
if (
292289
input_validate_results.passed
293290
and not is_control_output
294-
and self.body.component.inputs[input_name]._is_primitive_type # pylint: disable=protected-access
291+
and is_primitive_type # pylint: disable=protected-access
295292
):
296293
validate_results.append_error(
297294
yaml_path="mapping",
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
from typing import Callable
3+
from devtools_testutils import AzureRecordedTestCase
4+
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD
5+
6+
from azure.ai.ml import MLClient, load_job
7+
from azure.ai.ml._utils.utils import load_yaml
8+
from azure.ai.ml._schema.pipeline import pipeline_job
9+
from azure.ai.ml.entities._builders import Command, Pipeline
10+
from azure.ai.ml.entities._builders.do_while import DoWhile
11+
12+
from .._util import _PIPELINE_JOB_TIMEOUT_SECOND
13+
from .test_pipeline_job import assert_job_cancel
14+
15+
16+
@pytest.fixture()
17+
def update_pipeline_schema():
18+
# Update the job type that the pipeline is supported.
19+
schema = pipeline_job.PipelineJobSchema
20+
schema._declared_fields['jobs'] = pipeline_job.PipelineJobsField()
21+
22+
23+
@pytest.mark.usefixtures(
24+
"recorded_test",
25+
"mock_code_hash",
26+
"enable_pipeline_private_preview_features",
27+
"update_pipeline_schema",
28+
"mock_asset_name",
29+
"mock_component_hash",
30+
)
31+
@pytest.mark.timeout(timeout=_PIPELINE_JOB_TIMEOUT_SECOND, method=_PYTEST_TIMEOUT_METHOD)
32+
@pytest.mark.e2etest
33+
@pytest.mark.pipeline_test
34+
class TestConditionalNodeInPipeline(AzureRecordedTestCase):
35+
def test_pipeline_with_do_while_node(self, client: MLClient, randstr: Callable[[], str]) -> None:
36+
params_override = [{"name": randstr('name')}]
37+
pipeline_job = load_job(
38+
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/pipeline.yml",
39+
params_override=params_override,
40+
)
41+
created_pipeline = assert_job_cancel(pipeline_job, client)
42+
assert len(created_pipeline.jobs) == 5
43+
assert isinstance(created_pipeline.jobs["pipeline_body_node"], Pipeline)
44+
assert isinstance(created_pipeline.jobs["do_while_job_with_pipeline_job"], DoWhile)
45+
assert isinstance(created_pipeline.jobs["do_while_job_with_command_component"], DoWhile)
46+
assert isinstance(created_pipeline.jobs["command_component_body_node"], Command)
47+
assert isinstance(created_pipeline.jobs["get_do_while_result"], Command)
48+
49+
def test_do_while_pipeline_with_primitive_inputs(self, client: MLClient, randstr: Callable[[], str]) -> None:
50+
params_override = [{"name": randstr('name')}]
51+
pipeline_job = load_job(
52+
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/pipeline_with_primitive_inputs.yml",
53+
params_override=params_override,
54+
)
55+
created_pipeline = assert_job_cancel(pipeline_job, client)
56+
assert len(created_pipeline.jobs) == 5
57+
assert isinstance(created_pipeline.jobs["pipeline_body_node"], Pipeline)
58+
assert isinstance(created_pipeline.jobs["do_while_job_with_pipeline_job"], DoWhile)
59+
assert isinstance(created_pipeline.jobs["do_while_job_with_command_component"], DoWhile)
60+
assert isinstance(created_pipeline.jobs["command_component_body_node"], Command)
61+
assert isinstance(created_pipeline.jobs["get_do_while_result"], Command)

sdk/ml/azure-ai-ml/tests/pipeline_job/e2etests/test_pipeline_job.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,63 +1416,6 @@ def test_pipeline_with_pipeline_component(self, client: MLClient, randstr: Calla
14161416
"_source": "YAML.JOB",
14171417
}
14181418

1419-
@pytest.mark.skip(reason="Currently do_while only enable in master region.")
1420-
def test_pipeline_with_do_while_node(self, client: MLClient, randstr: Callable[[], str]) -> None:
1421-
params_override = [{"name": randstr()}]
1422-
pipeline_job = load_job(
1423-
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/pipeline.yml",
1424-
params_override=params_override,
1425-
)
1426-
created_pipeline = assert_job_cancel(pipeline_job, client)
1427-
assert len(created_pipeline.jobs) == 5
1428-
assert isinstance(created_pipeline.jobs["pipeline_body_node"], Pipeline)
1429-
assert isinstance(created_pipeline.jobs["do_while_job_with_pipeline_job"], DoWhile)
1430-
assert isinstance(created_pipeline.jobs["do_while_job_with_command_component"], DoWhile)
1431-
assert isinstance(created_pipeline.jobs["command_component_body_node"], Command)
1432-
assert isinstance(created_pipeline.jobs["get_do_while_result"], Command)
1433-
1434-
@pytest.mark.skip(reason="Currently not enable submit a pipeline with primitive inputs")
1435-
def test_do_while_pipeline_with_primitive_inputs(self, client: MLClient, randstr: Callable[[], str]) -> None:
1436-
params_override = [{"name": randstr()}]
1437-
pipeline_job = load_job(
1438-
path="./tests/test_configs/dsl_pipeline/pipeline_with_do_while/pipeline_with_primitive_inputs.yml",
1439-
params_override=params_override,
1440-
)
1441-
created_pipeline = assert_job_cancel(pipeline_job, client)
1442-
assert len(created_pipeline.jobs) == 5
1443-
assert isinstance(created_pipeline.jobs["pipeline_body_node"], Pipeline)
1444-
assert isinstance(created_pipeline.jobs["do_while_job_with_pipeline_job"], DoWhile)
1445-
assert isinstance(created_pipeline.jobs["do_while_job_with_command_component"], DoWhile)
1446-
assert isinstance(created_pipeline.jobs["command_component_body_node"], Command)
1447-
assert isinstance(created_pipeline.jobs["get_do_while_result"], Command)
1448-
1449-
@pytest.mark.skip(reason="Currently do_while only enable in master region.")
1450-
def test_pipeline_with_invalid_do_while_node(self, client: MLClient, randstr: Callable[[], str]) -> None:
1451-
params_override = [{"name": randstr()}]
1452-
with pytest.raises(ValidationError) as exception:
1453-
load_job(
1454-
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/invalid_pipeline.yml",
1455-
params_override=params_override,
1456-
)
1457-
error_message_str = re.findall(r"(\{.*\})", exception.value.args[0].replace("\n", ""))[0]
1458-
error_messages = json.loads(error_message_str.replace("\\", "\\\\"))
1459-
1460-
def assert_error_message(path, except_message, error_messages):
1461-
msgs = next(filter(lambda item: item["path"] == path, error_messages))
1462-
assert except_message == msgs["message"]
1463-
1464-
assert_error_message("jobs.empty_mapping.mapping", "Missing data for required field.", error_messages["errors"])
1465-
assert_error_message(
1466-
"jobs.out_of_range_max_iteration_count.limits.max_iteration_count",
1467-
"Must be greater than or equal to 1 and less than or equal to 1000.",
1468-
error_messages["errors"],
1469-
)
1470-
assert_error_message(
1471-
"jobs.invalid_max_iteration_count.limits.max_iteration_count",
1472-
"Not a valid integer.",
1473-
error_messages["errors"],
1474-
)
1475-
14761419
def test_pipeline_component_job(self, client: MLClient):
14771420
test_path = "./tests/test_configs/pipeline_jobs/pipeline_component_job.yml"
14781421
job: PipelineJob = load_job(source=test_path)

sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_validate.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
import json
13
from pathlib import Path
24
from unittest.mock import patch
35

@@ -12,6 +14,7 @@
1214
from azure.ai.ml.exceptions import ValidationException
1315

1416
from .._util import _PIPELINE_JOB_TIMEOUT_SECOND
17+
from ..e2etests.test_control_flow_node_in_pipeline_job import update_pipeline_schema
1518

1619

1720
def assert_the_same_path(actual_path, expected_path):
@@ -637,3 +640,31 @@ def pipeline_with_compute_binding(compute_name: str):
637640
pipeline_job = pipeline_with_compute_binding('cpu-cluster')
638641
# Assert compute binding validate not raise error when validate
639642
assert pipeline_job._validate().passed
643+
644+
@pytest.mark.usefixtures(
645+
"enable_pipeline_private_preview_features",
646+
"update_pipeline_schema"
647+
)
648+
def test_pipeline_with_invalid_do_while_node(self) -> None:
649+
with pytest.raises(ValidationError) as exception:
650+
load_job(
651+
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/invalid_pipeline.yml",
652+
)
653+
error_message_str = re.findall(r"(\{.*\})", exception.value.args[0].replace("\n", ""))[0]
654+
error_messages = json.loads(error_message_str.replace("\\", "\\\\"))
655+
656+
def assert_error_message(path, except_message, error_messages):
657+
msgs = next(filter(lambda item: item["path"] == path, error_messages))
658+
assert except_message == msgs["message"]
659+
660+
assert_error_message("jobs.empty_mapping.mapping", "Missing data for required field.", error_messages["errors"])
661+
assert_error_message(
662+
"jobs.out_of_range_max_iteration_count.limits.max_iteration_count",
663+
"Must be greater than or equal to 1 and less than or equal to 1000.",
664+
error_messages["errors"],
665+
)
666+
assert_error_message(
667+
"jobs.invalid_max_iteration_count.limits.max_iteration_count",
668+
"Not a valid integer.",
669+
error_messages["errors"],
670+
)

0 commit comments

Comments
 (0)