Skip to content

Commit e4ec7e6

Browse files
authored
[ML][Pipelines] feat: concurrent artifact download (Azure#28988)
* feat: make artifact download thread-safe * fix: avoid wait in calling az artifacts
1 parent f970c11 commit e4ec7e6

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/_artifact_cache.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import subprocess
1111
import tempfile
1212
import zipfile
13+
from collections import defaultdict
1314
from io import BytesIO
1415
from pathlib import Path
1516
from threading import Lock
@@ -48,17 +49,21 @@ def __new__(cls):
4849
def __init__(self, cache_directory=None):
4950
self._cache_directory = cache_directory or self.DEFAULT_DISK_CACHE_DIRECTORY
5051
Path(self._cache_directory).mkdir(exist_ok=True, parents=True)
51-
# check az extension azure-devops installed
52+
# check az extension azure-devops installed. Install it if not installed.
5253
process = subprocess.Popen(
53-
"az artifacts --help",
54+
"az artifacts --help --yes",
5455
shell=True, # nosec B602
5556
stdout=subprocess.PIPE,
5657
stderr=subprocess.PIPE,
5758
)
5859
process.communicate()
5960
if process.returncode != 0:
60-
subprocess.check_call("az extension add --name azure-devops", shell=True)
61+
raise RuntimeError(
62+
"Auto-installation failed. Please install azure-devops "
63+
"extension by 'az extension add --name azure-devops'."
64+
)
6165
self._artifacts_tool_path = None
66+
self._download_locks = defaultdict(Lock)
6267

6368
@property
6469
def cache_directory(self):
@@ -253,27 +258,29 @@ def get(self, feed, name, version, scope, organization=None, project=None, resol
253258
/ name
254259
/ version
255260
)
256-
if self._check_artifacts(artifact_package_path):
257-
# When the cache folder of artifact package exists, it's sure that the package has been downloaded.
258-
return artifact_package_path.absolute().resolve()
259-
if resolve:
260-
check_sum_path = self._get_checksum_path(artifact_package_path)
261-
if Path(check_sum_path).exists():
262-
os.unlink(check_sum_path)
263-
if artifact_package_path.exists():
264-
# Remove invalid artifact package to avoid affecting download artifact.
265-
temp_folder = tempfile.mktemp() # nosec B306
266-
os.rename(artifact_package_path, temp_folder)
267-
shutil.rmtree(temp_folder)
268-
# Download artifact
269-
return self.set(
270-
feed=feed,
271-
name=name,
272-
version=version,
273-
organization=organization,
274-
project=project,
275-
scope=scope,
276-
)
261+
# Use lock to avoid downloading the same package at the same time.
262+
with self._download_locks[artifact_package_path]:
263+
if self._check_artifacts(artifact_package_path):
264+
# When the cache folder of artifact package exists, it's sure that the package has been downloaded.
265+
return artifact_package_path.absolute().resolve()
266+
if resolve:
267+
check_sum_path = self._get_checksum_path(artifact_package_path)
268+
if Path(check_sum_path).exists():
269+
os.unlink(check_sum_path)
270+
if artifact_package_path.exists():
271+
# Remove invalid artifact package to avoid affecting download artifact.
272+
temp_folder = tempfile.mktemp() # nosec B306
273+
os.rename(artifact_package_path, temp_folder)
274+
shutil.rmtree(temp_folder)
275+
# Download artifact
276+
return self.set(
277+
feed=feed,
278+
name=name,
279+
version=version,
280+
organization=organization,
281+
project=project,
282+
scope=scope,
283+
)
277284
return None
278285

279286
def set(self, feed, name, version, scope, organization=None, project=None):

0 commit comments

Comments
 (0)