22# Copyright (c) Microsoft Corporation. All rights reserved.
33# ---------------------------------------------------------
44
5- # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes,client-method-missing-type-annotations,missing-client-constructor-parameter-kwargs
5+ # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes,client-method-missing-type-annotations,missing-client-constructor-parameter-kwargs,logging-format-interpolation
66
77import logging
88import os
1414
1515from colorama import Fore
1616
17- from azure .ai .ml ._artifacts ._constants import UPLOAD_CONFIRMATION
17+ from azure .ai .ml ._artifacts ._constants import UPLOAD_CONFIRMATION , FILE_SIZE_WARNING
1818from azure .ai .ml ._utils ._asset_utils import (
1919 AssetNotChangedError ,
2020 IgnoreFile ,
2121 _build_metadata_dict ,
2222 generate_asset_id ,
23+ get_directory_size ,
2324 upload_directory ,
2425 upload_file ,
2526)
27+ from azure .ai .ml ._azure_environments import _get_cloud_details
2628from azure .ai .ml .constants ._common import STORAGE_AUTH_MISMATCH_ERROR
2729from azure .ai .ml .exceptions import ErrorCategory , ErrorTarget , MlException , ValidationException
2830from azure .core .exceptions import ResourceExistsError
3436class Gen2StorageClient :
3537 def __init__ (self , credential : str , file_system : str , account_url : str ):
3638 service_client = DataLakeServiceClient (account_url = account_url , credential = credential )
39+ self .account_name = account_url .split ("." )[0 ].split ("//" )[1 ]
3740 self .file_system = file_system
3841 self .file_system_client = service_client .get_file_system_client (file_system = file_system )
3942 try :
@@ -77,6 +80,16 @@ def upload(
7780 # configure progress bar description
7881 msg = Fore .GREEN + f"Uploading { formatted_path } "
7982
83+ # warn if large file (> 100 MB)
84+ file_size , _ = get_directory_size (source )
85+ file_size_in_mb = file_size / 10 ** 6
86+
87+ cloud = _get_cloud_details ()
88+ cloud_endpoint = cloud ['storage_endpoint' ] # make sure proper cloud endpoint is used
89+ full_storage_url = f"https://{ self .account_name } .dfs.{ cloud_endpoint } /{ self .file_system } /{ dest } "
90+ if file_size_in_mb > 100 :
91+ module_logger .warning (FILE_SIZE_WARNING .format (source = source , destination = full_storage_url ))
92+
8093 # start upload
8194 self .directory_client = self .file_system_client .get_directory_client (asset_id )
8295 self .check_blob_exists ()
@@ -159,6 +172,7 @@ def download(self, starts_with: str, destination: str = Path.home()) -> None:
159172 prefix `starts_with` to the destination folder."""
160173 try :
161174 mylist = self .file_system_client .get_paths (path = starts_with )
175+ download_size_in_mb = 0
162176 for item in mylist :
163177 file_name = item .name [len (starts_with ) :].lstrip ("/" ) or Path (starts_with ).name
164178
@@ -168,6 +182,15 @@ def download(self, starts_with: str, destination: str = Path.home()) -> None:
168182
169183 target_path = Path (destination , file_name )
170184 file_client = self .file_system_client .get_file_client (item .name )
185+
186+ # check if total size of download has exceeded 100 MB
187+ cloud = _get_cloud_details ()
188+ cloud_endpoint = cloud ['storage_endpoint' ] # make sure proper cloud endpoint is used
189+ full_storage_url = f"https://{ self .account_name } .dfs.{ cloud_endpoint } /{ self .file_system } /{ starts_with } "
190+ download_size_in_mb += (file_client .get_file_properties ().size / 10 ** 6 )
191+ if download_size_in_mb > 100 :
192+ module_logger .warning (FILE_SIZE_WARNING .format (source = full_storage_url , destination = destination ))
193+
171194 file_content = file_client .download_file ().readall ()
172195 try :
173196 os .makedirs (str (target_path .parent ), exist_ok = True )
0 commit comments