|
5 | 5 | # -------------------------------------------------------------------------- |
6 | 6 | import logging |
7 | 7 | import urllib.parse as url_parse |
| 8 | +from functools import wraps |
8 | 9 |
|
9 | 10 | from azure.core.exceptions import ResourceNotFoundError |
10 | 11 | from azure.core.pipeline.policies import ContentDecodePolicy |
11 | 12 | from azure.core.pipeline.transport import AioHttpTransport |
12 | 13 |
|
| 14 | +try: |
| 15 | + import httpx |
| 16 | + |
| 17 | + AsyncHTTPXTransport = httpx.AsyncHTTPTransport |
| 18 | +except ImportError: |
| 19 | + httpx = None |
| 20 | + AsyncHTTPXTransport = None |
| 21 | + |
13 | 22 | from ..helpers import is_live_and_not_recording, trim_kwargs_from_test_function |
14 | 23 | from ..proxy_testcase import ( |
| 24 | + RecordedTransport, |
| 25 | + _transform_args, |
| 26 | + _transform_httpx_args, |
15 | 27 | get_test_id, |
16 | 28 | start_record_or_playback, |
17 | | - transform_request, |
| 29 | + restore_httpx_response_url, |
18 | 30 | stop_record_or_playback, |
19 | 31 | ) |
20 | 32 |
|
21 | 33 |
|
22 | | -def recorded_by_proxy_async(test_func): |
23 | | - """Decorator that redirects network requests to target the azure-sdk-tools test proxy. Use with recorded tests. |
24 | | -
|
25 | | - For more details and usage examples, refer to |
26 | | - https://github.com/Azure/azure-sdk-for-python/blob/main/doc/dev/tests.md#write-or-run-tests |
| 34 | +def recorded_by_proxy_async(*transports): |
27 | 35 | """ |
| 36 | + Decorator for recording and playing back test proxy sessions in async tests. |
28 | 37 |
|
29 | | - async def record_wrap(*args, **kwargs): |
30 | | - def transform_args(*args, **kwargs): |
31 | | - copied_positional_args = list(args) |
32 | | - request = copied_positional_args[1] |
| 38 | + Args: |
| 39 | + *transports: Which transport(s) to record. Pass one or more comma separated RecordedTransport enum values. |
| 40 | + - No args (default): Record AioHttpTransport.send calls (azure.core). |
| 41 | + - RecordedTransport.AZURE_CORE: Record AioHttpTransport.send calls. Same as the default above. |
| 42 | + - RecordedTransport.HTTPX: Record AsyncHTTPXTransport.handle_async_request calls. |
| 43 | + - RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX: Record both transports. |
33 | 44 |
|
34 | | - transform_request(request, recording_id) |
| 45 | + Usages: |
35 | 46 |
|
36 | | - return tuple(copied_positional_args), kwargs |
| 47 | + from devtools_testutils.aio import recorded_by_proxy_async |
| 48 | + from devtools_testutils import RecordedTransport |
37 | 49 |
|
38 | | - trimmed_kwargs = {k: v for k, v in kwargs.items()} |
39 | | - trim_kwargs_from_test_function(test_func, trimmed_kwargs) |
| 50 | + # If your test uses azure.core only network calls (default) |
| 51 | + @recorded_by_proxy_async |
| 52 | + async def test(...): ... |
40 | 53 |
|
41 | | - if is_live_and_not_recording(): |
42 | | - return await test_func(*args, **trimmed_kwargs) |
| 54 | + # Explicitly enable azure.core recordings only (equivalent to the above) |
| 55 | + @recorded_by_proxy_async(RecordedTransport.AZURE_CORE) |
| 56 | + async def test(...): ... |
43 | 57 |
|
44 | | - test_id = get_test_id() |
45 | | - recording_id, variables = start_record_or_playback(test_id) |
46 | | - original_transport_func = AioHttpTransport.send |
| 58 | + # If your test uses httpx only for network calls |
| 59 | + @recorded_by_proxy_async(RecordedTransport.HTTPX) |
| 60 | + async def test(...): ... |
47 | 61 |
|
48 | | - async def combined_call(*args, **kwargs): |
49 | | - adjusted_args, adjusted_kwargs = transform_args(*args, **kwargs) |
50 | | - result = await original_transport_func(*adjusted_args, **adjusted_kwargs) |
51 | | - |
52 | | - # make the x-recording-upstream-base-uri the URL of the request |
53 | | - # this makes the request look like it was made to the original endpoint instead of to the proxy |
54 | | - # without this, things like LROPollers can get broken by polling the wrong endpoint |
55 | | - parsed_result = url_parse.urlparse(result.request.url) |
56 | | - upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"]) |
57 | | - upstream_uri_dict = { |
58 | | - "scheme": upstream_uri.scheme, |
59 | | - "netloc": upstream_uri.netloc, |
60 | | - } |
61 | | - original_target = parsed_result._replace(**upstream_uri_dict).geturl() |
62 | | - |
63 | | - result.request.url = original_target |
64 | | - return result |
| 62 | + # If your test uses both azure.core and httpx for network calls |
| 63 | + @recorded_by_proxy_async(RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX) |
| 64 | + async def test(...): ... |
| 65 | + """ |
65 | 66 |
|
66 | | - AioHttpTransport.send = combined_call |
| 67 | + # Bare decorator usage: @recorded_by_proxy_async |
| 68 | + if len(transports) == 1 and callable(transports[0]): |
| 69 | + test_func = transports[0] |
| 70 | + transport_list = [(AioHttpTransport, "send")] |
| 71 | + return _make_proxy_decorator_async(transport_list)(test_func) |
| 72 | + |
| 73 | + # Parameterized decorator usage: @recorded_by_proxy_async(...) |
| 74 | + # Determine which transports to use |
| 75 | + transport_list = [] |
| 76 | + |
| 77 | + # If no transports specified, default to azure.core |
| 78 | + transport_set = set(transports) if transports else {RecordedTransport.AZURE_CORE} |
| 79 | + |
| 80 | + # Add transports based on what's in the set |
| 81 | + for transport in transport_set: |
| 82 | + if transport == RecordedTransport.AZURE_CORE or ( |
| 83 | + isinstance(transport, str) and transport == RecordedTransport.AZURE_CORE.value |
| 84 | + ): |
| 85 | + transport_list.append((AioHttpTransport, "send")) |
| 86 | + elif transport == RecordedTransport.HTTPX or ( |
| 87 | + isinstance(transport, str) and transport == RecordedTransport.HTTPX.value |
| 88 | + ): |
| 89 | + if AsyncHTTPXTransport is not None: |
| 90 | + transport_list.append((AsyncHTTPXTransport, "handle_async_request")) |
| 91 | + |
| 92 | + # If still no transports, fall back to azure.core |
| 93 | + if not transport_list: |
| 94 | + transport_list = [(AioHttpTransport, "send")] |
| 95 | + |
| 96 | + # Return a decorator function that will be applied to the test function |
| 97 | + return lambda test_func: _make_proxy_decorator_async(transport_list)(test_func) |
| 98 | + |
| 99 | + |
| 100 | +def _make_proxy_decorator_async(transports): |
| 101 | + def _decorator(test_func): |
| 102 | + @wraps(test_func) |
| 103 | + async def record_wrap(*args, **kwargs): |
| 104 | + # ---- your existing trimming/early-exit logic ---- |
| 105 | + trimmed_kwargs = {k: v for k, v in kwargs.items()} |
| 106 | + trim_kwargs_from_test_function(test_func, trimmed_kwargs) |
| 107 | + |
| 108 | + if is_live_and_not_recording(): |
| 109 | + return await test_func(*args, **trimmed_kwargs) |
| 110 | + |
| 111 | + test_id = get_test_id() |
| 112 | + recording_id, variables = start_record_or_playback(test_id) |
| 113 | + |
| 114 | + # Build a wrapper factory so each patched method closes over its own original |
| 115 | + def make_combined_call(original_transport_func, is_httpx=False): |
| 116 | + async def combined_call(*call_args, **call_kwargs): |
| 117 | + if is_httpx: |
| 118 | + adjusted_args, adjusted_kwargs = _transform_httpx_args(recording_id, *call_args, **call_kwargs) |
| 119 | + result = await original_transport_func(*adjusted_args, **adjusted_kwargs) |
| 120 | + restore_httpx_response_url(result) |
| 121 | + else: |
| 122 | + adjusted_args, adjusted_kwargs = _transform_args(recording_id, *call_args, **call_kwargs) |
| 123 | + result = await original_transport_func(*adjusted_args, **adjusted_kwargs) |
| 124 | + # rewrite request.url to the original upstream for LROs, etc. |
| 125 | + parsed_result = url_parse.urlparse(result.request.url) |
| 126 | + upstream_uri = url_parse.urlparse(result.request.headers["x-recording-upstream-base-uri"]) |
| 127 | + upstream_uri_dict = {"scheme": upstream_uri.scheme, "netloc": upstream_uri.netloc} |
| 128 | + original_target = parsed_result._replace(**upstream_uri_dict).geturl() |
| 129 | + result.request.url = original_target |
| 130 | + return result |
| 131 | + |
| 132 | + return combined_call |
| 133 | + |
| 134 | + # Patch multiple transports and ensure restoration |
| 135 | + test_variables = None |
| 136 | + test_run = False |
| 137 | + originals = [] |
| 138 | + # monkeypatch all requested transports |
| 139 | + for owner, name in transports: |
| 140 | + original = getattr(owner, name) |
| 141 | + # Check if this is an httpx transport by comparing with httpx transport classes |
| 142 | + is_httpx_transport = (AsyncHTTPXTransport is not None and owner is AsyncHTTPXTransport) or ( |
| 143 | + httpx is not None and owner.__module__.startswith("httpx") |
| 144 | + ) |
| 145 | + setattr(owner, name, make_combined_call(original, is_httpx=is_httpx_transport)) |
| 146 | + originals.append((owner, name, original)) |
67 | 147 |
|
68 | | - # call the modified function |
69 | | - # we define test_variables before invoking the test so the variable is defined in case of an exception |
70 | | - test_variables = None |
71 | | - # this tracks whether the test has been run yet; used when calling the test function with/without `variables` |
72 | | - # running without `variables` in the `except` block leads to unnecessary exceptions in test execution output |
73 | | - test_run = False |
74 | | - try: |
75 | 148 | try: |
76 | | - test_variables = await test_func(*args, variables=variables, **trimmed_kwargs) |
77 | | - test_run = True |
78 | | - except TypeError as error: |
79 | | - if "unexpected keyword argument" in str(error) and "variables" in str(error): |
80 | | - logger = logging.getLogger() |
81 | | - logger.info( |
82 | | - "This test can't accept variables as input. The test method should accept `**kwargs` and/or a " |
83 | | - "`variables` parameter to make use of recorded test variables." |
84 | | - ) |
85 | | - else: |
86 | | - raise error |
87 | | - # if the test couldn't accept `variables`, run the test without passing them |
88 | | - if not test_run: |
89 | | - test_variables = await test_func(*args, **trimmed_kwargs) |
90 | | - |
91 | | - except ResourceNotFoundError as error: |
92 | | - error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response) |
93 | | - message = error_body.get("message") or error_body.get("Message") |
94 | | - error_with_message = ResourceNotFoundError(message=message, response=error.response) |
95 | | - raise error_with_message from error |
96 | | - |
97 | | - finally: |
98 | | - AioHttpTransport.send = original_transport_func |
99 | | - stop_record_or_playback(test_id, recording_id, test_variables) |
100 | | - |
101 | | - return test_variables |
102 | | - |
103 | | - return record_wrap |
| 149 | + try: |
| 150 | + test_variables = await test_func(*args, variables=variables, **trimmed_kwargs) |
| 151 | + test_run = True |
| 152 | + except TypeError as error: |
| 153 | + if "unexpected keyword argument" in str(error) and "variables" in str(error): |
| 154 | + logger = logging.getLogger() |
| 155 | + logger.info( |
| 156 | + "This test can't accept variables as input. " |
| 157 | + "Accept `**kwargs` and/or a `variables` parameter to use recorded variables." |
| 158 | + ) |
| 159 | + else: |
| 160 | + raise |
| 161 | + |
| 162 | + if not test_run: |
| 163 | + test_variables = await test_func(*args, **trimmed_kwargs) |
| 164 | + |
| 165 | + except ResourceNotFoundError as error: |
| 166 | + error_body = ContentDecodePolicy.deserialize_from_http_generics(error.response) |
| 167 | + troubleshoot = ( |
| 168 | + "Playback failure -- for help resolving, see https://aka.ms/azsdk/python/test-proxy/troubleshoot." |
| 169 | + ) |
| 170 | + message = error_body.get("message") or error_body.get("Message") |
| 171 | + error_with_message = ResourceNotFoundError( |
| 172 | + message=f"{troubleshoot} Error details:\n{message}", |
| 173 | + response=error.response, |
| 174 | + ) |
| 175 | + raise error_with_message from error |
| 176 | + |
| 177 | + finally: |
| 178 | + # restore in reverse order |
| 179 | + for owner, name, original in reversed(originals): |
| 180 | + setattr(owner, name, original) |
| 181 | + stop_record_or_playback(test_id, recording_id, test_variables) |
| 182 | + |
| 183 | + return test_variables |
| 184 | + |
| 185 | + return record_wrap |
| 186 | + |
| 187 | + return _decorator |
0 commit comments