Skip to content

Commit d0e378e

Browse files
authored
Azure-Core Exceptions: Type Complete (Azure#31056)
* First update * Some clean-up * Fix message typing * Type complete * Protocol is in typing_extensions * And also runtime_checkable * Pylint * Remove unecessary double quotes * Anna's feedback
1 parent 92299a2 commit d0e378e

File tree

1 file changed

+107
-42
lines changed

1 file changed

+107
-42
lines changed

sdk/core/azure-core/azure/core/exceptions.py

Lines changed: 107 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,39 @@
2323
# IN THE SOFTWARE.
2424
#
2525
# --------------------------------------------------------------------------
26-
26+
from __future__ import annotations
2727
import json
2828
import logging
2929
import sys
3030

31-
from typing import Callable, Any, Optional, Union, Type, List, Dict, TYPE_CHECKING
31+
from types import TracebackType
32+
from typing import (
33+
Callable,
34+
Any,
35+
Optional,
36+
Union,
37+
Type,
38+
List,
39+
Mapping,
40+
TypeVar,
41+
Generic,
42+
Dict,
43+
TYPE_CHECKING,
44+
)
45+
from typing_extensions import Protocol, runtime_checkable
3246

3347
_LOGGER = logging.getLogger(__name__)
3448

3549
if TYPE_CHECKING:
36-
from azure.core.pipeline.transport._base import _HttpResponseBase
3750
from azure.core.pipeline.policies import RequestHistory
3851

52+
HTTPResponseType = TypeVar("HTTPResponseType")
53+
HTTPRequestType = TypeVar("HTTPRequestType")
54+
KeyType = TypeVar("KeyType")
55+
ValueType = TypeVar("ValueType")
56+
# To replace when typing.Self is available in our baseline
57+
SelfODataV4Format = TypeVar("SelfODataV4Format", bound="ODataV4Format")
58+
3959

4060
__all__ = [
4161
"AzureError",
@@ -59,7 +79,7 @@
5979
]
6080

6181

62-
def raise_with_traceback(exception: Callable, *args, **kwargs) -> None:
82+
def raise_with_traceback(exception: Callable, *args: Any, **kwargs: Any) -> None:
6383
"""Raise exception with a specified traceback.
6484
This MUST be called inside a "except" clause.
6585
@@ -83,26 +103,58 @@ def raise_with_traceback(exception: Callable, *args, **kwargs) -> None:
83103
raise error # pylint: disable=raise-missing-from
84104

85105

86-
class ErrorMap:
106+
@runtime_checkable
107+
class _HttpResponseCommonAPI(Protocol):
108+
"""Protocol used by exceptions for HTTP response.
109+
110+
As HttpResponseError uses very few properties of HttpResponse, a protocol
111+
is faster and simpler than import all the possible types (at least 6).
112+
"""
113+
114+
@property
115+
def reason(self) -> Optional[str]:
116+
pass
117+
118+
@property
119+
def status_code(self) -> Optional[int]:
120+
pass
121+
122+
def text(self) -> str:
123+
pass
124+
125+
@property
126+
def request(self) -> object: # object as type, since all we need is str() on it
127+
pass
128+
129+
130+
class ErrorMap(Generic[KeyType, ValueType]):
87131
"""Error Map class. To be used in map_error method, behaves like a dictionary.
88132
It returns the error type if it is found in custom_error_map. Or return default_error
89133
90134
:param dict custom_error_map: User-defined error map, it is used to map status codes to error types.
91135
:keyword error default_error: Default error type. It is returned if the status code is not found in custom_error_map
92136
"""
93137

94-
def __init__(self, custom_error_map=None, **kwargs):
138+
def __init__(
139+
self, # pylint: disable=unused-argument
140+
custom_error_map: Optional[Mapping[KeyType, ValueType]] = None,
141+
*,
142+
default_error: Optional[ValueType] = None,
143+
**kwargs: Any,
144+
) -> None:
95145
self._custom_error_map = custom_error_map or {}
96-
self._default_error = kwargs.pop("default_error", None)
146+
self._default_error = default_error
97147

98-
def get(self, key):
148+
def get(self, key: KeyType) -> Optional[ValueType]:
99149
ret = self._custom_error_map.get(key)
100150
if ret:
101151
return ret
102152
return self._default_error
103153

104154

105-
def map_error(status_code, response, error_map):
155+
def map_error(
156+
status_code: int, response: _HttpResponseCommonAPI, error_map: Mapping[int, Type[HttpResponseError]]
157+
) -> None:
106158
if not error_map:
107159
return
108160
error_type = error_map.get(status_code)
@@ -157,7 +209,7 @@ class ODataV4Format:
157209
DETAILS_LABEL = "details"
158210
INNERERROR_LABEL = "innererror"
159211

160-
def __init__(self, json_object: Dict[str, Any]):
212+
def __init__(self, json_object: Mapping[str, Any]) -> None:
161213
if "error" in json_object:
162214
json_object = json_object["error"]
163215
cls: Type[ODataV4Format] = self.__class__
@@ -180,10 +232,10 @@ def __init__(self, json_object: Dict[str, Any]):
180232
except Exception: # pylint: disable=broad-except
181233
pass
182234

183-
self.innererror: Dict[str, Any] = json_object.get(cls.INNERERROR_LABEL, {})
235+
self.innererror: Mapping[str, Any] = json_object.get(cls.INNERERROR_LABEL, {})
184236

185237
@property
186-
def error(self):
238+
def error(self: SelfODataV4Format) -> SelfODataV4Format:
187239
import warnings
188240

189241
warnings.warn(
@@ -192,7 +244,7 @@ def error(self):
192244
)
193245
return self
194246

195-
def __str__(self):
247+
def __str__(self) -> str:
196248
return "({}) {}\n{}".format(self.code, self.message, self.message_details())
197249

198250
def message_details(self) -> str:
@@ -220,7 +272,7 @@ def message_details(self) -> str:
220272
class AzureError(Exception):
221273
"""Base exception for all errors.
222274
223-
:param message: The message object stringified as 'message' attribute
275+
:param object message: The message object stringified as 'message' attribute
224276
:keyword error: The original exception if any
225277
:paramtype error: Exception
226278
@@ -235,16 +287,21 @@ class AzureError(Exception):
235287
and will be `None` where continuation is either unavailable or not applicable.
236288
"""
237289

238-
def __init__(self, message, *args, **kwargs):
239-
self.inner_exception = kwargs.get("error")
240-
self.exc_type, self.exc_value, self.exc_traceback = sys.exc_info()
241-
self.exc_type = self.exc_type.__name__ if self.exc_type else type(self.inner_exception)
242-
self.exc_msg = "{}, {}: {}".format(message, self.exc_type, self.exc_value)
243-
self.message = str(message)
244-
self.continuation_token = kwargs.get("continuation_token")
290+
def __init__(self, message: Optional[object], *args: Any, **kwargs: Any) -> None:
291+
self.inner_exception: Optional[BaseException] = kwargs.get("error")
292+
293+
exc_info = sys.exc_info()
294+
self.exc_type: Optional[Type[Any]] = exc_info[0]
295+
self.exc_value: Optional[BaseException] = exc_info[1]
296+
self.exc_traceback: Optional[TracebackType] = exc_info[2]
297+
298+
self.exc_type = self.exc_type if self.exc_type else type(self.inner_exception)
299+
self.exc_msg: str = "{}, {}: {}".format(message, self.exc_type.__name__, self.exc_value)
300+
self.message: Optional[str] = str(message)
301+
self.continuation_token: Optional[str] = kwargs.get("continuation_token")
245302
super(AzureError, self).__init__(self.message, *args)
246303

247-
def raise_with_traceback(self):
304+
def raise_with_traceback(self) -> None:
248305
"""Raise the exception with the existing traceback.
249306
250307
.. deprecated:: 1.22.0
@@ -253,7 +310,7 @@ def raise_with_traceback(self):
253310
try:
254311
raise super(AzureError, self).with_traceback(self.exc_traceback) # pylint: disable=raise-missing-from
255312
except AttributeError:
256-
self.__traceback__ = self.exc_traceback
313+
self.__traceback__: Optional[TracebackType] = self.exc_traceback
257314
raise self # pylint: disable=raise-missing-from
258315

259316

@@ -280,8 +337,7 @@ class ServiceResponseTimeoutError(ServiceResponseError):
280337
class HttpResponseError(AzureError):
281338
"""A request was made, and a non-success status code was received from the service.
282339
283-
:param message: HttpResponse's error message
284-
:type message: string
340+
:param object message: The message object stringified as 'message' attribute
285341
:param response: The response that triggered the exception.
286342
:type response: ~azure.core.pipeline.transport.HttpResponse or ~azure.core.pipeline.transport.AsyncHttpResponse
287343
@@ -297,24 +353,27 @@ class HttpResponseError(AzureError):
297353
:vartype error: ODataV4Format
298354
"""
299355

300-
def __init__(self, message=None, response=None, **kwargs):
356+
def __init__(
357+
self, message: Optional[object] = None, response: Optional[_HttpResponseCommonAPI] = None, **kwargs: Any
358+
) -> None:
301359
# Don't want to document this one yet.
302360
error_format = kwargs.get("error_format", ODataV4Format)
303361

304-
self.reason = None
305-
self.status_code = None
306-
self.response = response
362+
self.reason: Optional[str] = None
363+
self.status_code: Optional[int] = None
364+
self.response: Optional[_HttpResponseCommonAPI] = response
307365
if response:
308366
self.reason = response.reason
309367
self.status_code = response.status_code
310368

311369
# old autorest are setting "error" before calling __init__, so it might be there already
312370
# transferring into self.model
313371
model: Optional[Any] = kwargs.pop("model", None)
372+
self.model: Optional[Any]
314373
if model is not None: # autorest v5
315374
self.model = model
316375
else: # autorest azure-core, for KV 1.0, Storage 12.0, etc.
317-
self.model: Optional[Any] = getattr(self, "error", None)
376+
self.model = getattr(self, "error", None)
318377
self.error: Optional[ODataV4Format] = self._parse_odata_body(error_format, response)
319378

320379
# By priority, message is:
@@ -329,19 +388,23 @@ def __init__(self, message=None, response=None, **kwargs):
329388
super(HttpResponseError, self).__init__(message=message, **kwargs)
330389

331390
@staticmethod
332-
def _parse_odata_body(error_format: Type[ODataV4Format], response: "_HttpResponseBase") -> Optional[ODataV4Format]:
391+
def _parse_odata_body(
392+
error_format: Type[ODataV4Format], response: Optional[_HttpResponseCommonAPI]
393+
) -> Optional[ODataV4Format]:
333394
try:
334-
odata_json = json.loads(response.text())
395+
# https://github.com/python/mypy/issues/14743#issuecomment-1664725053
396+
odata_json = json.loads(response.text()) # type: ignore
335397
return error_format(odata_json)
336398
except Exception: # pylint: disable=broad-except
337399
# If the body is not JSON valid, just stop now
338400
pass
339401
return None
340402

341-
def __str__(self):
403+
def __str__(self) -> str:
342404
retval = super(HttpResponseError, self).__str__()
343405
try:
344-
body = self.response.text()
406+
# https://github.com/python/mypy/issues/14743#issuecomment-1664725053
407+
body = self.response.text() # type: ignore
345408
if body and not self.error:
346409
return "{}\nContent: {}".format(retval, body)[:2048]
347410
except Exception: # pylint: disable=broad-except
@@ -381,14 +444,16 @@ class ResourceNotModifiedError(HttpResponseError):
381444
This will not be raised directly by the Azure core pipeline."""
382445

383446

384-
class TooManyRedirectsError(HttpResponseError):
447+
class TooManyRedirectsError(HttpResponseError, Generic[HTTPRequestType, HTTPResponseType]):
385448
"""Reached the maximum number of redirect attempts.
386449
387450
:param history: The history of requests made while trying to fulfill the request.
388451
:type history: list[~azure.core.pipeline.policies.RequestHistory]
389452
"""
390453

391-
def __init__(self, history, *args, **kwargs):
454+
def __init__(
455+
self, history: "List[RequestHistory[HTTPRequestType, HTTPResponseType]]", *args: Any, **kwargs: Any
456+
) -> None:
392457
self.history = history
393458
message = "Reached maximum redirect attempts."
394459
super(TooManyRedirectsError, self).__init__(message, *args, **kwargs)
@@ -414,7 +479,7 @@ class ODataV4Error(HttpResponseError):
414479

415480
_ERROR_FORMAT = ODataV4Format
416481

417-
def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
482+
def __init__(self, response: _HttpResponseCommonAPI, **kwargs: Any) -> None:
418483
# Ensure field are declared, whatever can happen afterwards
419484
self.odata_json: Optional[Dict[str, Any]] = None
420485
try:
@@ -428,7 +493,7 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
428493
self.message: Optional[str] = kwargs.get("message", odata_message)
429494
self.target: Optional[str] = None
430495
self.details: Optional[List[Any]] = []
431-
self.innererror: Optional[Dict[str, Any]] = {}
496+
self.innererror: Optional[Mapping[str, Any]] = {}
432497

433498
if self.message and "message" not in kwargs:
434499
kwargs["message"] = self.message
@@ -445,7 +510,7 @@ def __init__(self, response: "_HttpResponseBase", **kwargs) -> None:
445510
_LOGGER.info("Received error message was not valid OdataV4 format.")
446511
self._error_format = "JSON was invalid for format " + str(self._ERROR_FORMAT)
447512

448-
def __str__(self):
513+
def __str__(self) -> str:
449514
if self._error_format:
450515
return str(self._error_format)
451516
return super(ODataV4Error, self).__str__()
@@ -461,7 +526,7 @@ class StreamConsumedError(AzureError):
461526
:type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
462527
"""
463528

464-
def __init__(self, response):
529+
def __init__(self, response: _HttpResponseCommonAPI) -> None:
465530
message = (
466531
"You are attempting to read or stream the content from request {}. "
467532
"You have likely already consumed this stream, so it can not be accessed anymore.".format(response.request)
@@ -479,7 +544,7 @@ class StreamClosedError(AzureError):
479544
:type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
480545
"""
481546

482-
def __init__(self, response):
547+
def __init__(self, response: _HttpResponseCommonAPI) -> None:
483548
message = (
484549
"The content for response from request {} can no longer be read or streamed, since the "
485550
"response has already been closed.".format(response.request)
@@ -497,7 +562,7 @@ class ResponseNotReadError(AzureError):
497562
:type response: ~azure.core.rest.HttpResponse or ~azure.core.rest.AsyncHttpResponse
498563
"""
499564

500-
def __init__(self, response):
565+
def __init__(self, response: _HttpResponseCommonAPI) -> None:
501566
message = (
502567
"You have not read in the bytes for the response from request {}. "
503568
"Call .read() on the response first.".format(response.request)

0 commit comments

Comments
 (0)