Skip to content

Commit f970c11

Browse files
authored
Switch to using temporary data reference SAS uri for non-registry snapshot uploads (Azure#28957)
* Refactor code create_or_update to use temporary data reference for uploads
1 parent f7465ff commit f970c11

File tree

311 files changed

+118058
-161983
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

311 files changed

+118058
-161983
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_artifact_utilities.py

Lines changed: 321 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# pylint: disable=protected-access
66

7+
import json
78
import logging
89
import os
910
import uuid
@@ -13,7 +14,8 @@
1314

1415
from azure.ai.ml._artifacts._blob_storage_helper import BlobStorageClient
1516
from azure.ai.ml._artifacts._gen2_storage_helper import Gen2StorageClient
16-
from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata
17+
from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata, _get_cloud_details
18+
from azure.ai.ml._restclient.v2022_05_01.models import Workspace
1719
from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType
1820
from azure.ai.ml._scope_dependent_operations import OperationScope
1921
from azure.ai.ml._utils._arm_id_utils import (
@@ -29,20 +31,29 @@
2931
_validate_path,
3032
get_ignore_file,
3133
get_object_hash,
34+
get_content_hash,
35+
get_content_hash_version,
3236
)
37+
from azure.ai.ml._utils._http_utils import HttpPipeline
3338
from azure.ai.ml._utils._storage_utils import (
3439
AzureMLDatastorePathUri,
3540
get_artifact_path_from_storage_url,
3641
get_storage_client,
3742
)
38-
from azure.ai.ml._utils.utils import is_mlflow_uri, is_url
39-
from azure.ai.ml.constants._common import SHORT_URI_FORMAT, STORAGE_ACCOUNT_URLS
43+
from azure.ai.ml._utils.utils import is_mlflow_uri, is_url, retry, replace_between
44+
from azure.ai.ml.constants._common import (
45+
SHORT_URI_FORMAT,
46+
STORAGE_ACCOUNT_URLS,
47+
MAX_ASSET_STORE_API_CALL_RETRIES,
48+
HTTPS_PREFIX,
49+
)
4050
from azure.ai.ml.entities import Environment
4151
from azure.ai.ml.entities._assets._artifacts.artifact import Artifact, ArtifactStorageInfo
4252
from azure.ai.ml.entities._credentials import AccountKeyConfiguration
4353
from azure.ai.ml.entities._datastore._constants import WORKSPACE_BLOB_STORE
4454
from azure.ai.ml.exceptions import ErrorTarget, ValidationException
4555
from azure.ai.ml.operations._datastore_operations import DatastoreOperations
56+
from azure.core.exceptions import HttpResponseError
4657
from azure.storage.blob import BlobSasPermissions, generate_blob_sas
4758
from azure.storage.filedatalake import FileSasPermissions, generate_file_sas
4859

@@ -359,6 +370,126 @@ def _update_gen2_metadata(name, version, indicator_file, storage_client) -> None
359370
artifact_directory_client.set_metadata(_build_metadata_dict(name=name, version=version))
360371

361372

373+
def _generate_temporary_data_reference_id() -> str:
374+
"""Generate a temporary data reference id."""
375+
return str(uuid.uuid4())
376+
377+
378+
@retry(
379+
exceptions=HttpResponseError,
380+
failure_msg="Artifact upload exceeded maximum retries. Try again.",
381+
logger=module_logger,
382+
max_attempts=MAX_ASSET_STORE_API_CALL_RETRIES,
383+
)
384+
def _get_snapshot_temporary_data_reference(
385+
asset_name: str,
386+
asset_version: str,
387+
request_headers: Dict[str, str],
388+
workspace: Workspace,
389+
requests_pipeline: HttpPipeline,
390+
) -> Tuple[str, str]:
391+
"""
392+
Make a temporary data reference for an asset and return SAS uri and blob storage uri.
393+
:param asset_name: Name of the asset to be created
394+
:type asset_name: str
395+
:param asset_version: Version of the asset to be created
396+
:type asset_version: str
397+
:param request_headers: Request headers for API call
398+
:type request_headers: Dict[str, str]
399+
:param workspace: Workspace object
400+
:type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace
401+
:param requests_pipeline: Proxy for sending HTTP requests
402+
:type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline
403+
:return: Existing asset's name and version, if found
404+
:rtype: Tuple[str, str]
405+
"""
406+
407+
# create temporary data reference
408+
temporary_data_reference_id = _generate_temporary_data_reference_id()
409+
410+
# build and send request
411+
asset_id = (
412+
f"azureml://locations/{workspace.location}/workspaces/{workspace.workspace_id}/"
413+
f"codes/{asset_name}/versions/{asset_version}"
414+
)
415+
data = {
416+
"assetId": asset_id,
417+
"temporaryDataReferenceId": temporary_data_reference_id,
418+
"temporaryDataReferenceType": "TemporaryBlobReference",
419+
}
420+
data_encoded = json.dumps(data).encode("utf-8")
421+
serialized_data = json.loads(data_encoded)
422+
423+
# make sure correct cloud endpoint is used
424+
cloud_endpoint = _get_cloud_details()["registry_discovery_endpoint"]
425+
service_url = replace_between(cloud_endpoint, HTTPS_PREFIX, ".", workspace.location)
426+
427+
# send request
428+
request_url = f"{service_url}assetstore/v1.0/temporaryDataReference/createOrGet"
429+
response = requests_pipeline.post(request_url, json=serialized_data, headers=request_headers)
430+
431+
if response.status_code != 200:
432+
raise HttpResponseError(response=response)
433+
434+
response_json = json.loads(response.text())
435+
436+
# get SAS uri for upload and blob uri for asset creation
437+
blob_uri = response_json["blobReferenceForConsumption"]["blobUri"]
438+
sas_uri = response_json["blobReferenceForConsumption"]["credential"]["sasUri"]
439+
440+
return sas_uri, blob_uri
441+
442+
443+
def _get_asset_by_hash(
444+
operations: "DatastoreOperations",
445+
hash_str: str,
446+
request_headers: Dict[str, str],
447+
workspace: Workspace,
448+
requests_pipeline: HttpPipeline,
449+
) -> Dict[str, str]:
450+
"""
451+
Check if an asset with the same hash already exists in the workspace. If so, return the asset name and version.
452+
:param operations: Datastore Operations object from MLClient
453+
:type operations: azure.ai.ml.operations._datastore_operations.DatastoreOperations
454+
:param hash_str: The hash of the specified local upload
455+
:type hash_str: str
456+
:param request_headers: Request headers for API call
457+
:type request_headers: Dict[str, str]
458+
:param workspace: Workspace object
459+
:type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace
460+
:param requests_pipeline: Proxy for sending HTTP requests
461+
:type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline
462+
:return: Existing asset's name and version, if found
463+
:rtype: Optional[Dict[str, str]]
464+
"""
465+
existing_asset = {}
466+
hash_version = get_content_hash_version()
467+
468+
# get workspace credentials
469+
subscription_id = operations._subscription_id
470+
resource_group_name = operations._resource_group_name
471+
472+
# make sure correct cloud endpoint is used
473+
cloud_endpoint = _get_cloud_details()["registry_discovery_endpoint"]
474+
service_url = replace_between(cloud_endpoint, HTTPS_PREFIX, ".", workspace.location)
475+
request_url = (
476+
f"{service_url}content/v2.0/subscriptions/{subscription_id}/"
477+
f"resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/"
478+
f"{workspace.name}/snapshots/getByHash?hash={hash_str}&hashVersion={hash_version}"
479+
)
480+
481+
response = requests_pipeline.get(request_url, headers=request_headers)
482+
if response.status_code != 200:
483+
# If API is unresponsive, create new asset
484+
return None
485+
486+
response_json = json.loads(response.text())
487+
existing_asset["name"] = response_json["name"]
488+
existing_asset["version"] = response_json["version"]
489+
490+
return existing_asset
491+
492+
362493
T = TypeVar("T", bound=Artifact)
363494

364495

@@ -422,6 +553,193 @@ def _check_and_upload_path(
422553
return artifact, indicator_file
423554

424555

556+
def _get_snapshot_path_info(artifact) -> Tuple[str, str, str]:
557+
"""
558+
Validate an Artifact's local path and get its resolved path, ignore file, and hash
559+
:param artifact: Artifact object
560+
:type artifact: azure.ai.ml.entities._assets._artifacts.artifact.Artifact
561+
:return: Artifact's path, ignorefile, and hash
562+
:rtype: Tuple[str, str, str]
563+
"""
564+
if (
565+
hasattr(artifact, "local_path")
566+
and artifact.local_path is not None
567+
or (
568+
hasattr(artifact, "path")
569+
and artifact.path is not None
570+
and not (is_url(artifact.path) or is_mlflow_uri(artifact.path))
571+
)
572+
):
573+
path = (
574+
Path(artifact.path)
575+
if hasattr(artifact, "path") and artifact.path is not None
576+
else Path(artifact.local_path)
577+
)
578+
if not path.is_absolute():
579+
path = Path(artifact.base_path, path).resolve()
580+
581+
_validate_path(path, _type=ErrorTarget.CODE)
582+
583+
ignore_file = get_ignore_file(path)
584+
asset_hash = get_content_hash(path, ignore_file)
585+
586+
return path, ignore_file, asset_hash
587+
588+
589+
def _get_existing_snapshot_by_hash(
590+
datastore_operation,
591+
asset_hash,
592+
workspace: Workspace,
593+
requests_pipeline: HttpPipeline,
594+
) -> Dict[str, str]:
595+
"""
596+
Check if an asset with the same hash already exists in the workspace. If so, return the asset name and version.
597+
:param datastore_operation: Datastore Operations object from MLClient
598+
:type operations: azure.ai.ml.operations._datastore_operations.DatastoreOperations
599+
:param asset_hash: The hash of the specified local upload
600+
:type asset_hash: str
601+
:param workspace: Workspace object
602+
:type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace
603+
:param requests_pipeline: Proxy for sending HTTP requests
604+
:type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline
605+
:return: Existing asset's name and version, if found
606+
:rtype: Optional[Dict[str, str]]
607+
"""
608+
ws_base_url = datastore_operation._operation._client._base_url
609+
token = datastore_operation._credential.get_token(ws_base_url + "/.default").token
610+
request_headers = {"Authorization": "Bearer " + token}
611+
request_headers["Content-Type"] = "application/json; charset=UTF-8"
612+
613+
existing_asset = _get_asset_by_hash(
614+
operations=datastore_operation,
615+
hash_str=asset_hash,
616+
request_headers=request_headers,
617+
workspace=workspace,
618+
requests_pipeline=requests_pipeline,
619+
)
620+
621+
return existing_asset
622+
623+
624+
def _upload_snapshot_to_datastore(
625+
operation_scope: OperationScope,
626+
datastore_operation: DatastoreOperations,
627+
path: Union[str, Path, os.PathLike],
628+
workspace: Workspace,
629+
requests_pipeline: HttpPipeline,
630+
datastore_name: str = None,
631+
show_progress: bool = True,
632+
asset_name: str = None,
633+
asset_version: str = None,
634+
asset_hash: str = None,
635+
ignore_file: IgnoreFile = IgnoreFile(),
636+
sas_uri: str = None, # contains registry sas url
637+
) -> ArtifactStorageInfo:
638+
"""
639+
Upload a code snapshot to workspace datastore.
640+
:param operation_scope: Workspace scope information
641+
:type operation_scope: azure.ai.ml._scope_dependent_operations.OperationScope
642+
:param datastore_operation: Datastore Operations object from MLClient
643+
:type datastore_operation: azure.ai.ml.operations._datastore_operations.DatastoreOperations
644+
:param path: The local path of the artifact
645+
:type path: Union[str, Path, os.PathLike]
646+
:param workspace: Workspace object
647+
:type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace
648+
:param requests_pipeline: Proxy for sending HTTP requests
649+
:type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline
650+
:param datastore_name: Name of the datastore to upload to
651+
:type datastore_name: str
652+
:param show_progress: Whether or not to show progress bar during upload, defaults to True
653+
:type show_progress: bool
654+
:param asset_name: Name of the asset to be created
655+
:type asset_name: str
656+
:param asset_version: Version of the asset to be created
657+
:type asset_version: str
658+
:param asset_hash: The hash of the specified local upload
659+
:type asset_hash: str
660+
:param ignore_file: Information about the path's .gitignore or .amlignore file, if exists
661+
:type ignore_file: azure.ai.ml._utils._asset_utils.IgnoreFile
662+
:param sas_uri: SAS uri for uploading to datastore
663+
:type sas_uri: str
664+
:return: Uploaded artifact's storage information
665+
:rtype: azure.ai.ml.entities._assets._artifacts.artifact.ArtifactStorageInfo
666+
"""
667+
ws_base_url = datastore_operation._operation._client._base_url
668+
token = datastore_operation._credential.get_token(ws_base_url + "/.default").token
669+
request_headers = {"Authorization": "Bearer " + token}
670+
request_headers["Content-Type"] = "application/json; charset=UTF-8"
671+
672+
if not sas_uri:
673+
sas_uri, blob_uri = _get_snapshot_temporary_data_reference(
674+
requests_pipeline=requests_pipeline,
675+
asset_name=asset_name,
676+
asset_version=asset_version,
677+
request_headers=request_headers,
678+
workspace=workspace,
679+
)
680+
681+
artifact = upload_artifact(
682+
str(path),
683+
datastore_operation,
684+
operation_scope,
685+
datastore_name,
686+
show_progress=show_progress,
687+
asset_hash=asset_hash,
688+
asset_name=asset_name,
689+
asset_version=asset_version,
690+
ignore_file=ignore_file,
691+
sas_uri=sas_uri,
692+
)
693+
artifact.storage_account_url = blob_uri
694+
695+
return artifact
696+
697+
698+
def _check_and_upload_snapshot(
699+
artifact: T,
700+
asset_operations: Union["DataOperations", "ModelOperations", "CodeOperations"],
701+
path: Union[str, Path, os.PathLike],
702+
workspace: Workspace,
703+
requests_pipeline: HttpPipeline,
704+
ignore_file: IgnoreFile = None,
705+
datastore_name: str = WORKSPACE_BLOB_STORE,
706+
sas_uri: str = None,
707+
show_progress: bool = True,
708+
) -> Tuple[T, str]:
709+
"""Checks whether `artifact` is a path or a uri and uploads it to the
710+
datastore if necessary.
711+
param T artifact: artifact to check and upload param
712+
Union["DataOperations", "ModelOperations", "CodeOperations"]
713+
asset_operations: the asset operations to use for uploading
714+
param str datastore_name: the name of the datastore to upload to
715+
param str sas_uri: the sas uri to use for uploading
716+
"""
717+
uploaded_artifact = _upload_snapshot_to_datastore(
718+
operation_scope=asset_operations._operation_scope,
719+
datastore_operation=asset_operations._datastore_operation,
720+
path=path,
721+
workspace=workspace,
722+
requests_pipeline=requests_pipeline,
723+
datastore_name=datastore_name,
724+
asset_name=artifact.name,
725+
asset_version=str(artifact.version),
726+
asset_hash=artifact._upload_hash if hasattr(artifact, "_upload_hash") else None,
727+
show_progress=show_progress,
728+
sas_uri=sas_uri,
729+
ignore_file=ignore_file,
730+
)
731+
732+
if artifact._is_anonymous:
733+
artifact.name, artifact.version = (
734+
uploaded_artifact.name,
735+
uploaded_artifact.version,
736+
)
737+
# Pass all of the upload information to the assets, and they will each construct the URLs that they support
738+
artifact._update_path(uploaded_artifact)
739+
740+
return artifact
741+
742+
425743
def _check_and_upload_env_build_context(
426744
environment: Environment,
427745
operations: "EnvironmentOperations",

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
self._operation_config,
351351
self._service_client_10_2021_dataplanepreview if registry_name else self._service_client_05_2022,
352352
self._datastores,
353+
requests_pipeline=self._requests_pipeline,
353354
**ops_kwargs,
354355
)
355356
self._operation_container.add(AzureMLResourceType.CODE, self._code)

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
# pylint: disable=protected-access
5+
# pylint: disable=protected-access,too-many-lines
66
import copy
77
import decimal
88
import hashlib
@@ -996,3 +996,8 @@ def get_valid_dot_keys_with_wildcard(
996996
"""
997997
left_reversed_parts = dot_key_wildcard.split(".")[::-1]
998998
return _get_valid_dot_keys_with_wildcard_impl(left_reversed_parts, root, validate_func=validate_func)
999+
1000+
1001+
def replace_between(s: str, start: str, end: str, replace: str) -> str:
1002+
"""Replace string between two substrings."""
1003+
return start + replace + s[s.find(end) - 1 + len(end) :]

0 commit comments

Comments
 (0)