22# Copyright (c) Microsoft Corporation. All rights reserved.
33# ---------------------------------------------------------
44import json
5+ import os
56from typing import Dict , Union
67
7- from azure .ai .ml import Output
8+ from azure .ai .ml import Output , Input
89from azure .ai .ml ._schema import PathAwareSchema
910from azure .ai .ml ._schema .pipeline .control_flow_job import ParallelForSchema
1011from azure .ai .ml ._utils .utils import is_data_binding_expression
@@ -28,7 +29,7 @@ class ParallelFor(LoopNode, NodeIOMixin):
2829 :param body: Pipeline job for the parallel for loop body.
2930 :type body: Pipeline
3031 :param items: The loop body's input which will bind to the loop node.
31- :type items: Union[list, dict, str, PipelineInput, NodeOutput ]
32+ :type items: typing. Union[list, dict, str, NodeOutput, PipelineInput ]
3233 :param max_concurrency: Maximum number of concurrent iterations to run. All loop body nodes will be executed
3334 in parallel if not specified.
3435 :type max_concurrency: int
@@ -105,17 +106,92 @@ def _attr_type_map(cls) -> dict:
105106 "items" : (dict , list , str , PipelineInput , NodeOutput ),
106107 }
107108
109+ @classmethod
110+ def _to_rest_item (cls , item : dict ) -> dict :
111+ """Convert item to rest object."""
112+ primitive_inputs , asset_inputs = {}, {}
113+ # validate item
114+ for key , val in item .items ():
115+ if isinstance (val , Input ):
116+ asset_inputs [key ] = val
117+ elif isinstance (val , (PipelineInput , NodeOutput )):
118+ # convert binding object to string
119+ primitive_inputs [key ] = str (val )
120+ else :
121+ primitive_inputs [key ] = val
122+ return {
123+ # asset type inputs will be converted to JobInput dict:
124+ # {"asset_param": {"uri": "xxx", "job_input_type": "uri_file"}}
125+ ** cls ._input_entity_to_rest_inputs (input_entity = asset_inputs ),
126+ # primitive inputs has primitive type value like this
127+ # {"int_param": 1}
128+ ** primitive_inputs
129+ }
130+
131+ @classmethod
132+ def _to_rest_items (cls , items : Union [list , dict , str , NodeOutput , PipelineInput ]) -> str :
133+ """Convert items to rest object."""
134+ # validate items.
135+ cls ._validate_items (items = items , raise_error = True , body_component = None )
136+ # convert items to rest object
137+ if isinstance (items , list ):
138+ rest_items = [cls ._to_rest_item (item = i ) for i in items ]
139+ rest_items = json .dumps (rest_items )
140+ elif isinstance (items , dict ):
141+ rest_items = {k : cls ._to_rest_item (item = v ) for k , v in items .items ()}
142+ rest_items = json .dumps (rest_items )
143+ elif isinstance (items , (NodeOutput , PipelineInput )):
144+ rest_items = str (items )
145+ elif isinstance (items , str ):
146+ rest_items = items
147+ else :
148+ raise UserErrorException ("Unsupported items type: {}" .format (type (items )))
149+ return rest_items
150+
108151 def _to_rest_object (self , ** kwargs ) -> dict : # pylint: disable=unused-argument
109152 """Convert self to a rest object for remote call."""
110153 rest_node = super (ParallelFor , self )._to_rest_object (** kwargs )
111- rest_node .update (dict (outputs = self ._to_rest_outputs ()))
154+ # convert items to rest object
155+ rest_items = self ._to_rest_items (items = self .items )
156+ rest_node .update (dict (
157+ items = rest_items ,
158+ outputs = self ._to_rest_outputs ()
159+ ))
112160 return convert_ordered_dict_to_dict (rest_node )
113161
162+ @classmethod
163+ def _from_rest_item (cls , rest_item ):
164+ """Convert rest item to item."""
165+ primitive_inputs , asset_inputs = {}, {}
166+ for key , val in rest_item .items ():
167+ if isinstance (val , dict ) and val .get ("job_input_type" ):
168+ asset_inputs [key ] = val
169+ else :
170+ primitive_inputs [key ] = val
171+ return {
172+ ** cls ._from_rest_inputs (inputs = asset_inputs ),
173+ ** primitive_inputs
174+ }
175+
176+ @classmethod
177+ def _from_rest_items (cls , rest_items : str ) -> Union [dict , list , str ]:
178+ """Convert items from rest object."""
179+ try :
180+ items = json .loads (rest_items )
181+ except json .JSONDecodeError :
182+ # return original items when failed to load
183+ return rest_items
184+ if isinstance (items , list ):
185+ return [cls ._from_rest_item (rest_item = i ) for i in items ]
186+ if isinstance (items , dict ):
187+ return {k : cls ._from_rest_item (rest_item = v ) for k , v in items .items ()}
188+ return rest_items
189+
114190 @classmethod
115191 def _from_rest_object (cls , obj : dict , pipeline_jobs : dict ) -> "ParallelFor" :
116192 # pylint: disable=protected-access
117-
118193 obj = BaseNode ._from_rest_object_to_init_params (obj )
194+ obj ["items" ] = cls ._from_rest_items (rest_items = obj .get ("items" , "" ))
119195 return cls ._create_instance_from_schema_dict (pipeline_jobs = pipeline_jobs , loaded_data = obj )
120196
121197 @classmethod
@@ -149,11 +225,21 @@ def _convert_output_meta(self, outputs):
149225 aggregate_outputs [name ] = resolved_output
150226 return aggregate_outputs
151227
152- def _validate_items (self , raise_error = True ):
153- validation_result = self ._create_empty_validation_result ()
154- if self .items is not None :
155- items = self .items
228+ def _customized_validate (self ):
229+ """Customized validation for parallel for node."""
230+ # pylint: disable=protected-access
231+ validation_result = self ._validate_body (raise_error = False )
232+ validation_result .merge_with (
233+ self ._validate_items (items = self .items , raise_error = False , body_component = self .body ._component )
234+ )
235+ return validation_result
236+
237+ @classmethod
238+ def _validate_items (cls , items , raise_error = True , body_component = None ):
239+ validation_result = cls ._create_empty_validation_result ()
240+ if items is not None :
156241 if isinstance (items , str ):
242+ # TODO: remove the validation
157243 # try to deserialize str if it's a json string
158244 try :
159245 items = json .loads (items )
@@ -168,7 +254,7 @@ def _validate_items(self, raise_error=True):
168254 items = list (items .values ())
169255 if isinstance (items , list ):
170256 if len (items ) > 0 :
171- self ._validate_items_list (items , validation_result )
257+ cls ._validate_items_list (items , validation_result , body_component = body_component )
172258 else :
173259 validation_result .append_error (
174260 yaml_path = "items" ,
@@ -179,17 +265,12 @@ def _validate_items(self, raise_error=True):
179265 message = "Items is required for parallel_for node" ,
180266 )
181267 return validation_result .try_raise (
182- self ._get_validation_error_target (),
268+ cls ._get_validation_error_target (),
183269 raise_error = raise_error ,
184270 )
185271
186- def _customized_validate (self ):
187- """Customized validation for parallel for node."""
188- validation_result = self ._validate_body (raise_error = False )
189- validation_result .merge_with (self ._validate_items (raise_error = False ))
190- return validation_result
191-
192- def _validate_items_list (self , items : list , validation_result ):
272+ @classmethod
273+ def _validate_items_list (cls , items : list , validation_result , body_component = None ):
193274 # pylint: disable=protected-access
194275 meta = {}
195276 # all items have to be dict and have matched meta
@@ -213,10 +294,41 @@ def _validate_items_list(self, items: list, validation_result):
213294 message = msg
214295 )
215296 # items' keys should appear in body's inputs
216- body_component = self .body ._component
217297 if isinstance (body_component , Component ) and (not item .keys () <= body_component .inputs .keys ()):
218298 msg = f"Item { item } got unmatched inputs with loop body component inputs { body_component .inputs } ."
219299 validation_result .append_error (
220300 yaml_path = "items" ,
221301 message = msg
222302 )
303+ # validate item value type
304+ cls ._validate_item_value_type (item = item , validation_result = validation_result )
305+
306+ @classmethod
307+ def _validate_item_value_type (cls , item : dict , validation_result ):
308+ # pylint: disable=protected-access
309+ supported_types = (Input , str , bool , int , float , PipelineInput )
310+ for _ , val in item .items ():
311+ if not isinstance (val , supported_types ):
312+ validation_result .append_error (
313+ yaml_path = "items" ,
314+ message = "Unsupported type {} in parallel_for items. Supported types are: {}" .format (
315+ type (val ), supported_types
316+ )
317+ )
318+ if isinstance (val , Input ):
319+ cls ._validate_input_item_value (entry = val , validation_result = validation_result )
320+
321+ @classmethod
322+ def _validate_input_item_value (cls , entry : Input , validation_result ):
323+ if not isinstance (entry , Input ):
324+ return
325+ if not entry .path :
326+ validation_result .append_error (
327+ yaml_path = "items" ,
328+ message = f"Input path not provided for { entry } ." ,
329+ )
330+ if isinstance (entry .path , str ) and os .path .exists (entry .path ):
331+ validation_result .append_error (
332+ yaml_path = "items" ,
333+ message = f"Local file input { entry } is not supported, please create it as a dataset." ,
334+ )
0 commit comments