Skip to content

Commit 68d502e

Browse files
authored
Core raw streaming (Azure#17920)
* add raw streaming support
1 parent db9cde5 commit 68d502e

File tree

9 files changed

+138
-39
lines changed

9 files changed

+138
-39
lines changed

sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class HttpResponse(object):
279279
def text(self, encoding=None):
280280
"""Return the whole body as a string."""
281281

282-
def stream_download(self, chunk_size=None, callback=None):
282+
def stream_download(self, pipeline, **kwargs):
283283
"""Generator for streaming request body data.
284284
Should be implemented by sub-classes if streaming download
285285
is supported.

sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
CONTENT_CHUNK_SIZE = 10 * 1024
4747
_LOGGER = logging.getLogger(__name__)
4848

49-
5049
class AioHttpTransport(AsyncHttpTransport):
5150
"""AioHttp HTTP sender implementation.
5251
@@ -89,7 +88,8 @@ async def open(self):
8988
self.session = aiohttp.ClientSession(
9089
loop=self._loop,
9190
trust_env=self._use_env_settings,
92-
cookie_jar=jar
91+
cookie_jar=jar,
92+
auto_decompress=False,
9393
)
9494
if self.session is not None:
9595
await self.session.__aenter__()
@@ -191,22 +191,24 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
191191
raise ServiceResponseError(err, error=err) from err
192192
return response
193193

194-
195194
class AioHttpStreamDownloadGenerator(AsyncIterator):
196195
"""Streams the response body data.
197196
198197
:param pipeline: The pipeline object
199198
:param response: The client response object.
200-
:param block_size: block size of data sent over connection.
201-
:type block_size: int
199+
:keyword bool decompress: If True which is default, will attempt to decode the body based
200+
on the ‘content-encoding’ header.
202201
"""
203-
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
202+
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
204203
self.pipeline = pipeline
205204
self.request = response.request
206205
self.response = response
207206
self.block_size = response.block_size
207+
self._decompress = kwargs.pop("decompress", True)
208+
if len(kwargs) > 0:
209+
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
208210
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
209-
self.downloaded = 0
211+
self._decompressor = None
210212

211213
def __len__(self):
212214
return self.content_length
@@ -216,6 +218,18 @@ async def __anext__(self):
216218
chunk = await self.response.internal_response.content.read(self.block_size)
217219
if not chunk:
218220
raise _ResponseStopIteration()
221+
if not self._decompress:
222+
return chunk
223+
enc = self.response.internal_response.headers.get('Content-Encoding')
224+
if not enc:
225+
return chunk
226+
enc = enc.lower()
227+
if enc in ("gzip", "deflate"):
228+
if not self._decompressor:
229+
import zlib
230+
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS
231+
self._decompressor = zlib.decompressobj(wbits=zlib_mode)
232+
chunk = self._decompressor.decompress(chunk)
219233
return chunk
220234
except _ResponseStopIteration:
221235
self.response.internal_response.close()
@@ -269,13 +283,15 @@ async def load_body(self) -> None:
269283
"""Load in memory the body, so it could be accessible from sync methods."""
270284
self._body = await self.internal_response.read()
271285

272-
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
286+
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
273287
"""Generator for streaming response body data.
274288
275289
:param pipeline: The pipeline object
276-
:type pipeline: azure.core.pipeline
290+
:type pipeline: azure.core.pipeline.Pipeline
291+
:keyword bool decompress: If True which is default, will attempt to decode the body based
292+
on the ‘content-encoding’ header.
277293
"""
278-
return AioHttpStreamDownloadGenerator(pipeline, self)
294+
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)
279295

280296
def __getstate__(self):
281297
# Be sure body is loaded in memory, otherwise not pickable and let it throw

sdk/core/azure-core/azure/core/pipeline/transport/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,8 @@ def __repr__(self):
580580

581581

582582
class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
583-
def stream_download(self, pipeline):
584-
# type: (PipelineType) -> Iterator[bytes]
583+
def stream_download(self, pipeline, **kwargs):
584+
# type: (PipelineType, **Any) -> Iterator[bytes]
585585
"""Generator for streaming request body data.
586586
587587
Should be implemented by sub-classes if streaming download

sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,16 @@ class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
124124
Allows for the asynchronous streaming of data from the response.
125125
"""
126126

127-
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
127+
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
128128
"""Generator for streaming response body data.
129129
130130
Should be implemented by sub-classes if streaming download
131131
is supported. Will return an asynchronous generator.
132132
133133
:param pipeline: The pipeline object
134-
:type pipeline: azure.core.pipeline
134+
:type pipeline: azure.core.pipeline.Pipeline
135+
:keyword bool decompress: If True which is default, will attempt to decode the body based
136+
on the ‘content-encoding’ header.
135137
"""
136138

137139
def parts(self) -> AsyncIterator:

sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
AsyncHttpResponse,
4343
_ResponseStopIteration,
4444
_iterate_response_content)
45-
from ._requests_basic import RequestsTransportResponse
45+
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
4646
from ._base_requests_async import RequestsAsyncTransportBase
4747

4848

@@ -138,17 +138,22 @@ class AsyncioStreamDownloadGenerator(AsyncIterator):
138138
139139
:param pipeline: The pipeline object
140140
:param response: The response object.
141-
:param generator iter_content_func: Iterator for response data.
142-
:param int content_length: size of body in bytes.
141+
:keyword bool decompress: If True which is default, will attempt to decode the body based
142+
on the ‘content-encoding’ header.
143143
"""
144-
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
144+
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
145145
self.pipeline = pipeline
146146
self.request = response.request
147147
self.response = response
148148
self.block_size = response.block_size
149-
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
149+
decompress = kwargs.pop("decompress", True)
150+
if len(kwargs) > 0:
151+
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
152+
if decompress:
153+
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
154+
else:
155+
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
150156
self.content_length = int(response.headers.get('Content-Length', 0))
151-
self.downloaded = 0
152157

153158
def __len__(self):
154159
return self.content_length
@@ -178,6 +183,6 @@ async def __anext__(self):
178183
class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
179184
"""Asynchronous streaming of data from the response.
180185
"""
181-
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
186+
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore
182187
"""Generator for streaming request body data."""
183-
return AsyncioStreamDownloadGenerator(pipeline, self) # type: ignore
188+
return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) # type: ignore

sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from typing import Iterator, Optional, Any, Union, TypeVar
2929
import urllib3 # type: ignore
3030
from urllib3.util.retry import Retry # type: ignore
31+
from urllib3.exceptions import (
32+
DecodeError, ReadTimeoutError, ProtocolError
33+
)
3134
import requests
3235

3336
from azure.core.configuration import ConnectionConfiguration
@@ -48,6 +51,25 @@
4851

4952
_LOGGER = logging.getLogger(__name__)
5053

54+
def _read_raw_stream(response, chunk_size=1):
55+
# Special case for urllib3.
56+
if hasattr(response.raw, 'stream'):
57+
try:
58+
for chunk in response.raw.stream(chunk_size, decode_content=False):
59+
yield chunk
60+
except ProtocolError as e:
61+
raise requests.exceptions.ChunkedEncodingError(e)
62+
except DecodeError as e:
63+
raise requests.exceptions.ContentDecodingError(e)
64+
except ReadTimeoutError as e:
65+
raise requests.exceptions.ConnectionError(e)
66+
else:
67+
# Standard file-like object.
68+
while True:
69+
chunk = response.raw.read(chunk_size)
70+
if not chunk:
71+
break
72+
yield chunk
5173

5274
class _RequestsTransportResponseBase(_HttpResponseBase):
5375
"""Base class for accessing response data.
@@ -98,13 +120,21 @@ class StreamDownloadGenerator(object):
98120
99121
:param pipeline: The pipeline object
100122
:param response: The response object.
123+
:keyword bool decompress: If True which is default, will attempt to decode the body based
124+
on the ‘content-encoding’ header.
101125
"""
102-
def __init__(self, pipeline, response):
126+
def __init__(self, pipeline, response, **kwargs):
103127
self.pipeline = pipeline
104128
self.request = response.request
105129
self.response = response
106130
self.block_size = response.block_size
107-
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
131+
decompress = kwargs.pop("decompress", True)
132+
if len(kwargs) > 0:
133+
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
134+
if decompress:
135+
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
136+
else:
137+
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
108138
self.content_length = int(response.headers.get('Content-Length', 0))
109139

110140
def __len__(self):
@@ -134,10 +164,10 @@ def __next__(self):
134164
class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase):
135165
"""Streaming of data from the response.
136166
"""
137-
def stream_download(self, pipeline):
138-
# type: (PipelineType) -> Iterator[bytes]
167+
def stream_download(self, pipeline, **kwargs):
168+
# type: (PipelineType, **Any) -> Iterator[bytes]
139169
"""Generator for streaming request body data."""
140-
return StreamDownloadGenerator(pipeline, self)
170+
return StreamDownloadGenerator(pipeline, self, **kwargs)
141171

142172

143173
class RequestsTransport(HttpTransport):

sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
AsyncHttpResponse,
4343
_ResponseStopIteration,
4444
_iterate_response_content)
45-
from ._requests_basic import RequestsTransportResponse
45+
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
4646
from ._base_requests_async import RequestsAsyncTransportBase
4747

4848

@@ -54,15 +54,22 @@ class TrioStreamDownloadGenerator(AsyncIterator):
5454
5555
:param pipeline: The pipeline object
5656
:param response: The response object.
57+
:keyword bool decompress: If True which is default, will attempt to decode the body based
58+
on the ‘content-encoding’ header.
5759
"""
58-
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
60+
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
5961
self.pipeline = pipeline
6062
self.request = response.request
6163
self.response = response
6264
self.block_size = response.block_size
63-
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
65+
decompress = kwargs.pop("decompress", True)
66+
if len(kwargs) > 0:
67+
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
68+
if decompress:
69+
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
70+
else:
71+
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
6472
self.content_length = int(response.headers.get('Content-Length', 0))
65-
self.downloaded = 0
6673

6774
def __len__(self):
6875
return self.content_length
@@ -95,10 +102,10 @@ async def __anext__(self):
95102
class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
96103
"""Asynchronous streaming of data from the response.
97104
"""
98-
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
105+
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore
99106
"""Generator for streaming response data.
100107
"""
101-
return TrioStreamDownloadGenerator(pipeline, self)
108+
return TrioStreamDownloadGenerator(pipeline, self, **kwargs)
102109

103110

104111
class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore

sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,18 @@
1717

1818
@pytest.mark.asyncio
1919
async def test_connection_error_response():
20+
class MockSession(object):
21+
def __init__(self):
22+
self.auto_decompress = True
23+
24+
@property
25+
def auto_decompress(self):
26+
return self.auto_decompress
27+
2028
class MockTransport(AsyncHttpTransport):
2129
def __init__(self):
2230
self._count = 0
31+
self.session = MockSession
2332

2433
async def __aexit__(self, exc_type, exc_val, exc_tb):
2534
pass
@@ -60,7 +69,7 @@ async def __call__(self, *args, **kwargs):
6069
pipeline = AsyncPipeline(MockTransport())
6170
http_response = AsyncHttpResponse(http_request, None)
6271
http_response.internal_response = MockInternalResponse()
63-
stream = AioHttpStreamDownloadGenerator(pipeline, http_response)
72+
stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False)
6473
with mock.patch('asyncio.sleep', new_callable=AsyncMock):
6574
with pytest.raises(ConnectionError):
6675
await stream.__anext__()
@@ -75,6 +84,8 @@ async def test_response_streaming_error_behavior():
7584

7685
class FakeStreamWithConnectionError:
7786
# fake object for urllib3.response.HTTPResponse
87+
def __init__(self):
88+
self.total_response_size = 500
7889

7990
def stream(self, chunk_size, decode_content=False):
8091
assert chunk_size == block_size
@@ -86,6 +97,15 @@ def stream(self, chunk_size, decode_content=False):
8697
left -= len(data)
8798
yield data
8899

100+
def read(self, chunk_size, decode_content=False):
101+
assert chunk_size == block_size
102+
if self.total_response_size > 0:
103+
if self.total_response_size <= block_size:
104+
raise requests.exceptions.ConnectionError()
105+
data = b"X" * min(chunk_size, self.total_response_size)
106+
self.total_response_size -= len(data)
107+
return data
108+
89109
def close(self):
90110
pass
91111

0 commit comments

Comments
 (0)