2323# IN THE SOFTWARE.
2424#
2525# --------------------------------------------------------------------------
26-
26+ from __future__ import annotations
2727import json
2828import logging
2929import sys
3030
31- from typing import Callable , Any , Optional , Union , Type , List , Dict , TYPE_CHECKING
31+ from types import TracebackType
32+ from typing import (
33+ Callable ,
34+ Any ,
35+ Optional ,
36+ Union ,
37+ Type ,
38+ List ,
39+ Mapping ,
40+ TypeVar ,
41+ Generic ,
42+ Dict ,
43+ TYPE_CHECKING ,
44+ )
45+ from typing_extensions import Protocol , runtime_checkable
3246
3347_LOGGER = logging .getLogger (__name__ )
3448
3549if TYPE_CHECKING :
36- from azure .core .pipeline .transport ._base import _HttpResponseBase
3750 from azure .core .pipeline .policies import RequestHistory
3851
52+ HTTPResponseType = TypeVar ("HTTPResponseType" )
53+ HTTPRequestType = TypeVar ("HTTPRequestType" )
54+ KeyType = TypeVar ("KeyType" )
55+ ValueType = TypeVar ("ValueType" )
56+ # To replace when typing.Self is available in our baseline
57+ SelfODataV4Format = TypeVar ("SelfODataV4Format" , bound = "ODataV4Format" )
58+
3959
4060__all__ = [
4161 "AzureError" ,
5979]
6080
6181
62- def raise_with_traceback (exception : Callable , * args , ** kwargs ) -> None :
82+ def raise_with_traceback (exception : Callable , * args : Any , ** kwargs : Any ) -> None :
6383 """Raise exception with a specified traceback.
6484 This MUST be called inside a "except" clause.
6585
@@ -83,26 +103,58 @@ def raise_with_traceback(exception: Callable, *args, **kwargs) -> None:
83103 raise error # pylint: disable=raise-missing-from
84104
85105
86- class ErrorMap :
106+ @runtime_checkable
107+ class _HttpResponseCommonAPI (Protocol ):
108+ """Protocol used by exceptions for HTTP response.
109+
110+ As HttpResponseError uses very few properties of HttpResponse, a protocol
111+ is faster and simpler than import all the possible types (at least 6).
112+ """
113+
114+ @property
115+ def reason (self ) -> Optional [str ]:
116+ pass
117+
118+ @property
119+ def status_code (self ) -> Optional [int ]:
120+ pass
121+
122+ def text (self ) -> str :
123+ pass
124+
125+ @property
126+ def request (self ) -> object : # object as type, since all we need is str() on it
127+ pass
128+
129+
130+ class ErrorMap (Generic [KeyType , ValueType ]):
87131 """Error Map class. To be used in map_error method, behaves like a dictionary.
88132 It returns the error type if it is found in custom_error_map. Or return default_error
89133
90134 :param dict custom_error_map: User-defined error map, it is used to map status codes to error types.
91135 :keyword error default_error: Default error type. It is returned if the status code is not found in custom_error_map
92136 """
93137
94- def __init__ (self , custom_error_map = None , ** kwargs ):
138+ def __init__ (
139+ self , # pylint: disable=unused-argument
140+ custom_error_map : Optional [Mapping [KeyType , ValueType ]] = None ,
141+ * ,
142+ default_error : Optional [ValueType ] = None ,
143+ ** kwargs : Any ,
144+ ) -> None :
95145 self ._custom_error_map = custom_error_map or {}
96- self ._default_error = kwargs . pop ( " default_error" , None )
146+ self ._default_error = default_error
97147
98- def get (self , key ) :
148+ def get (self , key : KeyType ) -> Optional [ ValueType ] :
99149 ret = self ._custom_error_map .get (key )
100150 if ret :
101151 return ret
102152 return self ._default_error
103153
104154
105- def map_error (status_code , response , error_map ):
155+ def map_error (
156+ status_code : int , response : _HttpResponseCommonAPI , error_map : Mapping [int , Type [HttpResponseError ]]
157+ ) -> None :
106158 if not error_map :
107159 return
108160 error_type = error_map .get (status_code )
@@ -157,7 +209,7 @@ class ODataV4Format:
157209 DETAILS_LABEL = "details"
158210 INNERERROR_LABEL = "innererror"
159211
160- def __init__ (self , json_object : Dict [str , Any ]):
212+ def __init__ (self , json_object : Mapping [str , Any ]) -> None :
161213 if "error" in json_object :
162214 json_object = json_object ["error" ]
163215 cls : Type [ODataV4Format ] = self .__class__
@@ -180,10 +232,10 @@ def __init__(self, json_object: Dict[str, Any]):
180232 except Exception : # pylint: disable=broad-except
181233 pass
182234
183- self .innererror : Dict [str , Any ] = json_object .get (cls .INNERERROR_LABEL , {})
235+ self .innererror : Mapping [str , Any ] = json_object .get (cls .INNERERROR_LABEL , {})
184236
185237 @property
186- def error (self ) :
238+ def error (self : SelfODataV4Format ) -> SelfODataV4Format :
187239 import warnings
188240
189241 warnings .warn (
@@ -192,7 +244,7 @@ def error(self):
192244 )
193245 return self
194246
195- def __str__ (self ):
247+ def __str__ (self ) -> str :
196248 return "({}) {}\n {}" .format (self .code , self .message , self .message_details ())
197249
198250 def message_details (self ) -> str :
@@ -220,7 +272,7 @@ def message_details(self) -> str:
220272class AzureError (Exception ):
221273 """Base exception for all errors.
222274
223- :param message: The message object stringified as 'message' attribute
275+ :param object message: The message object stringified as 'message' attribute
224276 :keyword error: The original exception if any
225277 :paramtype error: Exception
226278
@@ -235,16 +287,21 @@ class AzureError(Exception):
235287 and will be `None` where continuation is either unavailable or not applicable.
236288 """
237289
238- def __init__ (self , message , * args , ** kwargs ):
239- self .inner_exception = kwargs .get ("error" )
240- self .exc_type , self .exc_value , self .exc_traceback = sys .exc_info ()
241- self .exc_type = self .exc_type .__name__ if self .exc_type else type (self .inner_exception )
242- self .exc_msg = "{}, {}: {}" .format (message , self .exc_type , self .exc_value )
243- self .message = str (message )
244- self .continuation_token = kwargs .get ("continuation_token" )
290+ def __init__ (self , message : Optional [object ], * args : Any , ** kwargs : Any ) -> None :
291+ self .inner_exception : Optional [BaseException ] = kwargs .get ("error" )
292+
293+ exc_info = sys .exc_info ()
294+ self .exc_type : Optional [Type [Any ]] = exc_info [0 ]
295+ self .exc_value : Optional [BaseException ] = exc_info [1 ]
296+ self .exc_traceback : Optional [TracebackType ] = exc_info [2 ]
297+
298+ self .exc_type = self .exc_type if self .exc_type else type (self .inner_exception )
299+ self .exc_msg : str = "{}, {}: {}" .format (message , self .exc_type .__name__ , self .exc_value )
300+ self .message : Optional [str ] = str (message )
301+ self .continuation_token : Optional [str ] = kwargs .get ("continuation_token" )
245302 super (AzureError , self ).__init__ (self .message , * args )
246303
247- def raise_with_traceback (self ):
304+ def raise_with_traceback (self ) -> None :
248305 """Raise the exception with the existing traceback.
249306
250307 .. deprecated:: 1.22.0
@@ -253,7 +310,7 @@ def raise_with_traceback(self):
253310 try :
254311 raise super (AzureError , self ).with_traceback (self .exc_traceback ) # pylint: disable=raise-missing-from
255312 except AttributeError :
256- self .__traceback__ = self .exc_traceback
313+ self .__traceback__ : Optional [ TracebackType ] = self .exc_traceback
257314 raise self # pylint: disable=raise-missing-from
258315
259316
@@ -280,8 +337,7 @@ class ServiceResponseTimeoutError(ServiceResponseError):
280337class HttpResponseError (AzureError ):
281338 """A request was made, and a non-success status code was received from the service.
282339
283- :param message: HttpResponse's error message
284- :type message: string
340+ :param object message: The message object stringified as 'message' attribute
285341 :param response: The response that triggered the exception.
286342 :type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse
287343
@@ -297,24 +353,27 @@ class HttpResponseError(AzureError):
297353 :vartype error: ODataV4Format
298354 """
299355
300- def __init__ (self , message = None , response = None , ** kwargs ):
356+ def __init__ (
357+ self , message : Optional [object ] = None , response : Optional [_HttpResponseCommonAPI ] = None , ** kwargs : Any
358+ ) -> None :
301359 # Don't want to document this one yet.
302360 error_format = kwargs .get ("error_format" , ODataV4Format )
303361
304- self .reason = None
305- self .status_code = None
306- self .response = response
362+ self .reason : Optional [ str ] = None
363+ self .status_code : Optional [ int ] = None
364+ self .response : Optional [ _HttpResponseCommonAPI ] = response
307365 if response :
308366 self .reason = response .reason
309367 self .status_code = response .status_code
310368
311369 # old autorest are setting "error" before calling __init__, so it might be there already
312370 # transferring into self.model
313371 model : Optional [Any ] = kwargs .pop ("model" , None )
372+ self .model : Optional [Any ]
314373 if model is not None : # autorest v5
315374 self .model = model
316375 else : # autorest azure-core, for KV 1.0, Storage 12.0, etc.
317- self .model : Optional [ Any ] = getattr (self , "error" , None )
376+ self .model = getattr (self , "error" , None )
318377 self .error : Optional [ODataV4Format ] = self ._parse_odata_body (error_format , response )
319378
320379 # By priority, message is:
@@ -329,19 +388,23 @@ def __init__(self, message=None, response=None, **kwargs):
329388 super (HttpResponseError , self ).__init__ (message = message , ** kwargs )
330389
331390 @staticmethod
332- def _parse_odata_body (error_format : Type [ODataV4Format ], response : "_HttpResponseBase" ) -> Optional [ODataV4Format ]:
391+ def _parse_odata_body (
392+ error_format : Type [ODataV4Format ], response : Optional [_HttpResponseCommonAPI ]
393+ ) -> Optional [ODataV4Format ]:
333394 try :
334- odata_json = json .loads (response .text ())
395+ # https://github.com/python/mypy/issues/14743#issuecomment-1664725053
396+ odata_json = json .loads (response .text ()) # type: ignore
335397 return error_format (odata_json )
336398 except Exception : # pylint: disable=broad-except
337399 # If the body is not JSON valid, just stop now
338400 pass
339401 return None
340402
341- def __str__ (self ):
403+ def __str__ (self ) -> str :
342404 retval = super (HttpResponseError , self ).__str__ ()
343405 try :
344- body = self .response .text ()
406+ # https://github.com/python/mypy/issues/14743#issuecomment-1664725053
407+ body = self .response .text () # type: ignore
345408 if body and not self .error :
346409 return "{}\n Content: {}" .format (retval , body )[:2048 ]
347410 except Exception : # pylint: disable=broad-except
@@ -381,14 +444,16 @@ class ResourceNotModifiedError(HttpResponseError):
381444 This will not be raised directly by the Azure core pipeline."""
382445
383446
384- class TooManyRedirectsError (HttpResponseError ):
447+ class TooManyRedirectsError (HttpResponseError , Generic [ HTTPRequestType , HTTPResponseType ] ):
385448 """Reached the maximum number of redirect attempts.
386449
387450 :param history: The history of requests made while trying to fulfill the request.
388451 :type history: list[~azure.core.pipeline.policies.RequestHistory]
389452 """
390453
391- def __init__ (self , history , * args , ** kwargs ):
454+ def __init__ (
455+ self , history : "List[RequestHistory[HTTPRequestType, HTTPResponseType]]" , * args : Any , ** kwargs : Any
456+ ) -> None :
392457 self .history = history
393458 message = "Reached maximum redirect attempts."
394459 super (TooManyRedirectsError , self ).__init__ (message , * args , ** kwargs )
@@ -414,7 +479,7 @@ class ODataV4Error(HttpResponseError):
414479
415480 _ERROR_FORMAT = ODataV4Format
416481
417- def __init__ (self , response : "_HttpResponseBase" , ** kwargs ) -> None :
482+ def __init__ (self , response : _HttpResponseCommonAPI , ** kwargs : Any ) -> None :
418483 # Ensure field are declared, whatever can happen afterwards
419484 self .odata_json : Optional [Dict [str , Any ]] = None
420485 try :
@@ -428,7 +493,7 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
428493 self .message : Optional [str ] = kwargs .get ("message" , odata_message )
429494 self .target : Optional [str ] = None
430495 self .details : Optional [List [Any ]] = []
431- self .innererror : Optional [Dict [str , Any ]] = {}
496+ self .innererror : Optional [Mapping [str , Any ]] = {}
432497
433498 if self .message and "message" not in kwargs :
434499 kwargs ["message" ] = self .message
@@ -445,7 +510,7 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
445510 _LOGGER .info ("Received error message was not valid OdataV4 format." )
446511 self ._error_format = "JSON was invalid for format " + str (self ._ERROR_FORMAT )
447512
448- def __str__ (self ):
513+ def __str__ (self ) -> str :
449514 if self ._error_format :
450515 return str (self ._error_format )
451516 return super (ODataV4Error , self ).__str__ ()
@@ -461,7 +526,7 @@ class StreamConsumedError(AzureError):
461526 :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
462527 """
463528
464- def __init__ (self , response ) :
529+ def __init__ (self , response : _HttpResponseCommonAPI ) -> None :
465530 message = (
466531 "You are attempting to read or stream the content from request {}. "
467532 "You have likely already consumed this stream, so it can not be accessed anymore." .format (response .request )
@@ -479,7 +544,7 @@ class StreamClosedError(AzureError):
479544 :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
480545 """
481546
482- def __init__ (self , response ) :
547+ def __init__ (self , response : _HttpResponseCommonAPI ) -> None :
483548 message = (
484549 "The content for response from request {} can no longer be read or streamed, since the "
485550 "response has already been closed." .format (response .request )
@@ -497,7 +562,7 @@ class ResponseNotReadError(AzureError):
497562 :type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
498563 """
499564
500- def __init__ (self , response ) :
565+ def __init__ (self , response : _HttpResponseCommonAPI ) -> None :
501566 message = (
502567 "You have not read in the bytes for the response from request {}. "
503568 "Call .read() on the response first." .format (response .request )
0 commit comments