Skip to content

Commit 4318a6a

Browse files
authored
Support for data share() (Azure#29894)
* first * data * ggg * remove the testcase that is not valid anymore * import order
1 parent 235d855 commit 4318a6a

File tree

2 files changed

+61
-55
lines changed

2 files changed

+61
-55
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_data_operations.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import os
88
from pathlib import Path
99
from typing import Dict, List, Optional, Union, Iterable
10+
from contextlib import contextmanager
1011

1112
from marshmallow.exceptions import ValidationError as SchemaValidationError
13+
from azure.ai.ml._utils._registry_utils import get_registry_client
1214

1315
from azure.ai.ml._utils._experimental import experimental
1416
from azure.ai.ml.entities import PipelineJob, PipelineJobSettings
@@ -278,11 +280,16 @@ def create_or_update(self, data: Data) -> Data:
278280
target=ErrorTarget.DATA,
279281
error_category=ErrorCategory.USER_ERROR,
280282
)
281-
data = data._to_rest_object()
283+
data_res_obj = data._to_rest_object()
282284
result = self._service_client.resource_management_asset_reference.begin_import_method(
283-
resource_group_name=self._resource_group_name, registry_name=self._registry_name, body=data
284-
)
285-
return result
285+
resource_group_name=self._resource_group_name,
286+
registry_name=self._registry_name,
287+
body=data_res_obj,
288+
).result()
289+
290+
if not result:
291+
data_res_obj = self._get(name=data.name, version=data.version)
292+
return Data._from_rest_object(data_res_obj)
286293

287294
sas_uri = get_sas_uri_for_registry_asset(
288295
service_client=self._service_client,
@@ -544,19 +551,24 @@ def _get_latest_version(self, name: str) -> Data:
544551
)
545552
return self.get(name, version=latest_version)
546553

547-
# pylint: disable=no-self-use
548-
def _prepare_to_copy(
549-
self, data: Data, name: Optional[str] = None, version: Optional[str] = None
550-
) -> WorkspaceAssetReference:
551-
"""Returns WorkspaceAssetReference to copy a registered data to registry given the asset id.
554+
@monitor_with_activity(logger, "data.Share", ActivityType.PUBLICAPI)
555+
def share(self, name, version, *, share_with_name, share_with_version, registry_name) -> Data:
556+
"""Share a data asset from workspace to registry.
552557
553-
:param data: Registered data
554-
:type data: Data
555-
:param name: Destination name
558+
:param name: Name of data asset.
556559
:type name: str
557-
:param version: Destination version
560+
:param version: Version of data asset.
558561
:type version: str
562+
:param share_with_name: Name of data asset to share with.
563+
:type share_with_name: str
564+
:param share_with_version: Version of data asset to share with.
565+
:type share_with_version: str
566+
:param registry_name: Name of the destination registry.
567+
:type registry_name: str
568+
:return: Data asset object.
569+
:rtype: ~azure.ai.ml.entities.Data
559570
"""
571+
560572
# Get workspace info to get workspace GUID
561573
workspace = self._service_client.workspaces.get(
562574
resource_group_name=self._resource_group_name, workspace_name=self._workspace_name
@@ -569,16 +581,47 @@ def _prepare_to_copy(
569581
workspace_location,
570582
workspace_guid,
571583
AzureMLResourceType.DATA,
572-
data.name,
573-
data.version,
584+
name,
585+
version,
574586
)
575587

576-
return WorkspaceAssetReference(
577-
name=name if name else data.name,
578-
version=version if version else data.version,
588+
data_ref = WorkspaceAssetReference(
589+
name=share_with_name if share_with_name else name,
590+
version=share_with_version if share_with_version else version,
579591
asset_id=asset_id,
580592
)
581593

594+
with self._set_registry_client(registry_name):
595+
return self.create_or_update(data_ref)
596+
597+
@contextmanager
598+
def _set_registry_client(self, registry_name: str) -> None:
599+
"""Sets the registry client for the data operations.
600+
601+
:param registry_name: Name of the registry.
602+
:type registry_name: str
603+
"""
604+
rg_ = self._operation_scope._resource_group_name
605+
sub_ = self._operation_scope._subscription_id
606+
registry_ = self._operation_scope.registry_name
607+
client_ = self._service_client
608+
data_versions_operation_ = self._operation
609+
610+
try:
611+
_client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name)
612+
self._operation_scope.registry_name = registry_name
613+
self._operation_scope._resource_group_name = _rg
614+
self._operation_scope._subscription_id = _sub
615+
self._service_client = _client
616+
self._operation = _client.data_versions
617+
yield
618+
finally:
619+
self._operation_scope.registry_name = registry_
620+
self._operation_scope._resource_group_name = rg_
621+
self._operation_scope._subscription_id = sub_
622+
self._service_client = client_
623+
self._operation = data_versions_operation_
624+
582625

583626
def _assert_local_path_matches_asset_type(
584627
local_path: str,

sdk/ml/azure-ai-ml/tests/dataset/unittests/test_data_operations.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -563,40 +563,3 @@ def test_create_with_datastore(
563563
show_progress=True,
564564
ignore_file=None,
565565
)
566-
567-
def test_promote_data_from_workspace(
568-
self, mock_data_operations_in_registry: DataOperations, mock_data_operations: DataOperations, tmp_path: Path
569-
) -> None:
570-
data_asset_name = f"data_random_string"
571-
p = tmp_path / "data_full.yml"
572-
data_path = tmp_path / "data.pkl"
573-
data_path.write_text("hello world")
574-
p.write_text(
575-
f"""
576-
name: {data_asset_name}
577-
path: ./data.pkl
578-
version: 3"""
579-
)
580-
581-
with patch(
582-
"azure.ai.ml._artifacts._artifact_utilities._upload_to_datastore",
583-
return_value=ArtifactStorageInfo(
584-
name=data_asset_name,
585-
version="3",
586-
relative_path="path",
587-
datastore_arm_id="/subscriptions/mock/resourceGroups/mock/providers/Microsoft.MachineLearningServices/workspaces/mock/datastores/datastore_id",
588-
container_name="containerName",
589-
),
590-
) as mock_upload, patch(
591-
"azure.ai.ml.operations._data_operations.Data._from_rest_object",
592-
return_value=Data(),
593-
):
594-
data = load_data(source=p)
595-
data_to_promote = mock_data_operations._prepare_to_copy(data, "new_name", "new_version")
596-
assert data_to_promote.name == "new_name"
597-
assert data_to_promote.version == "new_version"
598-
mock_data_operations_in_registry._operation.get.side_effect = Mock(
599-
side_effect=ResourceNotFoundError("Test")
600-
)
601-
mock_data_operations_in_registry.create_or_update(data_to_promote)
602-
mock_data_operations_in_registry._service_client.resource_management_asset_reference.begin_import_method.assert_called_once()

0 commit comments

Comments
 (0)