Skip to content

Commit 5a6d5a8

Browse files
authored
Improve recording decorators (#44180)
1 parent 06a83c8 commit 5a6d5a8

File tree

10 files changed

+401
-185
lines changed

10 files changed

+401
-185
lines changed

eng/tools/azure-sdk-tools/devtools_testutils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .exceptions import AzureTestError, ReservedResourceNameError
2121
from .proxy_fixtures import environment_variables, recorded_test, variable_recorder
2222
from .proxy_startup import start_test_proxy, stop_test_proxy, test_proxy
23-
from .proxy_testcase import recorded_by_proxy
23+
from .proxy_testcase import recorded_by_proxy, RecordedTransport
2424
from .sanitizers import (
2525
add_api_version_transform,
2626
add_batch_sanitizers,
@@ -105,6 +105,7 @@
105105
"EnvironmentVariableLoader",
106106
"environment_variables",
107107
"recorded_by_proxy",
108+
"RecordedTransport",
108109
"recorded_test",
109110
"test_proxy",
110111
"trim_kwargs_from_test_function",

eng/tools/azure-sdk-tools/devtools_testutils/aio/proxy_testcase_async.py

Lines changed: 156 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,99 +5,183 @@
55
# --------------------------------------------------------------------------
66
import logging
77
import urllib.parse as url_parse
8+
from functools import wraps
89

910
from azure.core.exceptions import ResourceNotFoundError
1011
from azure.core.pipeline.policies import ContentDecodePolicy
1112
from azure.core.pipeline.transport import AioHttpTransport
1213

14+
try:
15+
import httpx
16+
17+
AsyncHTTPXTransport = httpx.AsyncHTTPTransport
18+
except ImportError:
19+
httpx = None
20+
AsyncHTTPXTransport = None
21+
1322
from ..helpers import is_live_and_not_recording, trim_kwargs_from_test_function
1423
from ..proxy_testcase import (
24+
RecordedTransport,
25+
_transform_args,
26+
_transform_httpx_args,
1527
get_test_id,
1628
start_record_or_playback,
17-
transform_request,
29+
restore_httpx_response_url,
1830
stop_record_or_playback,
1931
)
2032

2133

22-
def recorded_by_proxy_async(test_func):
23-
"""Decorator that redirects network requests to target the azure-sdk-tools test proxy. Use with recorded tests.
24-
25-
For more details and usage examples, refer to
26-
https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/tests.md#write-or-run-tests
34+
def recorded_by_proxy_async(*transports):
2735
"""
36+
Decorator for recording and playing back test proxy sessions in async tests.
2837
29-
async def record_wrap(*args, **kwargs):
30-
def transform_args(*args, **kwargs):
31-
copied_positional_args = list(args)
32-
request = copied_positional_args[1]
38+
Args:
39+
*transports: Which transport(s) to record. Pass one or more comma separated RecordedTransport enum values.
40+
- No args (default): Record AioHttpTransport.send calls (azure.core).
41+
- RecordedTransport.AZURE_CORE: Record AioHttpTransport.send calls. Same as the default above.
42+
- RecordedTransport.HTTPX: Record AsyncHTTPXTransport.handle_async_request calls.
43+
- RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX: Record both transports.
3344
34-
transform_request(request, recording_id)
45+
Usages:
3546
36-
return tuple(copied_positional_args), kwargs
47+
from devtools_testutils.aio import recorded_by_proxy_async
48+
from devtools_testutils import RecordedTransport
3749
38-
trimmed_kwargs = {k: v for k, v in kwargs.items()}
39-
trim_kwargs_from_test_function(test_func, trimmed_kwargs)
50+
# If your test uses azure.core only network calls (default)
51+
@recorded_by_proxy_async
52+
async def test(...): ...
4053
41-
if is_live_and_not_recording():
42-
return await test_func(*args, **trimmed_kwargs)
54+
# Explicitly enable azure.core recordings only (equivalent to the above)
55+
@recorded_by_proxy_async(RecordedTransport.AZURE_CORE)
56+
async def test(...): ...
4357
44-
test_id = get_test_id()
45-
recording_id, variables = start_record_or_playback(test_id)
46-
original_transport_func = AioHttpTransport.send
58+
# If your test uses httpx only for network calls
59+
@recorded_by_proxy_async(RecordedTransport.HTTPX)
60+
async def test(...): ...
4761
48-
async def combined_call(*args, **kwargs):
49-
adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs)
50-
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
51-
52-
# make the x-recording-upstream-base-uri the URL of the request
53-
# this makes the request look like it was made to the original endpoint instead of to the proxy
54-
# without this, things like LROPollers can get broken by polling the wrong endpoint
55-
parsed_result = url_parse.urlparse(result.request.url)
56-
upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
57-
upstream_uri_dict = {
58-
"scheme": upstream_uri.scheme,
59-
"netloc": upstream_uri.netloc,
60-
}
61-
original_target = parsed_result._replace(**upstream_uri_dict).geturl()
62-
63-
result.request.url = original_target
64-
return result
62+
# If your test uses both azure.core and httpx for network calls
63+
@recorded_by_proxy_async(RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX)
64+
async def test(...): ...
65+
"""
6566

66-
AioHttpTransport.send = combined_call
67+
# Bare decorator usage: @recorded_by_proxy_async
68+
if len(transports) == 1 and callable(transports[0]):
69+
test_func = transports[0]
70+
transport_list = [(AioHttpTransport, "send")]
71+
return _make_proxy_decorator_async(transport_list)(test_func)
72+
73+
# Parameterized decorator usage: @recorded_by_proxy_async(...)
74+
# Determine which transports to use
75+
transport_list = []
76+
77+
# If no transports specified, default to azure.core
78+
transport_set = set(transports) if transports else {RecordedTransport.AZURE_CORE}
79+
80+
# Add transports based on what's in the set
81+
for transport in transport_set:
82+
if transport == RecordedTransport.AZURE_CORE or (
83+
isinstance(transport, str) and transport == RecordedTransport.AZURE_CORE.value
84+
):
85+
transport_list.append((AioHttpTransport, "send"))
86+
elif transport == RecordedTransport.HTTPX or (
87+
isinstance(transport, str) and transport == RecordedTransport.HTTPX.value
88+
):
89+
if AsyncHTTPXTransport is not None:
90+
transport_list.append((AsyncHTTPXTransport, "handle_async_request"))
91+
92+
# If still no transports, fall back to azure.core
93+
if not transport_list:
94+
transport_list = [(AioHttpTransport, "send")]
95+
96+
# Return a decorator function that will be applied to the test function
97+
return lambda test_func: _make_proxy_decorator_async(transport_list)(test_func)
98+
99+
100+
def _make_proxy_decorator_async(transports):
101+
def _decorator(test_func):
102+
@wraps(test_func)
103+
async def record_wrap(*args, **kwargs):
104+
# ---- your existing trimming/early-exit logic ----
105+
trimmed_kwargs = {k: v for k, v in kwargs.items()}
106+
trim_kwargs_from_test_function(test_func, trimmed_kwargs)
107+
108+
if is_live_and_not_recording():
109+
return await test_func(*args, **trimmed_kwargs)
110+
111+
test_id = get_test_id()
112+
recording_id, variables = start_record_or_playback(test_id)
113+
114+
# Build a wrapper factory so each patched method closes over its own original
115+
def make_combined_call(original_transport_func, is_httpx=False):
116+
async def combined_call(*call_args, **call_kwargs):
117+
if is_httpx:
118+
adjusted_args, adjusted_kwargs = _transform_httpx_args(recording_id, *call_args, **call_kwargs)
119+
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
120+
restore_httpx_response_url(result)
121+
else:
122+
adjusted_args, adjusted_kwargs = _transform_args(recording_id, *call_args, **call_kwargs)
123+
result = await original_transport_func(*adjusted_args, **adjusted_kwargs)
124+
# rewrite request.url to the original upstream for LROs, etc.
125+
parsed_result = url_parse.urlparse(result.request.url)
126+
upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"])
127+
upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc}
128+
original_target = parsed_result._replace(**upstream_uri_dict).geturl()
129+
result.request.url = original_target
130+
return result
131+
132+
return combined_call
133+
134+
# Patch multiple transports and ensure restoration
135+
test_variables = None
136+
test_run = False
137+
originals = []
138+
# monkeypatch all requested transports
139+
for owner, name in transports:
140+
original = getattr(owner, name)
141+
# Check if this is an httpx transport by comparing with httpx transport classes
142+
is_httpx_transport = (AsyncHTTPXTransport is not None and owner is AsyncHTTPXTransport) or (
143+
httpx is not None and owner.__module__.startswith("httpx")
144+
)
145+
setattr(owner, name, make_combined_call(original, is_httpx=is_httpx_transport))
146+
originals.append((owner, name, original))
67147

68-
# call the modified function
69-
# we define test_variables before invoking the test so the variable is defined in case of an exception
70-
test_variables = None
71-
# this tracks whether the test has been run yet; used when calling the test function with/without `variables`
72-
# running without `variables` in the `except` block leads to unnecessary exceptions in test execution output
73-
test_run = False
74-
try:
75148
try:
76-
test_variables = await test_func(*args, variables=variables, **trimmed_kwargs)
77-
test_run = True
78-
except TypeError as error:
79-
if "unexpected keyword argument" in str(error) and "variables" in str(error):
80-
logger = logging.getLogger()
81-
logger.info(
82-
"This test can't accept variables as input. The test method should accept `**kwargs` and/or a "
83-
"`variables` parameter to make use of recorded test variables."
84-
)
85-
else:
86-
raise error
87-
# if the test couldn't accept `variables`, run the test without passing them
88-
if not test_run:
89-
test_variables = await test_func(*args, **trimmed_kwargs)
90-
91-
except ResourceNotFoundError as error:
92-
error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
93-
message = error_body.get("message") or error_body.get("Message")
94-
error_with_message = ResourceNotFoundError(message=message, response=error.response)
95-
raise error_with_message from error
96-
97-
finally:
98-
AioHttpTransport.send = original_transport_func
99-
stop_record_or_playback(test_id, recording_id, test_variables)
100-
101-
return test_variables
102-
103-
return record_wrap
149+
try:
150+
test_variables = await test_func(*args, variables=variables, **trimmed_kwargs)
151+
test_run = True
152+
except TypeError as error:
153+
if "unexpected keyword argument" in str(error) and "variables" in str(error):
154+
logger = logging.getLogger()
155+
logger.info(
156+
"This test can't accept variables as input. "
157+
"Accept `**kwargs` and/or a `variables` parameter to use recorded variables."
158+
)
159+
else:
160+
raise
161+
162+
if not test_run:
163+
test_variables = await test_func(*args, **trimmed_kwargs)
164+
165+
except ResourceNotFoundError as error:
166+
error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response)
167+
troubleshoot = (
168+
"Playback failure -- for help resolving, see https://aka.ms/azsdk/python/test-proxy/troubleshoot."
169+
)
170+
message = error_body.get("message") or error_body.get("Message")
171+
error_with_message = ResourceNotFoundError(
172+
message=f"{troubleshoot} Error details:\n{message}",
173+
response=error.response,
174+
)
175+
raise error_with_message from error
176+
177+
finally:
178+
# restore in reverse order
179+
for owner, name, original in reversed(originals):
180+
setattr(owner, name, original)
181+
stop_record_or_playback(test_id, recording_id, test_variables)
182+
183+
return test_variables
184+
185+
return record_wrap
186+
187+
return _decorator

0 commit comments

Comments
 (0)