Skip to content

Commit 1a9b633

Browse files
authored
[rest] switch base responses to ABCs (Azure#20448)
* switch to protocol * update changelog * add initial tests * switch from protocol to abc * improve HttpResponse docstrings * lint * HeadersType -> MutableMapping[str, str] * remove iter_text and iter_lines * update tests * improve docstrings * have base impls handle more code * add set_read_checks * commit to restart pipelines * address xiang's comments * lint * clear json cache when encoding is updated * make sure content type is empty string if doesn't exist * update content_type to be None if there is no content type header * fix passing encoding to text method error * update is_stream_consumed docs * remove erroneous committed code
1 parent cc7e454 commit 1a9b633

20 files changed

+919
-567
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
- The `text` property on `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` has changed to a method, which also takes
2323
an `encoding` parameter.
2424
- Removed `iter_text` and `iter_lines` from `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse`
25+
- `azure.core.rest.HttpResponse` and `azure.core.rest.AsyncHttpResponse` are now abstract base classes. They should not be initialized directly, instead
26+
your transport responses should inherit from them and implement them.
2527

2628
### Bugs Fixed
2729

sdk/core/azure-core/azure/core/_pipeline_client_async.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# --------------------------------------------------------------------------
2626

2727
import logging
28-
from collections.abc import Iterable
28+
import collections.abc
2929
from typing import Any, Awaitable
3030
from .configuration import Configuration
3131
from .pipeline import AsyncPipeline
@@ -62,6 +62,26 @@
6262

6363
_LOGGER = logging.getLogger(__name__)
6464

65+
class _AsyncContextManager(collections.abc.Awaitable):
66+
67+
def __init__(self, wrapped: collections.abc.Awaitable):
68+
super().__init__()
69+
self.wrapped = wrapped
70+
self.response = None
71+
72+
def __await__(self):
73+
return self.wrapped.__await__()
74+
75+
async def __aenter__(self):
76+
self.response = await self
77+
return self.response
78+
79+
async def __aexit__(self, *args):
80+
await self.response.__aexit__(*args)
81+
82+
async def close(self):
83+
await self.response.close()
84+
6585

6686
class AsyncPipelineClient(PipelineClientBase):
6787
"""Service client core methods.
@@ -125,7 +145,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
125145
config.proxy_policy,
126146
ContentDecodePolicy(**kwargs)
127147
]
128-
if isinstance(per_call_policies, Iterable):
148+
if isinstance(per_call_policies, collections.abc.Iterable):
129149
policies.extend(per_call_policies)
130150
else:
131151
policies.append(per_call_policies)
@@ -134,7 +154,7 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
134154
config.retry_policy,
135155
config.authentication_policy,
136156
config.custom_hook_policy])
137-
if isinstance(per_retry_policies, Iterable):
157+
if isinstance(per_retry_policies, collections.abc.Iterable):
138158
policies.extend(per_retry_policies)
139159
else:
140160
policies.append(per_retry_policies)
@@ -143,13 +163,13 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
143163
DistributedTracingPolicy(**kwargs),
144164
config.http_logging_policy or HttpLoggingPolicy(**kwargs)])
145165
else:
146-
if isinstance(per_call_policies, Iterable):
166+
if isinstance(per_call_policies, collections.abc.Iterable):
147167
per_call_policies_list = list(per_call_policies)
148168
else:
149169
per_call_policies_list = [per_call_policies]
150170
per_call_policies_list.extend(policies)
151171
policies = per_call_policies_list
152-
if isinstance(per_retry_policies, Iterable):
172+
if isinstance(per_retry_policies, collections.abc.Iterable):
153173
per_retry_policies_list = list(per_retry_policies)
154174
else:
155175
per_retry_policies_list = [per_retry_policies]
@@ -188,7 +208,7 @@ async def _make_pipeline_call(self, request, **kwargs):
188208
# the body is loaded. instead of doing response.read(), going to set the body
189209
# to the internal content
190210
rest_response._content = response.body() # pylint: disable=protected-access
191-
await rest_response.close()
211+
await rest_response._set_read_checks() # pylint: disable=protected-access
192212
except Exception as exc:
193213
await rest_response.close()
194214
raise exc
@@ -222,6 +242,5 @@ def send_request(
222242
:return: The response of your network call. Does not do error handling on your response.
223243
:rtype: ~azure.core.rest.AsyncHttpResponse
224244
"""
225-
from .rest._rest_py3 import _AsyncContextManager
226245
wrapped = self._make_pipeline_call(request, stream=stream, **kwargs)
227246
return _AsyncContextManager(wrapped=wrapped)

sdk/core/azure-core/azure/core/pipeline/_tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,20 @@ def to_rest_request(pipeline_transport_request):
4646
def to_rest_response(pipeline_transport_response):
4747
from .transport._requests_basic import RequestsTransportResponse
4848
from ..rest._requests_basic import RestRequestsTransportResponse
49-
from ..rest import HttpResponse
5049
if isinstance(pipeline_transport_response, RequestsTransportResponse):
5150
response_type = RestRequestsTransportResponse
5251
else:
53-
response_type = HttpResponse
52+
raise ValueError("Unknown transport response")
5453
response = response_type(
5554
request=to_rest_request(pipeline_transport_response.request),
5655
internal_response=pipeline_transport_response.internal_response,
56+
block_size=pipeline_transport_response.block_size
5757
)
58-
response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access
5958
return response
6059

6160
def get_block_size(response):
6261
try:
63-
return response._connection_data_block_size # pylint: disable=protected-access
62+
return response._block_size # pylint: disable=protected-access
6463
except AttributeError:
6564
return response.block_size
6665

sdk/core/azure-core/azure/core/pipeline/_tools_async.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,13 @@ def _get_response_type(pipeline_transport_response):
5555
return RestTrioRequestsTransportResponse
5656
except ImportError:
5757
pass
58-
from ..rest import AsyncHttpResponse
59-
return AsyncHttpResponse
58+
raise ValueError("Unknown transport response")
6059

6160
def to_rest_response(pipeline_transport_response):
6261
response_type = _get_response_type(pipeline_transport_response)
6362
response = response_type(
6463
request=to_rest_request(pipeline_transport_response.request),
6564
internal_response=pipeline_transport_response.internal_response,
65+
block_size=pipeline_transport_response.block_size,
6666
)
67-
response._connection_data_block_size = pipeline_transport_response.block_size # pylint: disable=protected-access
6867
return response

sdk/core/azure-core/azure/core/rest/_aiohttp.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from itertools import groupby
2929
from typing import AsyncIterator
3030
from multidict import CIMultiDict
31-
from . import HttpRequest, AsyncHttpResponse
32-
from ._helpers_py3 import iter_raw_helper, iter_bytes_helper
31+
from ._http_response_impl_async import AsyncHttpResponseImpl
3332
from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator
3433

3534
class _ItemsView(collections.abc.ItemsView):
@@ -115,42 +114,26 @@ def get(self, key, default=None):
115114
values = ", ".join(values)
116115
return values or default
117116

118-
class RestAioHttpTransportResponse(AsyncHttpResponse):
117+
class RestAioHttpTransportResponse(AsyncHttpResponseImpl):
119118
def __init__(
120119
self,
121120
*,
122-
request: HttpRequest,
123121
internal_response,
122+
decompress: bool = True,
123+
**kwargs
124124
):
125-
super().__init__(request=request, internal_response=internal_response)
126-
self.status_code = internal_response.status
127-
self.headers = _CIMultiDict(internal_response.headers) # type: ignore
128-
self.reason = internal_response.reason
129-
self.content_type = internal_response.headers.get('content-type')
130-
131-
async def iter_raw(self) -> AsyncIterator[bytes]:
132-
"""Asynchronously iterates over the response's bytes. Will not decompress in the process
133-
134-
:return: An async iterator of bytes from the response
135-
:rtype: AsyncIterator[bytes]
136-
"""
137-
async for part in iter_raw_helper(AioHttpStreamDownloadGenerator, self):
138-
yield part
139-
await self.close()
140-
141-
async def iter_bytes(self) -> AsyncIterator[bytes]:
142-
"""Asynchronously iterates over the response's bytes. Will decompress in the process
143-
144-
:return: An async iterator of bytes from the response
145-
:rtype: AsyncIterator[bytes]
146-
"""
147-
async for part in iter_bytes_helper(
148-
AioHttpStreamDownloadGenerator,
149-
self,
150-
content=self._content
151-
):
152-
yield part
153-
await self.close()
125+
headers = _CIMultiDict(internal_response.headers)
126+
super().__init__(
127+
internal_response=internal_response,
128+
status_code=internal_response.status,
129+
headers=headers,
130+
content_type=headers.get('content-type'),
131+
reason=internal_response.reason,
132+
stream_download_generator=AioHttpStreamDownloadGenerator,
133+
content=None,
134+
**kwargs
135+
)
136+
self._decompress = decompress
154137

155138
def __getstate__(self):
156139
state = self.__dict__.copy()
@@ -165,6 +148,7 @@ async def close(self) -> None:
165148
:return: None
166149
:rtype: None
167150
"""
168-
self.is_closed = True
169-
self._internal_response.close()
170-
await asyncio.sleep(0)
151+
if not self.is_closed:
152+
self._is_closed = True
153+
self._internal_response.close()
154+
await asyncio.sleep(0)

sdk/core/azure-core/azure/core/rest/_helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@
3535
Union,
3636
Mapping,
3737
Sequence,
38-
List,
3938
Tuple,
4039
IO,
4140
Any,
4241
Dict,
4342
Iterable,
43+
MutableMapping,
4444
)
4545
import xml.etree.ElementTree as ET
4646
import six
@@ -66,8 +66,6 @@
6666

6767
ParamsType = Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]]
6868

69-
HeadersType = Mapping[str, str]
70-
7169
FileContent = Union[str, bytes, IO[str], IO[bytes]]
7270
FileType = Union[
7371
Tuple[Optional[str], FileContent],
@@ -129,8 +127,8 @@ def set_xml_body(content):
129127
return headers, body
130128

131129
def _shared_set_content_body(content):
132-
# type: (Any) -> Tuple[HeadersType, Optional[ContentTypeBase]]
133-
headers = {} # type: HeadersType
130+
# type: (Any) -> Tuple[MutableMapping[str, str], Optional[ContentTypeBase]]
131+
headers = {} # type: MutableMapping[str, str]
134132

135133
if isinstance(content, ET.Element):
136134
# XML body

sdk/core/azure-core/azure/core/rest/_helpers_py3.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,14 @@
3030
Iterable,
3131
Tuple,
3232
Union,
33-
Callable,
34-
Optional,
35-
AsyncIterator as AsyncIteratorType
33+
MutableMapping,
3634
)
37-
from ..exceptions import StreamConsumedError, StreamClosedError
3835

39-
from ._helpers import (
40-
_shared_set_content_body,
41-
HeadersType
42-
)
36+
from ._helpers import _shared_set_content_body
4337
ContentType = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
4438

4539
def set_content_body(content: ContentType) -> Tuple[
46-
HeadersType, ContentType
40+
MutableMapping[str, str], ContentType
4741
]:
4842
headers, body = _shared_set_content_body(content)
4943
if body is not None:
@@ -54,48 +48,3 @@ def set_content_body(content: ContentType) -> Tuple[
5448
"Unexpected type for 'content': '{}'. ".format(type(content)) +
5549
"We expect 'content' to either be str, bytes, or an Iterable / AsyncIterable"
5650
)
57-
58-
def _stream_download_helper(
59-
decompress: bool,
60-
stream_download_generator: Callable,
61-
response,
62-
) -> AsyncIteratorType[bytes]:
63-
if response.is_stream_consumed:
64-
raise StreamConsumedError(response)
65-
if response.is_closed:
66-
raise StreamClosedError(response)
67-
68-
response.is_stream_consumed = True
69-
return stream_download_generator(
70-
pipeline=None,
71-
response=response,
72-
decompress=decompress,
73-
)
74-
75-
async def iter_bytes_helper(
76-
stream_download_generator: Callable,
77-
response,
78-
content: Optional[bytes],
79-
) -> AsyncIteratorType[bytes]:
80-
if content:
81-
chunk_size = response._connection_data_block_size # pylint: disable=protected-access
82-
for i in range(0, len(content), chunk_size):
83-
yield content[i : i + chunk_size]
84-
else:
85-
async for part in _stream_download_helper(
86-
decompress=True,
87-
stream_download_generator=stream_download_generator,
88-
response=response,
89-
):
90-
yield part
91-
92-
async def iter_raw_helper(
93-
stream_download_generator: Callable,
94-
response,
95-
) -> AsyncIteratorType[bytes]:
96-
async for part in _stream_download_helper(
97-
decompress=False,
98-
stream_download_generator=stream_download_generator,
99-
response=response,
100-
):
101-
yield part

0 commit comments

Comments
 (0)