Skip to content

Commit dffd1a6

Browse files
authored
[ML][Pipelines] feat: support concurrent artifact downloading (Azure#29597)
* feat: support concurrent artifact downloading * refactor: unify cache folder * feat: enable concurrent artifact uploading within a component * fix: typo * feat: resolve all leaf nodes concurrently * fix: handle loop node * revert: revert leaf node first optimization * fix: pylint
1 parent 3bc5595 commit dffd1a6

File tree

5 files changed

+110
-79
lines changed

5 files changed

+110
-79
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/_additional_includes.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import shutil
77
import tempfile
88
import zipfile
9+
from concurrent.futures import ThreadPoolExecutor
10+
from multiprocessing import cpu_count
911
from pathlib import Path
1012
from typing import Union
1113

1214
import yaml
1315

16+
from ..._artifacts._constants import PROCESSES_PER_CORE
1417
from ..._utils._asset_utils import IgnoreFile, traverse_directory
18+
from ..._utils.utils import is_concurrent_component_registration_enabled, is_private_preview_enabled
1519
from ...entities._util import _general_copy
1620
from ...entities._validation import MutableValidationResult, _ValidationResultBuilder
1721
from ._artifact_cache import ArtifactCache
@@ -269,6 +273,50 @@ def _artifact_validate_result(self):
269273
self._load_artifact_additional_includes()
270274
return self.__artifact_validate_result
271275

276+
@classmethod
277+
def merge_local_path_to_additional_includes(cls, local_path, config_info, conflict_files):
278+
file_name = Path(local_path).name
279+
conflicts = conflict_files.get(file_name, set())
280+
conflicts.add(config_info)
281+
conflict_files[file_name] = conflicts
282+
283+
@classmethod
284+
def _get_artifacts_by_config(cls, artifact_config):
285+
artifact_cache = ArtifactCache()
286+
if any(item not in artifact_config for item in ["feed", "name", "version"]):
287+
raise RuntimeError("Feed, name and version are required for artifacts config.")
288+
return artifact_cache.get(
289+
organization=artifact_config.get("organization", None),
290+
project=artifact_config.get("project", None),
291+
feed=artifact_config["feed"],
292+
name=artifact_config["name"],
293+
version=artifact_config["version"],
294+
scope=artifact_config.get("scope", "organization"),
295+
resolve=True,
296+
)
297+
298+
def _resolve_additional_include_config(self, additional_include_config):
299+
result = []
300+
if isinstance(additional_include_config, dict) and additional_include_config.get("type") == ARTIFACT_KEY:
301+
try:
302+
# Get the artifacts package from devops to the local
303+
artifact_path = self._get_artifacts_by_config(additional_include_config)
304+
for item in os.listdir(artifact_path):
305+
config_info = (
306+
f"{additional_include_config['name']}:{additional_include_config['version']} in "
307+
f"{additional_include_config['feed']}"
308+
)
309+
result.append((os.path.join(artifact_path, item), config_info))
310+
except Exception as e: # pylint: disable=broad-except
311+
self._artifact_validate_result.append_error(message=e.args[0])
312+
elif isinstance(additional_include_config, str):
313+
result.append((additional_include_config, additional_include_config))
314+
else:
315+
self._artifact_validate_result.append_error(
316+
message=f"Unexpected format in additional_includes, {additional_include_config}"
317+
)
318+
return result
319+
272320
def _load_artifact_additional_includes(self):
273321
"""
274322
Load the additional includes by yaml format.
@@ -290,57 +338,34 @@ def _load_artifact_additional_includes(self):
290338
:return additional_includes: Path list of additional_includes
291339
:rtype additional_includes: List[str]
292340
"""
293-
additional_includes, conflict_files = [], {}
294341
self.__artifact_validate_result = _ValidationResultBuilder.success()
295342

296-
def merge_local_path_to_additional_includes(local_path, config_info):
297-
additional_includes.append(local_path)
298-
file_name = Path(local_path).name
299-
conflicts = conflict_files.get(file_name, set())
300-
conflicts.add(config_info)
301-
conflict_files[file_name] = conflicts
302-
303-
def get_artifacts_by_config(artifact_config):
304-
artifact_cache = ArtifactCache()
305-
if any(item not in artifact_config for item in ["feed", "name", "version"]):
306-
raise RuntimeError("Feed, name and version are required for artifacts config.")
307-
artifact_path = artifact_cache.get(
308-
organization=artifact_config.get("organization", None),
309-
project=artifact_config.get("project", None),
310-
feed=artifact_config["feed"],
311-
name=artifact_config["name"],
312-
version=artifact_config["version"],
313-
scope=artifact_config.get("scope", "organization"),
314-
resolve=True,
315-
)
316-
return artifact_path
317-
318343
# Load the artifacts config from additional_includes
319344
with open(self._additional_includes_file_path) as f:
320345
additional_includes_configs = yaml.safe_load(f)
321346
additional_includes_configs = additional_includes_configs.get(ADDITIONAL_INCLUDES_KEY, [])
322347

323-
for additional_include in additional_includes_configs:
324-
if isinstance(additional_include, dict) and additional_include.get("type") == ARTIFACT_KEY:
325-
try:
326-
# Get the artifacts package from devops to the local
327-
artifact_path = get_artifacts_by_config(additional_include)
328-
for item in os.listdir(artifact_path):
329-
config_info = (
330-
f"{additional_include['name']}:{additional_include['version']} in "
331-
f"{additional_include['feed']}"
332-
)
333-
merge_local_path_to_additional_includes(
334-
local_path=os.path.join(artifact_path, item), config_info=config_info
348+
additional_includes, conflict_files = [], {}
349+
num_threads = int(cpu_count()) * PROCESSES_PER_CORE
350+
if (
351+
len(additional_includes_configs) > 1
352+
and is_concurrent_component_registration_enabled()
353+
and is_private_preview_enabled()
354+
):
355+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
356+
for result in executor.map(self._resolve_additional_include_config, additional_includes_configs):
357+
for local_path, config_info in result:
358+
additional_includes.append(local_path)
359+
self.merge_local_path_to_additional_includes(
360+
local_path=local_path, config_info=config_info, conflict_files=conflict_files
335361
)
336-
except Exception as e: # pylint: disable=broad-except
337-
self._artifact_validate_result.append_error(message=e.args[0])
338-
elif isinstance(additional_include, str):
339-
merge_local_path_to_additional_includes(local_path=additional_include, config_info=additional_include)
340-
else:
341-
self._artifact_validate_result.append_error(
342-
message=f"Unexpected format in additional_includes, {additional_include}"
343-
)
362+
else:
363+
for result in map(self._resolve_additional_include_config, additional_includes_configs):
364+
for local_path, config_info in result:
365+
additional_includes.append(local_path)
366+
self.merge_local_path_to_additional_includes(
367+
local_path=local_path, config_info=config_info, conflict_files=conflict_files
368+
)
344369

345370
# Check the file conflict in local path and artifact package.
346371
conflict_files = {k: v for k, v in conflict_files.items() if len(v) > 1}

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_cache_utils.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,20 @@
1010
from collections import defaultdict
1111
from concurrent.futures import ThreadPoolExecutor
1212
from dataclasses import dataclass
13-
from functools import partial
1413
from pathlib import Path
15-
from typing import List, Dict, Optional, Union, Callable
14+
from typing import Callable, Dict, List, Optional, Union
1615

1716
from azure.ai.ml._utils._asset_utils import get_object_hash
1817
from azure.ai.ml._utils.utils import (
19-
is_on_disk_cache_enabled,
2018
is_concurrent_component_registration_enabled,
19+
is_on_disk_cache_enabled,
2120
is_private_preview_enabled,
2221
write_to_shared_file,
2322
)
24-
from azure.ai.ml.constants._common import AzureMLResourceType, AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS
23+
from azure.ai.ml.constants._common import AZUREML_COMPONENT_REGISTRATION_MAX_WORKERS, AzureMLResourceType
2524
from azure.ai.ml.entities import Component
2625
from azure.ai.ml.entities._builders import BaseNode
2726

28-
2927
logger = logging.getLogger(__name__)
3028

3129
_ANONYMOUS_HASH_PREFIX = "anonymous-component-"
@@ -47,6 +45,9 @@ class _CacheContent:
4745
on_disk_hash: Optional[str] = None
4846
arm_id: Optional[str] = None
4947

48+
def update_on_disk_hash(self):
49+
self.on_disk_hash = CachedNodeResolver.calc_on_disk_hash_for_component(self.component_ref, self.in_memory_hash)
50+
5051

5152
class CachedNodeResolver(object):
5253
"""Class to resolve component in nodes with cached component resolution results.
@@ -176,7 +177,7 @@ def _get_in_memory_hash_for_component(component: Component) -> str:
176177
return _ANONYMOUS_HASH_PREFIX + component._get_anonymous_hash() # pylint: disable=protected-access
177178

178179
@staticmethod
179-
def _get_on_disk_hash_for_component(component: Component, in_memory_hash: str) -> str:
180+
def calc_on_disk_hash_for_component(component: Component, in_memory_hash: str) -> str:
180181
"""Get a hash for a component.
181182
182183
This function will calculate the hash based on the component's code folder if the component has code, so it's
@@ -261,21 +262,23 @@ def _save_to_on_disk_cache(self, on_disk_hash: str, arm_id: str) -> None:
261262

262263
def _resolve_cache_contents(self, cache_contents_to_resolve: List[_CacheContent], resolver):
263264
"""Resolve all components to resolve and save the results in cache."""
264-
_components = list(map(lambda x: x.component_ref, cache_contents_to_resolve))
265-
_map_func = partial(resolver, azureml_type=AzureMLResourceType.COMPONENT)
266265

267-
if len(_components) > 1 and is_concurrent_component_registration_enabled() and is_private_preview_enabled():
266+
def _map_func(_cache_content: _CacheContent):
267+
_cache_content.arm_id = resolver(_cache_content.component_ref, azureml_type=AzureMLResourceType.COMPONENT)
268+
if is_on_disk_cache_enabled() and is_private_preview_enabled():
269+
self._save_to_on_disk_cache(_cache_content.on_disk_hash, _cache_content.arm_id)
270+
271+
if (
272+
len(cache_contents_to_resolve) > 1
273+
and is_concurrent_component_registration_enabled()
274+
and is_private_preview_enabled()
275+
):
268276
# given deduplication has already been done, we can safely assume that there is no
269277
# conflict in concurrent local cache access
270278
with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
271-
resolution_results = executor.map(_map_func, _components)
279+
list(executor.map(_map_func, cache_contents_to_resolve))
272280
else:
273-
resolution_results = map(_map_func, _components)
274-
275-
for cache_content, resolution_results in zip(cache_contents_to_resolve, resolution_results):
276-
cache_content.arm_id = resolution_results
277-
if is_on_disk_cache_enabled() and is_private_preview_enabled():
278-
self._save_to_on_disk_cache(cache_content.on_disk_hash, cache_content.arm_id)
281+
list(map(_map_func, cache_contents_to_resolve))
279282

280283
def _prepare_items_to_resolve(self):
281284
"""Pop all nodes in self._nodes_to_resolve to prepare cache contents to resolve and nodes to resolve. Nodes in
@@ -306,11 +309,21 @@ def _resolve_cache_contents_from_disk(self, cache_contents_to_resolve: List[_Cac
306309
"""Check on-disk cache to resolve cache contents in cache_contents_to_resolve and return unresolved cache
307310
contents."""
308311
# Note that we should recalculate the hash based on code for local cache, as
309-
# we can't assume that the code folder won't change among dependency resolution
310-
for cache_content in cache_contents_to_resolve:
311-
cache_content.on_disk_hash = self._get_on_disk_hash_for_component(
312-
cache_content.component_ref, cache_content.in_memory_hash
313-
)
312+
# we can't assume that the code folder won't change among dependency
313+
# On-disk hash calculation can be slow as it involved data copying and artifact downloading.
314+
# It is thread-safe given:
315+
# 1. artifact downloading is thread-safe as we have a lock in ArtifactCache
316+
# 2. data copying is thread-safe as there is only read operation on source folder
317+
# and target folder is unique for each thread
318+
if (
319+
len(cache_contents_to_resolve) > 1
320+
and is_concurrent_component_registration_enabled()
321+
and is_private_preview_enabled()
322+
):
323+
with ThreadPoolExecutor(max_workers=self._get_component_registration_max_workers()) as executor:
324+
executor.map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve)
325+
else:
326+
list(map(_CacheContent.update_on_disk_hash, cache_contents_to_resolve))
314327

315328
left_cache_contents_to_resolve = []
316329
# need to deduplicate disk hash first if concurrent resolution is enabled

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_feature_store_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,19 @@
66

77
from pathlib import Path
88
from tempfile import TemporaryDirectory
9-
from typing import TYPE_CHECKING, Dict, Optional, Union
9+
from typing import TYPE_CHECKING, Dict, Union
1010
from urllib.parse import urlparse
1111

1212
import yaml
1313

14-
from azure.ai.ml._restclient.v2023_02_01_preview.operations import ( # pylint: disable = unused-import
14+
from .._artifacts._artifact_utilities import get_datastore_info, get_storage_client
15+
from .._restclient.v2023_02_01_preview.operations import ( # pylint: disable = unused-import
1516
FeaturesetContainersOperations,
1617
FeaturesetVersionsOperations,
1718
FeaturestoreEntityContainersOperations,
1819
FeaturestoreEntityVersionsOperations,
1920
)
20-
from azure.ai.ml._artifacts._artifact_utilities import get_datastore_info, get_storage_client
21-
from azure.ai.ml.operations._datastore_operations import DatastoreOperations
22-
21+
from ..operations._datastore_operations import DatastoreOperations
2322
from ._storage_utils import AzureMLDatastorePathUri
2423
from .utils import load_yaml
2524

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import types
99
from functools import partial
1010
from inspect import Parameter, signature
11-
from typing import Callable, Dict, Iterable, Optional, Union, List, Any
11+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
1212

1313
from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
1414
AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
@@ -21,12 +21,7 @@
2121
OperationScope,
2222
_ScopeDependentOperations,
2323
)
24-
25-
from azure.ai.ml._telemetry import (
26-
ActivityType,
27-
monitor_with_activity,
28-
monitor_with_telemetry_mixin,
29-
)
24+
from azure.ai.ml._telemetry import ActivityType, monitor_with_activity, monitor_with_telemetry_mixin
3025
from azure.ai.ml._utils._asset_utils import (
3126
_archive_or_restore,
3227
_create_or_update_autoincrement,
@@ -45,8 +40,8 @@
4540
)
4641
from azure.ai.ml.entities import Component, ValidationResult
4742
from azure.ai.ml.exceptions import ComponentException, ErrorCategory, ErrorTarget, ValidationException
48-
from .._utils._cache_utils import CachedNodeResolver
4943

44+
from .._utils._cache_utils import CachedNodeResolver
5045
from .._utils._experimental import experimental
5146
from .._utils.utils import is_data_binding_expression
5247
from ..entities._builders import BaseNode

sdk/ml/azure-ai-ml/tests/internal_utils/unittests/test_cache_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import mock
77
import pytest
8-
98
from azure.ai.ml import MLClient, load_job
109
from azure.ai.ml._utils._cache_utils import CachedNodeResolver
1110
from azure.ai.ml.entities import Component, PipelineJob
@@ -24,7 +23,7 @@ def _mock_resolver(component: Union[str, Component], azureml_type: str) -> str:
2423
@staticmethod
2524
def _get_cache_path(component: Component, resolver: CachedNodeResolver) -> Path:
2625
in_memory_hash = resolver._get_in_memory_hash_for_component(component)
27-
on_disk_hash = resolver._get_on_disk_hash_for_component(component=component, in_memory_hash=in_memory_hash)
26+
on_disk_hash = resolver.calc_on_disk_hash_for_component(component=component, in_memory_hash=in_memory_hash)
2827
return resolver._get_on_disk_cache_path(on_disk_hash)
2928

3029
@staticmethod

0 commit comments

Comments
 (0)