diff --git a/httpx/_client.py b/httpx/_client.py index 13cd933673..aabb03897a 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -1726,8 +1726,13 @@ async def _send_single_request(self, request: Request) -> Response: "Attempted to send a sync request with an AsyncClient instance." ) - with request_context(request=request): - response = await transport.handle_async_request(request) + try: + with request_context(request=request): + response = await transport.handle_async_request(request) + except BaseException: + if hasattr(request.stream, "aclose"): + await request.stream.aclose() + raise assert isinstance(response.stream, AsyncByteStream) response.request = request diff --git a/httpx/_content.py b/httpx/_content.py index 6f479a0885..1112870bd1 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -28,15 +28,39 @@ __all__ = ["ByteStream"] +class _ByteStreamIterator: + def __init__(self, data: bytes) -> None: + self._data = data + self._consumed = False + + def __aiter__(self) -> AsyncIterator[bytes]: + return self + + async def __anext__(self) -> bytes: + if self._consumed: + raise StopAsyncIteration + self._consumed = True + return self._data + + async def aclose(self) -> None: + self._consumed = True + + class ByteStream(AsyncByteStream, SyncByteStream): def __init__(self, stream: bytes) -> None: self._stream = stream + self._iterator: _ByteStreamIterator | None = None def __iter__(self) -> Iterator[bytes]: yield self._stream - async def __aiter__(self) -> AsyncIterator[bytes]: - yield self._stream + def __aiter__(self) -> AsyncIterator[bytes]: + self._iterator = _ByteStreamIterator(self._stream) + return self._iterator + + async def aclose(self) -> None: + if self._iterator is not None: + await self._iterator.aclose() class IteratorByteStream(SyncByteStream):