|
4 | 4 |
|
5 | 5 | # pylint: disable=protected-access |
6 | 6 |
|
| 7 | +import json |
7 | 8 | import logging |
8 | 9 | import os |
9 | 10 | import uuid |
|
13 | 14 |
|
14 | 15 | from azure.ai.ml._artifacts._blob_storage_helper import BlobStorageClient |
15 | 16 | from azure.ai.ml._artifacts._gen2_storage_helper import Gen2StorageClient |
16 | | -from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata |
| 17 | +from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata, _get_cloud_details |
| 18 | +from azure.ai.ml._restclient.v2022_05_01.models import Workspace |
17 | 19 | from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType |
18 | 20 | from azure.ai.ml._scope_dependent_operations import OperationScope |
19 | 21 | from azure.ai.ml._utils._arm_id_utils import ( |
|
29 | 31 | _validate_path, |
30 | 32 | get_ignore_file, |
31 | 33 | get_object_hash, |
| 34 | + get_content_hash, |
| 35 | + get_content_hash_version, |
32 | 36 | ) |
| 37 | +from azure.ai.ml._utils._http_utils import HttpPipeline |
33 | 38 | from azure.ai.ml._utils._storage_utils import ( |
34 | 39 | AzureMLDatastorePathUri, |
35 | 40 | get_artifact_path_from_storage_url, |
36 | 41 | get_storage_client, |
37 | 42 | ) |
38 | | -from azure.ai.ml._utils.utils import is_mlflow_uri, is_url |
39 | | -from azure.ai.ml.constants._common import SHORT_URI_FORMAT, STORAGE_ACCOUNT_URLS |
| 43 | +from azure.ai.ml._utils.utils import is_mlflow_uri, is_url, retry, replace_between |
| 44 | +from azure.ai.ml.constants._common import ( |
| 45 | + SHORT_URI_FORMAT, |
| 46 | + STORAGE_ACCOUNT_URLS, |
| 47 | + MAX_ASSET_STORE_API_CALL_RETRIES, |
| 48 | + HTTPS_PREFIX, |
| 49 | +) |
40 | 50 | from azure.ai.ml.entities import Environment |
41 | 51 | from azure.ai.ml.entities._assets._artifacts.artifact import Artifact, ArtifactStorageInfo |
42 | 52 | from azure.ai.ml.entities._credentials import AccountKeyConfiguration |
43 | 53 | from azure.ai.ml.entities._datastore._constants import WORKSPACE_BLOB_STORE |
44 | 54 | from azure.ai.ml.exceptions import ErrorTarget, ValidationException |
45 | 55 | from azure.ai.ml.operations._datastore_operations import DatastoreOperations |
| 56 | +from azure.core.exceptions import HttpResponseError |
46 | 57 | from azure.storage.blob import BlobSasPermissions, generate_blob_sas |
47 | 58 | from azure.storage.filedatalake import FileSasPermissions, generate_file_sas |
48 | 59 |
|
@@ -359,6 +370,126 @@ def _update_gen2_metadata(name, version, indicator_file, storage_client) -> None |
359 | 370 | artifact_directory_client.set_metadata(_build_metadata_dict(name=name, version=version)) |
360 | 371 |
|
361 | 372 |
|
| 373 | +def _generate_temporary_data_reference_id() -> str: |
| 374 | + """Generate a temporary data reference id.""" |
| 375 | + return str(uuid.uuid4()) |
| 376 | + |
| 377 | + |
| 378 | +@retry( |
| 379 | + exceptions=HttpResponseError, |
| 380 | + failure_msg="Artifact upload exceeded maximum retries. Try again.", |
| 381 | + logger=module_logger, |
| 382 | + max_attempts=MAX_ASSET_STORE_API_CALL_RETRIES, |
| 383 | +) |
| 384 | +def _get_snapshot_temporary_data_reference( |
| 385 | + asset_name: str, |
| 386 | + asset_version: str, |
| 387 | + request_headers: Dict[str, str], |
| 388 | + workspace: Workspace, |
| 389 | + requests_pipeline: HttpPipeline, |
| 390 | +) -> Tuple[str, str]: |
| 391 | + """ |
| 392 | + Make a temporary data reference for an asset and return SAS uri and blob storage uri. |
| 393 | + :param asset_name: Name of the asset to be created |
| 394 | + :type asset_name: str |
| 395 | + :param asset_version: Version of the asset to be created |
| 396 | + :type asset_version: str |
| 397 | + :param request_headers: Request headers for API call |
| 398 | + :type request_headers: Dict[str, str] |
| 399 | + :param workspace: Workspace object |
| 400 | + :type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace |
| 401 | + :param requests_pipeline: Proxy for sending HTTP requests |
| 402 | + :type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline |
| 403 | + :return: Existing asset's name and version, if found |
| 404 | + :rtype: Tuple[str, str] |
| 405 | + """ |
| 406 | + |
| 407 | + # create temporary data reference |
| 408 | + temporary_data_reference_id = _generate_temporary_data_reference_id() |
| 409 | + |
| 410 | + # build and send request |
| 411 | + asset_id = ( |
| 412 | + f"azureml://locations/{workspace.location}/workspaces/{workspace.workspace_id}/" |
| 413 | + f"codes/{asset_name}/versions/{asset_version}" |
| 414 | + ) |
| 415 | + data = { |
| 416 | + "assetId": asset_id, |
| 417 | + "temporaryDataReferenceId": temporary_data_reference_id, |
| 418 | + "temporaryDataReferenceType": "TemporaryBlobReference", |
| 419 | + } |
| 420 | + data_encoded = json.dumps(data).encode("utf-8") |
| 421 | + serialized_data = json.loads(data_encoded) |
| 422 | + |
| 423 | + # make sure correct cloud endpoint is used |
| 424 | + cloud_endpoint = _get_cloud_details()["registry_discovery_endpoint"] |
| 425 | + service_url = replace_between(cloud_endpoint, HTTPS_PREFIX, ".", workspace.location) |
| 426 | + |
| 427 | + # send request |
| 428 | + request_url = f"{service_url}assetstore/v1.0/temporaryDataReference/createOrGet" |
| 429 | + response = requests_pipeline.post(request_url, json=serialized_data, headers=request_headers) |
| 430 | + |
| 431 | + if response.status_code != 200: |
| 432 | + raise HttpResponseError(response=response) |
| 433 | + |
| 434 | + response_json = json.loads(response.text()) |
| 435 | + |
| 436 | + # get SAS uri for upload and blob uri for asset creation |
| 437 | + blob_uri = response_json["blobReferenceForConsumption"]["blobUri"] |
| 438 | + sas_uri = response_json["blobReferenceForConsumption"]["credential"]["sasUri"] |
| 439 | + |
| 440 | + return sas_uri, blob_uri |
| 441 | + |
| 442 | + |
| 443 | +def _get_asset_by_hash( |
| 444 | + operations: "DatastoreOperations", |
| 445 | + hash_str: str, |
| 446 | + request_headers: Dict[str, str], |
| 447 | + workspace: Workspace, |
| 448 | + requests_pipeline: HttpPipeline, |
| 449 | +) -> Dict[str, str]: |
| 450 | + """ |
| 451 | + Check if an asset with the same hash already exists in the workspace. If so, return the asset name and version. |
| 452 | + :param operations: Datastore Operations object from MLClient |
| 453 | + :type operations: azure.ai.ml.operations._datastore_operations.DatastoreOperations |
| 454 | + :param hash_str: The hash of the specified local upload |
| 455 | + :type hash_str: str |
| 456 | + :param request_headers: Request headers for API call |
| 457 | + :type request_headers: Dict[str, str] |
| 458 | + :param workspace: Workspace object |
| 459 | + :type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace |
| 460 | + :param requests_pipeline: Proxy for sending HTTP requests |
| 461 | + :type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline |
| 462 | + :return: Existing asset's name and version, if found |
| 463 | + :rtype: Optional[Dict[str, str]] |
| 464 | + """ |
| 465 | + existing_asset = {} |
| 466 | + hash_version = get_content_hash_version() |
| 467 | + |
| 468 | + # get workspace credentials |
| 469 | + subscription_id = operations._subscription_id |
| 470 | + resource_group_name = operations._resource_group_name |
| 471 | + |
| 472 | + # make sure correct cloud endpoint is used |
| 473 | + cloud_endpoint = _get_cloud_details()["registry_discovery_endpoint"] |
| 474 | + service_url = replace_between(cloud_endpoint, HTTPS_PREFIX, ".", workspace.location) |
| 475 | + request_url = ( |
| 476 | + f"{service_url}content/v2.0/subscriptions/{subscription_id}/" |
| 477 | + f"resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/" |
| 478 | + f"{workspace.name}/snapshots/getByHash?hash={hash_str}&hashVersion={hash_version}" |
| 479 | + ) |
| 480 | + |
| 481 | + response = requests_pipeline.get(request_url, headers=request_headers) |
| 482 | + if response.status_code != 200: |
| 483 | + # If API is unresponsive, create new asset |
| 484 | + return None |
| 485 | + |
| 486 | + response_json = json.loads(response.text()) |
| 487 | + existing_asset["name"] = response_json["name"] |
| 488 | + existing_asset["version"] = response_json["version"] |
| 489 | + |
| 490 | + return existing_asset |
| 491 | + |
| 492 | + |
362 | 493 | T = TypeVar("T", bound=Artifact) |
363 | 494 |
|
364 | 495 |
|
@@ -422,6 +553,193 @@ def _check_and_upload_path( |
422 | 553 | return artifact, indicator_file |
423 | 554 |
|
424 | 555 |
|
| 556 | +def _get_snapshot_path_info(artifact) -> Tuple[str, str, str]: |
| 557 | + """ |
| 558 | + Validate an Artifact's local path and get its resolved path, ignore file, and hash |
| 559 | + :param artifact: Artifact object |
| 560 | + :type artifact: azure.ai.ml.entities._assets._artifacts.artifact.Artifact |
| 561 | + :return: Artifact's path, ignorefile, and hash |
| 562 | + :rtype: Tuple[str, str, str] |
| 563 | + """ |
| 564 | + if ( |
| 565 | + hasattr(artifact, "local_path") |
| 566 | + and artifact.local_path is not None |
| 567 | + or ( |
| 568 | + hasattr(artifact, "path") |
| 569 | + and artifact.path is not None |
| 570 | + and not (is_url(artifact.path) or is_mlflow_uri(artifact.path)) |
| 571 | + ) |
| 572 | + ): |
| 573 | + path = ( |
| 574 | + Path(artifact.path) |
| 575 | + if hasattr(artifact, "path") and artifact.path is not None |
| 576 | + else Path(artifact.local_path) |
| 577 | + ) |
| 578 | + if not path.is_absolute(): |
| 579 | + path = Path(artifact.base_path, path).resolve() |
| 580 | + |
| 581 | + _validate_path(path, _type=ErrorTarget.CODE) |
| 582 | + |
| 583 | + ignore_file = get_ignore_file(path) |
| 584 | + asset_hash = get_content_hash(path, ignore_file) |
| 585 | + |
| 586 | + return path, ignore_file, asset_hash |
| 587 | + |
| 588 | + |
| 589 | +def _get_existing_snapshot_by_hash( |
| 590 | + datastore_operation, |
| 591 | + asset_hash, |
| 592 | + workspace: Workspace, |
| 593 | + requests_pipeline: HttpPipeline, |
| 594 | +) -> Dict[str, str]: |
| 595 | + """ |
| 596 | + Check if an asset with the same hash already exists in the workspace. If so, return the asset name and version. |
| 597 | + :param datastore_operation: Datastore Operations object from MLClient |
| 598 | + :type operations: azure.ai.ml.operations._datastore_operations.DatastoreOperations |
| 599 | + :param asset_hash: The hash of the specified local upload |
| 600 | + :type asset_hash: str |
| 601 | + :param workspace: Workspace object |
| 602 | + :type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace |
| 603 | + :param requests_pipeline: Proxy for sending HTTP requests |
| 604 | + :type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline |
| 605 | + :return: Existing asset's name and version, if found |
| 606 | + :rtype: Optional[Dict[str, str]] |
| 607 | + """ |
| 608 | + ws_base_url = datastore_operation._operation._client._base_url |
| 609 | + token = datastore_operation._credential.get_token(ws_base_url + "/.default").token |
| 610 | + request_headers = {"Authorization": "Bearer " + token} |
| 611 | + request_headers["Content-Type"] = "application/json; charset=UTF-8" |
| 612 | + |
| 613 | + existing_asset = _get_asset_by_hash( |
| 614 | + operations=datastore_operation, |
| 615 | + hash_str=asset_hash, |
| 616 | + request_headers=request_headers, |
| 617 | + workspace=workspace, |
| 618 | + requests_pipeline=requests_pipeline, |
| 619 | + ) |
| 620 | + |
| 621 | + return existing_asset |
| 622 | + |
| 623 | + |
| 624 | +def _upload_snapshot_to_datastore( |
| 625 | + operation_scope: OperationScope, |
| 626 | + datastore_operation: DatastoreOperations, |
| 627 | + path: Union[str, Path, os.PathLike], |
| 628 | + workspace: Workspace, |
| 629 | + requests_pipeline: HttpPipeline, |
| 630 | + datastore_name: str = None, |
| 631 | + show_progress: bool = True, |
| 632 | + asset_name: str = None, |
| 633 | + asset_version: str = None, |
| 634 | + asset_hash: str = None, |
| 635 | + ignore_file: IgnoreFile = IgnoreFile(), |
| 636 | + sas_uri: str = None, # contains registry sas url |
| 637 | +) -> ArtifactStorageInfo: |
| 638 | + """ |
| 639 | + Upload a code snapshot to workspace datastore. |
| 640 | + :param operation_scope: Workspace scope information |
| 641 | + :type operation_scope: azure.ai.ml._scope_dependent_operations.OperationScope |
| 642 | + :param datastore_operation: Datastore Operations object from MLClient |
| 643 | + :type datastore_operation: azure.ai.ml.operations._datastore_operations.DatastoreOperations |
| 644 | + :param path: The local path of the artifact |
| 645 | + :type path: Union[str, Path, os.PathLike] |
| 646 | + :param workspace: Workspace object |
| 647 | + :type workspace: azure.ai.ml._restclient.v2022_05_01.models.Workspace |
| 648 | + :param requests_pipeline: Proxy for sending HTTP requests |
| 649 | + :type requests_pipeline: azure.ai.ml._utils._http_utils.HttpPipeline |
| 650 | + :param datastore_name: Name of the datastore to upload to |
| 651 | + :type datastore_name: str |
| 652 | + :param show_progress: Whether or not to show progress bar during upload, defaults to True |
| 653 | + :type show_progress: bool |
| 654 | + :param asset_name: Name of the asset to be created |
| 655 | + :type asset_name: str |
| 656 | + :param asset_version: Version of the asset to be created |
| 657 | + :type asset_version: str |
| 658 | + :param asset_hash: The hash of the specified local upload |
| 659 | + :type asset_hash: str |
| 660 | + :param ignore_file: Information about the path's .gitignore or .amlignore file, if exists |
| 661 | + :type ignore_file: azure.ai.ml._utils._asset_utils.IgnoreFile |
| 662 | + :param sas_uri: SAS uri for uploading to datastore |
| 663 | + :type sas_uri: str |
| 664 | + :return: Uploaded artifact's storage information |
| 665 | + :rtype: azure.ai.ml.entities._assets._artifacts.artifact.ArtifactStorageInfo |
| 666 | + """ |
| 667 | + ws_base_url = datastore_operation._operation._client._base_url |
| 668 | + token = datastore_operation._credential.get_token(ws_base_url + "/.default").token |
| 669 | + request_headers = {"Authorization": "Bearer " + token} |
| 670 | + request_headers["Content-Type"] = "application/json; charset=UTF-8" |
| 671 | + |
| 672 | + if not sas_uri: |
| 673 | + sas_uri, blob_uri = _get_snapshot_temporary_data_reference( |
| 674 | + requests_pipeline=requests_pipeline, |
| 675 | + asset_name=asset_name, |
| 676 | + asset_version=asset_version, |
| 677 | + request_headers=request_headers, |
| 678 | + workspace=workspace, |
| 679 | + ) |
| 680 | + |
| 681 | + artifact = upload_artifact( |
| 682 | + str(path), |
| 683 | + datastore_operation, |
| 684 | + operation_scope, |
| 685 | + datastore_name, |
| 686 | + show_progress=show_progress, |
| 687 | + asset_hash=asset_hash, |
| 688 | + asset_name=asset_name, |
| 689 | + asset_version=asset_version, |
| 690 | + ignore_file=ignore_file, |
| 691 | + sas_uri=sas_uri, |
| 692 | + ) |
| 693 | + artifact.storage_account_url = blob_uri |
| 694 | + |
| 695 | + return artifact |
| 696 | + |
| 697 | + |
| 698 | +def _check_and_upload_snapshot( |
| 699 | + artifact: T, |
| 700 | + asset_operations: Union["DataOperations", "ModelOperations", "CodeOperations"], |
| 701 | + path: Union[str, Path, os.PathLike], |
| 702 | + workspace: Workspace, |
| 703 | + requests_pipeline: HttpPipeline, |
| 704 | + ignore_file: IgnoreFile = None, |
| 705 | + datastore_name: str = WORKSPACE_BLOB_STORE, |
| 706 | + sas_uri: str = None, |
| 707 | + show_progress: bool = True, |
| 708 | +) -> Tuple[T, str]: |
| 709 | + """Checks whether `artifact` is a path or a uri and uploads it to the |
| 710 | + datastore if necessary. |
| 711 | + param T artifact: artifact to check and upload param |
| 712 | + Union["DataOperations", "ModelOperations", "CodeOperations"] |
| 713 | + asset_operations: the asset operations to use for uploading |
| 714 | + param str datastore_name: the name of the datastore to upload to |
| 715 | + param str sas_uri: the sas uri to use for uploading |
| 716 | + """ |
| 717 | + uploaded_artifact = _upload_snapshot_to_datastore( |
| 718 | + operation_scope=asset_operations._operation_scope, |
| 719 | + datastore_operation=asset_operations._datastore_operation, |
| 720 | + path=path, |
| 721 | + workspace=workspace, |
| 722 | + requests_pipeline=requests_pipeline, |
| 723 | + datastore_name=datastore_name, |
| 724 | + asset_name=artifact.name, |
| 725 | + asset_version=str(artifact.version), |
| 726 | + asset_hash=artifact._upload_hash if hasattr(artifact, "_upload_hash") else None, |
| 727 | + show_progress=show_progress, |
| 728 | + sas_uri=sas_uri, |
| 729 | + ignore_file=ignore_file, |
| 730 | + ) |
| 731 | + |
| 732 | + if artifact._is_anonymous: |
| 733 | + artifact.name, artifact.version = ( |
| 734 | + uploaded_artifact.name, |
| 735 | + uploaded_artifact.version, |
| 736 | + ) |
| 737 | + # Pass all of the upload information to the assets, and they will each construct the URLs that they support |
| 738 | + artifact._update_path(uploaded_artifact) |
| 739 | + |
| 740 | + return artifact |
| 741 | + |
| 742 | + |
425 | 743 | def _check_and_upload_env_build_context( |
426 | 744 | environment: Environment, |
427 | 745 | operations: "EnvironmentOperations", |
|
0 commit comments