Skip to content

Commit 30b196e

Browse files
authored
[rest] add backcompat mixin to rest requests (Azure#20599)
1 parent 4b3397d commit 30b196e

40 files changed

+1873
-854
lines changed

sdk/core/azure-core/azure/core/_pipeline_client.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,6 @@
6565

6666
_LOGGER = logging.getLogger(__name__)
6767

68-
def _prepare_request(request):
69-
# returns the request ready to run through pipelines
70-
# and a bool telling whether we ended up converting it
71-
rest_request = False
72-
try:
73-
request_to_run = request._to_pipeline_transport_request() # pylint: disable=protected-access
74-
rest_request = True
75-
except AttributeError:
76-
request_to_run = request
77-
return rest_request, request_to_run
78-
7968
class PipelineClient(PipelineClientBase):
8069
"""Service client core methods.
8170
@@ -204,9 +193,9 @@ def send_request(self, request, **kwargs):
204193
:return: The response of your network call. Does not do error handling on your response.
205194
:rtype: ~azure.core.rest.HttpResponse
206195
# """
207-
rest_request, request_to_run = _prepare_request(request)
196+
rest_request = hasattr(request, "content")
208197
return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
209-
pipeline_response = self._pipeline.run(request_to_run, **kwargs) # pylint: disable=protected-access
198+
pipeline_response = self._pipeline.run(request, **kwargs) # pylint: disable=protected-access
210199
response = pipeline_response.http_response
211200
if rest_request:
212201
response = _to_rest_response(response)

sdk/core/azure-core/azure/core/_pipeline_client_async.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
RequestIdPolicy,
3838
AsyncRetryPolicy,
3939
)
40-
from ._pipeline_client import _prepare_request
4140
from .pipeline._tools_async import to_rest_response as _to_rest_response
4241

4342
try:
@@ -175,10 +174,10 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
175174
return AsyncPipeline(transport, policies)
176175

177176
async def _make_pipeline_call(self, request, **kwargs):
178-
rest_request, request_to_run = _prepare_request(request)
177+
rest_request = hasattr(request, "content")
179178
return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
180179
pipeline_response = await self._pipeline.run(
181-
request_to_run, **kwargs # pylint: disable=protected-access
180+
request, **kwargs # pylint: disable=protected-access
182181
)
183182
response = pipeline_response.http_response
184183
if rest_request:

sdk/core/azure-core/azure/core/pipeline/transport/_base.py

Lines changed: 10 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from io import BytesIO
3535
import json
3636
import logging
37-
import os
3837
import time
3938
import copy
4039

@@ -50,7 +49,6 @@
5049
TYPE_CHECKING,
5150
Generic,
5251
TypeVar,
53-
cast,
5452
IO,
5553
List,
5654
Union,
@@ -63,7 +61,7 @@
6361
Type
6462
)
6563

66-
from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse
64+
from six.moves.http_client import HTTPResponse as _HTTPResponse
6765

6866
from azure.core.exceptions import HttpResponseError
6967
from azure.core.pipeline import (
@@ -75,6 +73,12 @@
7573
)
7674
from .._tools import await_result as _await_result
7775
from ...utils._utils import _case_insensitive_dict
76+
from ...utils._pipeline_transport_rest_shared import (
77+
_format_parameters_helper,
78+
_prepare_multipart_body_helper,
79+
_serialize_request,
80+
_format_data_helper,
81+
)
7882

7983

8084
if TYPE_CHECKING:
@@ -127,36 +131,6 @@ def _urljoin(base_url, stub_url):
127131
parsed = parsed._replace(path=parsed.path.rstrip("/") + "/" + stub_url)
128132
return parsed.geturl()
129133

130-
131-
class _HTTPSerializer(HTTPConnection, object):
132-
"""Hacking the stdlib HTTPConnection to serialize HTTP request as strings.
133-
"""
134-
135-
def __init__(self, *args, **kwargs):
136-
self.buffer = b""
137-
kwargs.setdefault("host", "fakehost")
138-
super(_HTTPSerializer, self).__init__(*args, **kwargs)
139-
140-
def putheader(self, header, *values):
141-
if header in ["Host", "Accept-Encoding"]:
142-
return
143-
super(_HTTPSerializer, self).putheader(header, *values)
144-
145-
def send(self, data):
146-
self.buffer += data
147-
148-
149-
def _serialize_request(http_request):
150-
serializer = _HTTPSerializer()
151-
serializer.request(
152-
method=http_request.method,
153-
url=http_request.url,
154-
body=http_request.body,
155-
headers=http_request.headers,
156-
)
157-
return serializer.buffer
158-
159-
160134
class HttpTransport(
161135
AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType]
162136
): # type: ignore
@@ -253,16 +227,7 @@ def _format_data(data):
253227
:param data: The request field data.
254228
:type data: str or file-like object.
255229
"""
256-
if hasattr(data, "read"):
257-
data = cast(IO, data)
258-
data_name = None
259-
try:
260-
if data.name[0] != "<" and data.name[-1] != ">":
261-
data_name = os.path.basename(data.name)
262-
except (AttributeError, TypeError):
263-
pass
264-
return (data_name, data, "application/octet-stream")
265-
return (None, cast(str, data))
230+
return _format_data_helper(data)
266231

267232
def format_parameters(self, params):
268233
# type: (Dict[str, str]) -> None
@@ -272,26 +237,7 @@ def format_parameters(self, params):
272237
273238
:param dict params: A dictionary of parameters.
274239
"""
275-
query = urlparse(self.url).query
276-
if query:
277-
self.url = self.url.partition("?")[0]
278-
existing_params = {
279-
p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]
280-
}
281-
params.update(existing_params)
282-
query_params = []
283-
for k, v in params.items():
284-
if isinstance(v, list):
285-
for w in v:
286-
if w is None:
287-
raise ValueError("Query parameter {} cannot be None".format(k))
288-
query_params.append("{}={}".format(k, w))
289-
else:
290-
if v is None:
291-
raise ValueError("Query parameter {} cannot be None".format(k))
292-
query_params.append("{}={}".format(k, v))
293-
query = "?" + "&".join(query_params)
294-
self.url = self.url + query
240+
return _format_parameters_helper(self, params)
295241

296242
def set_streamed_data_body(self, data):
297243
"""Set a streamable data body.
@@ -416,54 +362,7 @@ def prepare_multipart_body(self, content_index=0):
416362
:returns: The updated index after all parts in this request have been added.
417363
:rtype: int
418364
"""
419-
if not self.multipart_mixed_info:
420-
return 0
421-
422-
requests = self.multipart_mixed_info[0] # type: List[HttpRequest]
423-
boundary = self.multipart_mixed_info[2] # type: Optional[str]
424-
425-
# Update the main request with the body
426-
main_message = Message()
427-
main_message.add_header("Content-Type", "multipart/mixed")
428-
if boundary:
429-
main_message.set_boundary(boundary)
430-
431-
for req in requests:
432-
part_message = Message()
433-
if req.multipart_mixed_info:
434-
content_index = req.prepare_multipart_body(content_index=content_index)
435-
part_message.add_header("Content-Type", req.headers['Content-Type'])
436-
payload = req.serialize()
437-
# We need to remove the ~HTTP/1.1 prefix along with the added content-length
438-
payload = payload[payload.index(b'--'):]
439-
else:
440-
part_message.add_header("Content-Type", "application/http")
441-
part_message.add_header("Content-Transfer-Encoding", "binary")
442-
part_message.add_header("Content-ID", str(content_index))
443-
payload = req.serialize()
444-
content_index += 1
445-
part_message.set_payload(payload)
446-
main_message.attach(part_message)
447-
448-
try:
449-
from email.policy import HTTP
450-
451-
full_message = main_message.as_bytes(policy=HTTP)
452-
eol = b"\r\n"
453-
except ImportError: # Python 2.7
454-
# Right now we decide to not support Python 2.7 on serialization, since
455-
# it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it)
456-
raise NotImplementedError(
457-
"Multipart request are not supported on Python 2.7"
458-
)
459-
# full_message = main_message.as_string()
460-
# eol = b'\n'
461-
_, _, body = full_message.split(eol, 2)
462-
self.set_bytes_body(body)
463-
self.headers["Content-Type"] = (
464-
"multipart/mixed; boundary=" + main_message.get_boundary()
465-
)
466-
return content_index
365+
return _prepare_multipart_body_helper(self, content_index)
467366

468367
def serialize(self):
469368
# type: () -> bytes

0 commit comments

Comments
 (0)