Skip to content

Commit f1d0d20

Browse files
authored
Centralize credential pipeline construction (Azure#19864)
1 parent cc7b25e commit f1d0d20

File tree

13 files changed

+259
-313
lines changed

13 files changed

+259
-313
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
([#19943](https://github.com/Azure/azure-sdk-for-python/issues/19943))
1717

1818
### Other Changes
19+
- Added `CustomHookPolicy` to credential HTTP pipelines. This allows applications
20+
to initialize credentials with `raw_request_hook` and `raw_response_hook`
21+
keyword arguments. The value of these arguments should be a callback taking a
22+
`PipelineRequest` and `PipelineResponse`, respectively. For example:
23+
`ManagedIdentityCredential(raw_request_hook=lambda request: print(request.http_request.url))`
1924
- Reduced redundant `ChainedTokenCredential` and `DefaultAzureCredential`
2025
logging. On Python 3.7+, credentials invoked by these classes now log debug
2126
rather than info messages.
@@ -25,6 +30,7 @@
2530
fails
2631
([#19989](https://github.com/Azure/azure-sdk-for-python/issues/19989))
2732

33+
2834
## 1.7.0b2 (2021-07-08)
2935
### Features Added
3036
- `InteractiveBrowserCredential` keyword argument `login_hint` enables

sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,16 @@
88

99
from azure.core.exceptions import ClientAuthenticationError
1010
from azure.core.pipeline.transport import HttpRequest
11-
from azure.core.pipeline.policies import (
12-
DistributedTracingPolicy,
13-
HttpLoggingPolicy,
14-
HTTPPolicy,
15-
UserAgentPolicy,
16-
NetworkTraceLoggingPolicy,
17-
)
11+
from azure.core.pipeline.policies import HTTPPolicy
1812

1913
from .. import CredentialUnavailableError
2014
from .._constants import EnvironmentVariables
21-
from .._internal.managed_identity_client import ManagedIdentityClient, _get_configuration
15+
from .._internal.managed_identity_client import ManagedIdentityClient
2216
from .._internal.get_token_mixin import GetTokenMixin
23-
from .._internal.user_agent import USER_AGENT
2417

2518
if TYPE_CHECKING:
2619
# pylint:disable=unused-import,ungrouped-imports
27-
from typing import Any, List, Optional, Union
28-
from azure.core.configuration import Configuration
20+
from typing import Any, Optional, Union
2921
from azure.core.credentials import AccessToken
3022
from azure.core.pipeline import PipelineRequest, PipelineResponse
3123
from azure.core.pipeline.policies import SansIOHTTPPolicy
@@ -42,10 +34,8 @@ def __init__(self, **kwargs):
4234
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
4335
self._available = url and imds
4436
if self._available:
45-
config = _get_configuration()
46-
4737
self._client = ManagedIdentityClient(
48-
policies=_get_policies(config),
38+
_per_retry_policies=[ArcChallengeAuthPolicy()],
4939
request_factory=functools.partial(_get_request, url),
5040
**kwargs
5141
)
@@ -67,19 +57,6 @@ def _request_token(self, *scopes, **kwargs):
6757
return self._client.request_token(*scopes, **kwargs)
6858

6959

70-
def _get_policies(config, **kwargs):
71-
# type: (Configuration, **Any) -> List[PolicyType]
72-
return [
73-
UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs),
74-
config.proxy_policy,
75-
config.retry_policy,
76-
ArcChallengeAuthPolicy(),
77-
NetworkTraceLoggingPolicy(**kwargs),
78-
DistributedTracingPolicy(**kwargs),
79-
HttpLoggingPolicy(**kwargs),
80-
]
81-
82-
8360
def _get_request(url, scope, identity_config):
8461
# type: (str, str, dict) -> HttpRequest
8562
if identity_config:

sdk/identity/azure-identity/azure/identity/_internal/aad_client.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,16 @@
55
import time
66
from typing import TYPE_CHECKING
77

8-
from azure.core.configuration import Configuration
9-
from azure.core.pipeline import Pipeline
10-
from azure.core.pipeline.policies import (
11-
NetworkTraceLoggingPolicy,
12-
RetryPolicy,
13-
ProxyPolicy,
14-
UserAgentPolicy,
15-
DistributedTracingPolicy,
16-
HttpLoggingPolicy,
17-
)
18-
198
from .aad_client_base import AadClientBase
20-
from .user_agent import USER_AGENT
9+
from .._internal.pipeline import build_pipeline
2110

2211
if TYPE_CHECKING:
2312
# pylint:disable=unused-import,ungrouped-imports
24-
from typing import Any, Iterable, List, Optional, Union
13+
from typing import Any, Iterable, Optional
2514
from azure.core.credentials import AccessToken
26-
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
27-
from azure.core.pipeline.transport import HttpTransport
15+
from azure.core.pipeline import Pipeline
2816
from .._internal import AadClientCertificate
2917

30-
Policy = Union[HTTPPolicy, SansIOHTTPPolicy]
31-
3218

3319
class AadClient(AadClientBase):
3420
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
@@ -62,30 +48,6 @@ def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
6248
return self._process_response(response, now)
6349

6450
# pylint:disable=no-self-use
65-
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
66-
# type: (Optional[Configuration], Optional[List[Policy]], Optional[HttpTransport], **Any) -> Pipeline
67-
config = config or _create_config(**kwargs)
68-
policies = policies or [
69-
config.user_agent_policy,
70-
config.proxy_policy,
71-
config.retry_policy,
72-
config.logging_policy,
73-
DistributedTracingPolicy(**kwargs),
74-
HttpLoggingPolicy(**kwargs),
75-
]
76-
if not transport:
77-
from azure.core.pipeline.transport import RequestsTransport
78-
79-
transport = RequestsTransport(**kwargs)
80-
81-
return Pipeline(transport=transport, policies=policies)
82-
83-
84-
def _create_config(**kwargs):
85-
# type: (**Any) -> Configuration
86-
config = Configuration(**kwargs)
87-
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
88-
config.retry_policy = RetryPolicy(**kwargs)
89-
config.proxy_policy = ProxyPolicy(**kwargs)
90-
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
91-
return config
51+
def _build_pipeline(self, **kwargs):
52+
# type: (**Any) -> Pipeline
53+
return build_pipeline(**kwargs)

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
9393
pass
9494

9595
@abc.abstractmethod
96-
def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):
96+
def _build_pipeline(self, **kwargs):
9797
pass
9898

9999
def _process_response(self, response, request_time):
100100
# type: (PipelineResponse, int) -> AccessToken
101-
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
101+
content = response.context.get(
102+
ContentDecodePolicy.CONTEXT_NAME
103+
) or ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
102104

103105
if response.http_request.body.get("grant_type") == "refresh_token":
104106
if content.get("error") == "invalid_grant":

sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py

Lines changed: 24 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,11 @@
99
from msal import TokenCache
1010
import six
1111

12-
from azure.core.configuration import Configuration
1312
from azure.core.credentials import AccessToken
1413
from azure.core.exceptions import ClientAuthenticationError, DecodeError
15-
from azure.core.pipeline import Pipeline
16-
from azure.core.pipeline.policies import (
17-
ContentDecodePolicy,
18-
DistributedTracingPolicy,
19-
HeadersPolicy,
20-
HttpLoggingPolicy,
21-
UserAgentPolicy,
22-
RetryPolicy,
23-
NetworkTraceLoggingPolicy,
24-
)
25-
from azure.identity._internal import _scopes_to_resource
26-
27-
from .user_agent import USER_AGENT
14+
from azure.core.pipeline.policies import ContentDecodePolicy
15+
from .._internal import _scopes_to_resource
16+
from .._internal.pipeline import build_pipeline
2817

2918
try:
3019
ABC = abc.ABC
@@ -33,10 +22,10 @@
3322

3423
if TYPE_CHECKING:
3524
# pylint:disable=ungrouped-imports
36-
from typing import Any, Callable, Dict, List, Optional, Union
25+
from typing import Any, Callable, Dict, Optional, Union
3726
from azure.core.pipeline import PipelineResponse
3827
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
39-
from azure.core.pipeline.transport import HttpTransport, HttpRequest
28+
from azure.core.pipeline.transport import HttpRequest
4029

4130
PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]
4231

@@ -50,27 +39,27 @@ def __init__(self, request_factory, client_id=None, identity_config=None, **kwar
5039
self._identity_config = identity_config or {}
5140
if client_id:
5241
self._identity_config["client_id"] = client_id
53-
54-
config = kwargs.pop("_config", None) or _get_configuration(**kwargs)
55-
self._pipeline = self._build_pipeline(config, **kwargs)
56-
42+
self._pipeline = self._build_pipeline(**kwargs)
5743
self._request_factory = request_factory
5844

5945
def _process_response(self, response, request_time):
6046
# type: (PipelineResponse, int) -> AccessToken
6147

62-
try:
63-
content = ContentDecodePolicy.deserialize_from_text(
64-
response.http_response.text(), mime_type="application/json"
65-
)
66-
if not content:
67-
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
68-
except DecodeError as ex:
69-
if response.http_response.content_type.startswith("application/json"):
70-
message = "Failed to deserialize JSON from response"
71-
else:
72-
message = 'Unexpected content type "{}"'.format(response.http_response.content_type)
73-
six.raise_from(ClientAuthenticationError(message=message, response=response.http_response), ex)
48+
content = response.context.get(ContentDecodePolicy.CONTEXT_NAME)
49+
if not content:
50+
try:
51+
content = ContentDecodePolicy.deserialize_from_text(
52+
response.http_response.text(), mime_type="application/json"
53+
)
54+
except DecodeError as ex:
55+
if response.http_response.content_type.startswith("application/json"):
56+
message = "Failed to deserialize JSON from response"
57+
else:
58+
message = 'Unexpected content type "{}"'.format(response.http_response.content_type)
59+
six.raise_from(ClientAuthenticationError(message=message, response=response.http_response), ex)
60+
61+
if not content:
62+
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
7463

7564
if "access_token" not in content or not ("expires_in" in content or "expires_on" in content):
7665
if content and "access_token" in content:
@@ -110,7 +99,7 @@ def request_token(self, *scopes, **kwargs):
11099
pass
111100

112101
@abc.abstractmethod
113-
def _build_pipeline(self, config, policies=None, transport=None, **kwargs):
102+
def _build_pipeline(self, **kwargs):
114103
pass
115104

116105

@@ -124,32 +113,5 @@ def request_token(self, *scopes, **kwargs):
124113
token = self._process_response(response, request_time)
125114
return token
126115

127-
def _build_pipeline(self, config, policies=None, transport=None, **kwargs): # pylint:disable=no-self-use
128-
# type: (Configuration, Optional[List[PolicyType]], Optional[HttpTransport], **Any) -> Pipeline
129-
if policies is None: # [] is a valid policy list
130-
policies = _get_policies(config, **kwargs)
131-
if not transport:
132-
from azure.core.pipeline.transport import RequestsTransport
133-
134-
transport = RequestsTransport(**kwargs)
135-
136-
return Pipeline(transport=transport, policies=policies)
137-
138-
139-
def _get_policies(config, **kwargs):
140-
return [
141-
HeadersPolicy(**kwargs),
142-
UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs),
143-
config.proxy_policy,
144-
config.retry_policy,
145-
NetworkTraceLoggingPolicy(**kwargs),
146-
DistributedTracingPolicy(**kwargs),
147-
HttpLoggingPolicy(**kwargs),
148-
]
149-
150-
151-
def _get_configuration(**kwargs):
152-
# type: (**Any) -> Configuration
153-
config = Configuration()
154-
config.retry_policy = RetryPolicy(**kwargs)
155-
return config
116+
def _build_pipeline(self, **kwargs):
117+
return build_pipeline(**kwargs)

sdk/identity/azure-identity/azure/identity/_internal/msal_client.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,11 @@
66

77
import six
88

9-
from azure.core.configuration import Configuration
109
from azure.core.exceptions import ClientAuthenticationError
11-
from azure.core.pipeline import Pipeline
12-
from azure.core.pipeline.policies import (
13-
ContentDecodePolicy,
14-
DistributedTracingPolicy,
15-
HttpLoggingPolicy,
16-
NetworkTraceLoggingPolicy,
17-
ProxyPolicy,
18-
RetryPolicy,
19-
UserAgentPolicy,
20-
)
21-
from azure.core.pipeline.transport import HttpRequest, RequestsTransport
22-
23-
from .user_agent import USER_AGENT
10+
from azure.core.pipeline.policies import ContentDecodePolicy
11+
from azure.core.pipeline.transport import HttpRequest
12+
13+
from .pipeline import build_pipeline
2414

2515
try:
2616
from typing import TYPE_CHECKING
@@ -29,12 +19,10 @@
2919

3020
if TYPE_CHECKING:
3121
# pylint:disable=unused-import,ungrouped-imports
32-
from typing import Any, Dict, List, Optional, Union
22+
from typing import Any, Dict, Optional, Union
3323
from azure.core.pipeline import PipelineResponse
34-
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy
35-
from azure.core.pipeline.transport import HttpResponse, HttpTransport
24+
from azure.core.pipeline.transport import HttpResponse
3625

37-
PolicyList = List[Union[HTTPPolicy, SansIOHTTPPolicy]]
3826
RequestData = Union[Dict[str, str], str]
3927

4028

@@ -83,7 +71,7 @@ class MsalClient(object):
8371
def __init__(self, **kwargs): # pylint:disable=missing-client-constructor-parameter-credential
8472
# type: (**Any) -> None
8573
self._local = threading.local()
86-
self._pipeline = _build_pipeline(**kwargs)
74+
self._pipeline = build_pipeline(**kwargs)
8775

8876
def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:disable=unused-argument
8977
# type: (str, Optional[Dict[str, str]], RequestData, Optional[Dict[str, str]], **Any) -> MsalResponse
@@ -129,34 +117,3 @@ def _store_auth_error(self, response):
129117
content = response.context.get(ContentDecodePolicy.CONTEXT_NAME)
130118
if content and "error" in content:
131119
self._local.error = (content["error"], response.http_response)
132-
133-
134-
def _create_config(**kwargs):
135-
# type: (Any) -> Configuration
136-
config = Configuration(**kwargs)
137-
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
138-
config.retry_policy = RetryPolicy(**kwargs)
139-
config.proxy_policy = ProxyPolicy(**kwargs)
140-
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
141-
return config
142-
143-
144-
def _build_pipeline(config=None, policies=None, transport=None, **kwargs):
145-
# type: (Optional[Configuration], Optional[PolicyList], Optional[HttpTransport], **Any) -> Pipeline
146-
config = config or _create_config(**kwargs)
147-
148-
if policies is None: # [] is a valid policy list
149-
policies = [
150-
ContentDecodePolicy(),
151-
config.user_agent_policy,
152-
config.proxy_policy,
153-
config.retry_policy,
154-
config.logging_policy,
155-
DistributedTracingPolicy(**kwargs),
156-
HttpLoggingPolicy(**kwargs),
157-
]
158-
159-
if not transport:
160-
transport = RequestsTransport(**kwargs)
161-
162-
return Pipeline(transport=transport, policies=policies)

0 commit comments

Comments
 (0)