Skip to content

Commit 62116d4

Browse files
authored
[rest] unify response headers behaviors across transports (Azure#20234)
1 parent c06add2 commit 62116d4

File tree

10 files changed

+596
-132
lines changed

10 files changed

+596
-132
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ an `encoding` parameter.
1515

1616
### Bugs Fixed
1717

18+
- The behaviour of the headers returned in `azure.core.rest` responses now aligns across sync and async. Items can now be checked case-insensitively and without raising an error for format.
19+
1820
### Other Changes
1921

2022
## 1.17.0 (2021-08-05)

sdk/core/azure-core/azure/core/rest/_aiohttp.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,97 @@
2323
# IN THE SOFTWARE.
2424
#
2525
# --------------------------------------------------------------------------
26-
26+
import collections.abc
2727
import asyncio
28+
from itertools import groupby
2829
from typing import AsyncIterator
2930
from multidict import CIMultiDict
3031
from . import HttpRequest, AsyncHttpResponse
3132
from ._helpers_py3 import iter_raw_helper, iter_bytes_helper
3233
from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator
3334

35+
class _ItemsView(collections.abc.ItemsView):
36+
def __init__(self, ref):
37+
super().__init__(ref)
38+
self._ref = ref
39+
40+
def __iter__(self):
41+
for key, groups in groupby(self._ref.__iter__(), lambda x: x[0]):
42+
yield tuple([key, ", ".join(group[1] for group in groups)])
43+
44+
def __contains__(self, item):
45+
if not (isinstance(item, (list, tuple)) and len(item) == 2):
46+
return False
47+
for k, v in self.__iter__():
48+
if item[0].lower() == k.lower() and item[1] == v:
49+
return True
50+
return False
51+
52+
def __repr__(self):
53+
return f"dict_items({list(self.__iter__())})"
54+
55+
class _KeysView(collections.abc.KeysView):
56+
def __init__(self, items):
57+
super().__init__(items)
58+
self._items = items
59+
60+
def __iter__(self):
61+
for key, _ in self._items:
62+
yield key
63+
64+
def __contains__(self, key):
65+
for k in self.__iter__():
66+
if key.lower() == k.lower():
67+
return True
68+
return False
69+
def __repr__(self):
70+
return f"dict_keys({list(self.__iter__())})"
71+
72+
class _ValuesView(collections.abc.ValuesView):
73+
def __init__(self, items):
74+
super().__init__(items)
75+
self._items = items
76+
77+
def __iter__(self):
78+
for _, value in self._items:
79+
yield value
80+
81+
def __contains__(self, value):
82+
for v in self.__iter__():
83+
if value == v:
84+
return True
85+
return False
86+
87+
def __repr__(self):
88+
return f"dict_values({list(self.__iter__())})"
89+
90+
91+
class _CIMultiDict(CIMultiDict):
92+
"""Dictionary with the support for duplicate case-insensitive keys."""
93+
94+
def __iter__(self):
95+
return iter(self.keys())
96+
97+
def keys(self):
98+
"""Return a new view of the dictionary's keys."""
99+
return _KeysView(self.items())
100+
101+
def items(self):
102+
"""Return a new view of the dictionary's items."""
103+
return _ItemsView(super().items())
104+
105+
def values(self):
106+
"""Return a new view of the dictionary's values."""
107+
return _ValuesView(self.items())
108+
109+
def __getitem__(self, key: str) -> str:
110+
return ", ".join(self.getall(key, []))
111+
112+
def get(self, key, default=None):
113+
values = self.getall(key, None)
114+
if values:
115+
values = ", ".join(values)
116+
return values or default
34117

35118
class RestAioHttpTransportResponse(AsyncHttpResponse):
36119
def __init__(
@@ -41,7 +124,7 @@ def __init__(
41124
):
42125
super().__init__(request=request, internal_response=internal_response)
43126
self.status_code = internal_response.status
44-
self.headers = CIMultiDict(internal_response.headers) # type: ignore
127+
self.headers = _CIMultiDict(internal_response.headers) # type: ignore
45128
self.reason = internal_response.reason
46129
self.content_type = internal_response.headers.get('content-type')
47130

sdk/core/azure-core/azure/core/rest/_requests_basic.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,42 @@
2323
# IN THE SOFTWARE.
2424
#
2525
# --------------------------------------------------------------------------
26+
try:
27+
import collections.abc as collections
28+
except ImportError:
29+
import collections # type: ignore
30+
2631
from typing import TYPE_CHECKING, cast
32+
from requests.structures import CaseInsensitiveDict
2733

2834
from ..exceptions import ResponseNotReadError, StreamConsumedError, StreamClosedError
2935
from ._rest import _HttpResponseBase, HttpResponse
3036
from ..pipeline.transport._requests_basic import StreamDownloadGenerator
3137

38+
class _ItemsView(collections.ItemsView):
39+
40+
def __contains__(self, item):
41+
if not (isinstance(item, (list, tuple)) and len(item) == 2):
42+
return False # requests raises here, we just return False
43+
for k, v in self.__iter__():
44+
if item[0].lower() == k.lower() and item[1] == v:
45+
return True
46+
return False
47+
48+
def __repr__(self):
49+
return 'ItemsView({})'.format(dict(self.__iter__()))
50+
51+
class _CaseInsensitiveDict(CaseInsensitiveDict):
52+
"""Overriding default requests dict so we can unify
53+
to not raise if users pass in incorrect items to contains.
54+
Instead, we return False
55+
"""
56+
57+
def items(self):
58+
"""Return a new view of the dictionary's items."""
59+
return _ItemsView(self)
60+
61+
3262
if TYPE_CHECKING:
3363
from typing import Iterator, Optional
3464

@@ -43,7 +73,7 @@ class _RestRequestsTransportResponseBase(_HttpResponseBase):
4373
def __init__(self, **kwargs):
4474
super(_RestRequestsTransportResponseBase, self).__init__(**kwargs)
4575
self.status_code = self._internal_response.status_code
46-
self.headers = self._internal_response.headers
76+
self.headers = _CaseInsensitiveDict(self._internal_response.headers)
4777
self.reason = self._internal_response.reason
4878
self.content_type = self._internal_response.headers.get('content-type')
4979

sdk/core/azure-core/azure/core/rest/_rest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ class HttpResponse(_HttpResponseBase): # pylint: disable=too-many-instance-attr
306306
:keyword request: The request that resulted in this response.
307307
:paramtype request: ~azure.core.rest.HttpRequest
308308
:ivar int status_code: The status code of this response
309-
:ivar mapping headers: The response headers
309+
:ivar mapping headers: The case-insensitive response headers.
310+
While looking up headers is case-insensitive, when looking up
311+
keys in `header.keys()`, we recommend using lowercase.
310312
:ivar str reason: The reason phrase for this response
311313
:ivar bytes content: The response content in bytes.
312314
:ivar str url: The URL that resulted in this response

sdk/core/azure-core/azure/core/rest/_rest_py3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def __init__(
227227
self.request = request
228228
self._internal_response = kwargs.pop("internal_response")
229229
self.status_code = None
230-
self.headers = {} # type: HeadersType
230+
self.headers = _case_insensitive_dict({})
231231
self.reason = None
232232
self.is_closed = False
233233
self.is_stream_consumed = False
@@ -284,7 +284,7 @@ def json(self) -> Any:
284284
"""
285285
# this will trigger errors if response is not read in
286286
self.content # pylint: disable=pointless-statement
287-
if not self._json:
287+
if self._json is None:
288288
self._json = loads(self.text())
289289
return self._json
290290

@@ -320,7 +320,9 @@ class HttpResponse(_HttpResponseBase):
320320
:keyword request: The request that resulted in this response.
321321
:paramtype request: ~azure.core.rest.HttpRequest
322322
:ivar int status_code: The status code of this response
323-
:ivar mapping headers: The response headers
323+
:ivar mapping headers: The case-insensitive response headers.
324+
While looking up headers is case-insensitive, when looking up
325+
keys in `header.keys()`, we recommend using lowercase.
324326
:ivar str reason: The reason phrase for this response
325327
:ivar bytes content: The response content in bytes.
326328
:ivar str url: The URL that resulted in this response

0 commit comments

Comments
 (0)