Skip to content

Commit f84e2d4

Browse files
authored
[ML][Pipelines] fix: pipeline job schema validation fail for data-binding expression (Azure#28517)
* fix: pipeline job schema validation fail for data-binding expression * fix: test public json schema
1 parent 1e66bf7 commit f84e2d4

File tree

12 files changed

+271
-59
lines changed

12 files changed

+271
-59
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs):
455455
resolve_field_instance(cls_or_instance)
456456
for cls_or_instance in union_fields
457457
]
458+
# TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True
458459
self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation
459460
except FieldInstanceResolutionError as error:
460461
raise ValueError(

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_fields_provider.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from marshmallow import fields
66

7+
from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
78
from azure.ai.ml._schema.core.fields import NestedField, PrimitiveValueField, UnionField
89
from azure.ai.ml._schema.job.input_output_entry import (
910
DataInputSchema,
@@ -14,20 +15,30 @@
1415
)
1516

1617

17-
def InputsField(**kwargs):
18+
def InputsField(*, support_databinding: bool = False, **kwargs):
19+
value_fields = [
20+
NestedField(DataInputSchema),
21+
NestedField(ModelInputSchema),
22+
NestedField(MLTableInputSchema),
23+
NestedField(InputLiteralValueSchema),
24+
PrimitiveValueField(is_strict=False),
25+
# This ordering of types for the values keyword is intentional. The ordering of types
26+
# determines what order schema values are matched and cast in. Changing the current ordering can
27+
# result in values being mis-cast such as 1.0 translating into True.
28+
]
29+
30+
# As is_strict is set to True, 1 and only 1 value field must be matched.
31+
# root level data-binding expression has already been covered by PrimitiveValueField;
32+
# If support_databinding is True, we should only add data-binding expression support for nested fields.
33+
if support_databinding:
34+
for field_obj in value_fields:
35+
if isinstance(field_obj, NestedField):
36+
support_data_binding_expression_for_fields(field_obj.schema)
37+
1838
return fields.Dict(
1939
keys=fields.Str(),
2040
values=UnionField(
21-
[
22-
NestedField(DataInputSchema),
23-
NestedField(ModelInputSchema),
24-
NestedField(MLTableInputSchema),
25-
NestedField(InputLiteralValueSchema),
26-
PrimitiveValueField(is_strict=False),
27-
# This ordering of types for the values keyword is intentional. The ordering of types
28-
# determines what order schema values are matched and cast in. Changing the current ordering can
29-
# result in values being mis-cast such as 1.0 translating into True.
30-
],
41+
value_fields,
3142
metadata={"description": "Inputs to a job."},
3243
is_strict=True,
3344
**kwargs

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/component_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
class BaseNodeSchema(PathAwareSchema):
4949
unknown = INCLUDE
5050

51-
inputs = InputsField()
51+
inputs = InputsField(support_databinding=True)
5252
outputs = fields.Dict(
5353
keys=fields.Str(),
5454
values=UnionField([OutputBindingStr, NestedField(OutputSchema)], allow_none=True),
@@ -61,7 +61,7 @@ def __init__(self, *args, **kwargs):
6161
# data binding expression is not supported inside component field, while validation error
6262
# message will be very long when component is an object as error message will include
6363
# str(component), so just add component to skip list. The same to trial in Sweep.
64-
support_data_binding_expression_for_fields(self, ["type", "component", "trial"])
64+
support_data_binding_expression_for_fields(self, ["type", "component", "trial", "inputs"])
6565

6666
@post_dump(pass_original=True)
6767
def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint: disable=unused-argument

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_job.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from marshmallow import INCLUDE, ValidationError, post_load, pre_dump, pre_load
1010

11-
from azure.ai.ml._schema._utils.data_binding_expression import _add_data_binding_to_field
1211
from azure.ai.ml._schema.core.fields import (
1312
ArmVersionedStr,
1413
ComputeField,
@@ -36,7 +35,8 @@ class PipelineJobSchema(BaseJobSchema):
3635
type = StringTransformedEnum(allowed_values=[JobType.PIPELINE])
3736
compute = ComputeField()
3837
settings = NestedField(PipelineJobSettingsSchema, unknown=INCLUDE)
39-
inputs = InputsField()
38+
# Support databinding in inputs as we support macro like ${{name}}
39+
inputs = InputsField(support_databinding=True)
4040
outputs = OutputsField()
4141
jobs = PipelineJobsField()
4242
component = UnionField(
@@ -50,11 +50,6 @@ class PipelineJobSchema(BaseJobSchema):
5050
],
5151
)
5252

53-
def __init__(self, *args, **kwargs):
54-
super().__init__(*args, **kwargs)
55-
# Support databinding in inputs as we support macro like ${{name}}
56-
_add_data_binding_to_field(self.load_fields["inputs"], [], [])
57-
5853
@pre_dump()
5954
def backup_jobs_and_remove_component(self, job, **kwargs):
6055
# pylint: disable=protected-access

sdk/ml/azure-ai-ml/tests/conftest.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from azure.core.pipeline.transport import HttpTransport
4949
from azure.identity import AzureCliCredential, ClientSecretCredential, DefaultAzureCredential
5050

51+
from test_utilities.utils import reload_schema_for_nodes_in_pipeline_job
52+
5153
E2E_TEST_LOGGING_ENABLED = "E2E_TEST_LOGGING_ENABLED"
5254
test_folder = Path(os.path.abspath(__file__)).parent.absolute()
5355

@@ -840,14 +842,28 @@ def pytest_configure(config):
840842

841843
@pytest.fixture()
842844
def enable_private_preview_pipeline_node_types():
843-
# Update the node types in pipeline jobs to include the private preview node types
844-
from azure.ai.ml._schema.pipeline import pipeline_job
845+
with reload_schema_for_nodes_in_pipeline_job():
846+
yield
845847

846-
schema = pipeline_job.PipelineJobSchema
847-
original_jobs = schema._declared_fields["jobs"]
848-
schema._declared_fields["jobs"] = pipeline_job.PipelineJobsField()
849848

850-
try:
849+
@pytest.fixture()
850+
def disable_internal_components():
851+
"""Some global changes are made in enable_internal_components, so we need to explicitly disable it.
852+
It's not recommended to use this fixture along with other related fixtures like enable_internal_components
853+
and enable_private_preview_features, as the execution order of fixtures is not guaranteed.
854+
"""
855+
from azure.ai.ml._internal._schema.component import NodeType
856+
from azure.ai.ml._internal._util import _set_registered
857+
from azure.ai.ml.entities._component.component_factory import component_factory
858+
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
859+
860+
for _type in NodeType.all_values():
861+
pipeline_node_factory._create_instance_funcs.pop(_type, None) # pylint: disable=protected-access
862+
pipeline_node_factory._load_from_rest_object_funcs.pop(_type, None) # pylint: disable=protected-access
863+
component_factory._create_instance_funcs.pop(_type, None) # pylint: disable=protected-access
864+
component_factory._create_schema_funcs.pop(_type, None) # pylint: disable=protected-access
865+
866+
_set_registered(False)
867+
868+
with reload_schema_for_nodes_in_pipeline_job(revert_after_yield=False):
851869
yield
852-
finally:
853-
schema._declared_fields["jobs"] = original_jobs

sdk/ml/azure-ai-ml/tests/internal/_utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,3 @@ def extract_non_primitive(obj):
282282
if isinstance(obj, (float, int, str)):
283283
return None
284284
return obj
285-
286-
287-
def unregister_internal_components():
288-
from azure.ai.ml._internal._schema.component import NodeType
289-
from azure.ai.ml._internal._util import _set_registered
290-
from azure.ai.ml.entities._component.component_factory import component_factory
291-
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
292-
293-
for _type in NodeType.all_values():
294-
pipeline_node_factory._create_instance_funcs.pop(_type, None) # pylint: disable=protected-access
295-
pipeline_node_factory._load_from_rest_object_funcs.pop(_type, None) # pylint: disable=protected-access
296-
component_factory._create_instance_funcs.pop(_type, None) # pylint: disable=protected-access
297-
component_factory._create_schema_funcs.pop(_type, None) # pylint: disable=protected-access
298-
299-
_set_registered(False)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
import pytest
5+
from azure.ai.ml.constants._common import AZUREML_INTERNAL_COMPONENTS_ENV_VAR
6+
from azure.ai.ml.dsl._utils import environment_variable_overwrite
7+
from azure.ai.ml.exceptions import ValidationException
8+
9+
10+
@pytest.mark.usefixtures("disable_internal_components")
11+
@pytest.mark.unittest
12+
@pytest.mark.pipeline_test
13+
class TestInternalDisabled:
14+
def test_load_pipeline_job_with_internal_nodes_from_rest(self):
15+
# this is a simplified test case which avoid constructing a complete pipeline job rest object
16+
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
17+
18+
internal_node_type = "CommandComponent"
19+
with environment_variable_overwrite(AZUREML_INTERNAL_COMPONENTS_ENV_VAR, "False"):
20+
with pytest.raises(ValidationException, match=f"Unsupported component type: {internal_node_type}."):
21+
pipeline_node_factory.get_load_from_rest_object_func(internal_node_type)
22+
23+
with environment_variable_overwrite(AZUREML_INTERNAL_COMPONENTS_ENV_VAR, "True"):
24+
pipeline_node_factory.get_load_from_rest_object_func(internal_node_type)

sdk/ml/azure-ai-ml/tests/internal/unittests/test_pipeline_job.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,10 @@
3232
TargetSelector,
3333
)
3434
from azure.ai.ml._internal.entities import InternalBaseNode, InternalComponent, Scope
35-
from azure.ai.ml.constants._common import AZUREML_INTERNAL_COMPONENTS_ENV_VAR, AssetTypes
35+
from azure.ai.ml.constants._common import AssetTypes
3636
from azure.ai.ml.constants._job.job import JobComputePropertyFields
3737
from azure.ai.ml.dsl import pipeline
38-
from azure.ai.ml.dsl._utils import environment_variable_overwrite
3938
from azure.ai.ml.entities import CommandComponent, Data, PipelineJob
40-
from azure.ai.ml.exceptions import ValidationException
4139

4240
from .._utils import (
4341
DATA_VERSION,
@@ -46,7 +44,6 @@
4644
extract_non_primitive,
4745
get_expected_runsettings_items,
4846
set_run_settings,
49-
unregister_internal_components,
5047
)
5148

5249

@@ -602,15 +599,3 @@ def test_job_properties(self):
602599
assert len(node_dict["properties"]) == 1
603600
assert "AZURE_ML_PathOnCompute_" in list(node_dict["properties"].keys())[0]
604601
assert node_dict["properties"] == rest_node_dict["properties"]
605-
606-
def test_load_pipeline_job_with_internal_nodes_from_rest(self):
607-
# this is a simplified test case which avoid constructing a complete pipeline job rest object
608-
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
609-
610-
unregister_internal_components()
611-
internal_node_type = "CommandComponent"
612-
with environment_variable_overwrite(AZUREML_INTERNAL_COMPONENTS_ENV_VAR, "False"):
613-
with pytest.raises(ValidationException, match=f"Unsupported component type: {internal_node_type}."):
614-
pipeline_node_factory.get_load_from_rest_object_func(internal_node_type)
615-
616-
pipeline_node_factory.get_load_from_rest_object_func(internal_node_type)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from pathlib import Path
2+
import pytest
3+
import yaml
4+
from jsonschema.validators import validate
5+
from azure.ai.ml.entities import PipelineJob
6+
from test_utilities.json_schema import PatchedJSONSchema
7+
8+
from .._util import _PIPELINE_JOB_TIMEOUT_SECOND
9+
10+
11+
# schema of nodes will be reloaded with private preview features disabled in unregister_internal_components
12+
@pytest.mark.usefixtures("disable_internal_components")
13+
@pytest.mark.timeout(_PIPELINE_JOB_TIMEOUT_SECOND)
14+
@pytest.mark.unittest
15+
@pytest.mark.pipeline_test
16+
class TestPrivatePreviewDisabled:
17+
def test_public_json_schema(self):
18+
# public json schema is the json schema to be saved in
19+
# https://azuremlschemas.azureedge.net/latest/pipelineJob.schema.json
20+
base_dir = Path("./tests/test_configs/pipeline_jobs/json_schema_validation")
21+
target_schema = PatchedJSONSchema().dump(PipelineJob._create_schema_for_validation(context={"base_path": "./"}))
22+
23+
with open(base_dir.joinpath("component_spec.yaml"), "r") as f:
24+
yaml_data = yaml.safe_load(f.read())
25+
26+
validate(yaml_data, target_schema)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
$schema: https://azuremlschemas.azureedge.net/latest/pipelineJob.schema.json
2+
type: pipeline
3+
4+
settings:
5+
default_compute: azureml:cpu-cluster
6+
7+
8+
jobs:
9+
process_data:
10+
type: command
11+
component: ./1process_data_component.yaml
12+
inputs:
13+
raw_data: ${{parent.inputs.raw_data}}
14+
outputs:
15+
train:
16+
validation:
17+
compute: azureml:cpu-cluster

0 commit comments

Comments
 (0)