77import os
88from pathlib import Path
99from typing import Dict , List , Optional , Union , Iterable
10+ from contextlib import contextmanager
1011
1112from marshmallow .exceptions import ValidationError as SchemaValidationError
13+ from azure .ai .ml ._utils ._registry_utils import get_registry_client
1214
1315from azure .ai .ml ._utils ._experimental import experimental
1416from azure .ai .ml .entities import PipelineJob , PipelineJobSettings
@@ -278,11 +280,16 @@ def create_or_update(self, data: Data) -> Data:
278280 target = ErrorTarget .DATA ,
279281 error_category = ErrorCategory .USER_ERROR ,
280282 )
281- data = data ._to_rest_object ()
283+ data_res_obj = data ._to_rest_object ()
282284 result = self ._service_client .resource_management_asset_reference .begin_import_method (
283- resource_group_name = self ._resource_group_name , registry_name = self ._registry_name , body = data
284- )
285- return result
285+ resource_group_name = self ._resource_group_name ,
286+ registry_name = self ._registry_name ,
287+ body = data_res_obj ,
288+ ).result ()
289+
290+ if not result :
291+ data_res_obj = self ._get (name = data .name , version = data .version )
292+ return Data ._from_rest_object (data_res_obj )
286293
287294 sas_uri = get_sas_uri_for_registry_asset (
288295 service_client = self ._service_client ,
@@ -544,19 +551,24 @@ def _get_latest_version(self, name: str) -> Data:
544551 )
545552 return self .get (name , version = latest_version )
546553
547- # pylint: disable=no-self-use
548- def _prepare_to_copy (
549- self , data : Data , name : Optional [str ] = None , version : Optional [str ] = None
550- ) -> WorkspaceAssetReference :
551- """Returns WorkspaceAssetReference to copy a registered data to registry given the asset id.
554+ @monitor_with_activity (logger , "data.Share" , ActivityType .PUBLICAPI )
555+ def share (self , name , version , * , share_with_name , share_with_version , registry_name ) -> Data :
556+ """Share a data asset from workspace to registry.
552557
553- :param data: Registered data
554- :type data: Data
555- :param name: Destination name
558+ :param name: Name of data asset.
556559 :type name: str
557- :param version: Destination version
560+ :param version: Version of data asset.
558561 :type version: str
562+ :param share_with_name: Name of data asset to share with.
563+ :type share_with_name: str
564+ :param share_with_version: Version of data asset to share with.
565+ :type share_with_version: str
566+ :param registry_name: Name of the destination registry.
567+ :type registry_name: str
568+ :return: Data asset object.
569+ :rtype: ~azure.ai.ml.entities.Data
559570 """
571+
560572 # Get workspace info to get workspace GUID
561573 workspace = self ._service_client .workspaces .get (
562574 resource_group_name = self ._resource_group_name , workspace_name = self ._workspace_name
@@ -569,16 +581,47 @@ def _prepare_to_copy(
569581 workspace_location ,
570582 workspace_guid ,
571583 AzureMLResourceType .DATA ,
572- data . name ,
573- data . version ,
584+ name ,
585+ version ,
574586 )
575587
576- return WorkspaceAssetReference (
577- name = name if name else data . name ,
578- version = version if version else data . version ,
588+ data_ref = WorkspaceAssetReference (
589+ name = share_with_name if share_with_name else name ,
590+ version = share_with_version if share_with_version else version ,
579591 asset_id = asset_id ,
580592 )
581593
594+ with self ._set_registry_client (registry_name ):
595+ return self .create_or_update (data_ref )
596+
597+ @contextmanager
598+ def _set_registry_client (self , registry_name : str ) -> None :
599+ """Sets the registry client for the data operations.
600+
601+ :param registry_name: Name of the registry.
602+ :type registry_name: str
603+ """
604+ rg_ = self ._operation_scope ._resource_group_name
605+ sub_ = self ._operation_scope ._subscription_id
606+ registry_ = self ._operation_scope .registry_name
607+ client_ = self ._service_client
608+ data_versions_operation_ = self ._operation
609+
610+ try :
611+ _client , _rg , _sub = get_registry_client (self ._service_client ._config .credential , registry_name )
612+ self ._operation_scope .registry_name = registry_name
613+ self ._operation_scope ._resource_group_name = _rg
614+ self ._operation_scope ._subscription_id = _sub
615+ self ._service_client = _client
616+ self ._operation = _client .data_versions
617+ yield
618+ finally :
619+ self ._operation_scope .registry_name = registry_
620+ self ._operation_scope ._resource_group_name = rg_
621+ self ._operation_scope ._subscription_id = sub_
622+ self ._service_client = client_
623+ self ._operation = data_versions_operation_
624+
582625
583626def _assert_local_path_matches_asset_type (
584627 local_path : str ,
0 commit comments