Skip to content

Commit c4a1d6e

Browse files
authored
[Identity] Fix multi-tenant auth using async AadClient (Azure#21322)
1 parent 25162e9 commit c4a1d6e

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from azure.core.credentials import AccessToken
1515
from azure.core.pipeline import AsyncPipeline
1616
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
17+
from azure.core.pipeline.transport import HttpRequest
1718
from ..._internal import AadClientCertificate
1819

1920
Policy = Union[AsyncHTTPPolicy, SansIOHTTPPolicy]
@@ -44,41 +45,31 @@ async def obtain_token_by_authorization_code(
4445
request = self._get_auth_code_request(
4546
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
4647
)
47-
now = int(time.time())
48-
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
49-
return self._process_response(response, now)
48+
return await self._run_pipeline(request, **kwargs)
5049

5150
async def obtain_token_by_client_certificate(
5251
self, scopes: "Iterable[str]", certificate: "AadClientCertificate", **kwargs: "Any"
5352
) -> "AccessToken":
5453
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
55-
now = int(time.time())
56-
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
57-
return self._process_response(response, now)
54+
return await self._run_pipeline(request, stream=False, **kwargs)
5855

5956
async def obtain_token_by_client_secret(
6057
self, scopes: "Iterable[str]", secret: str, **kwargs: "Any"
6158
) -> "AccessToken":
6259
request = self._get_client_secret_request(scopes, secret, **kwargs)
63-
now = int(time.time())
64-
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
65-
return self._process_response(response, now)
60+
return await self._run_pipeline(request, **kwargs)
6661

6762
async def obtain_token_by_jwt_assertion(
6863
self, scopes: "Iterable[str]", assertion: str, **kwargs: "Any"
6964
) -> "AccessToken":
7065
request = self._get_jwt_assertion_request(scopes, assertion)
71-
now = int(time.time())
72-
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
73-
return self._process_response(response, now)
66+
return await self._run_pipeline(request, stream=False, **kwargs)
7467

7568
async def obtain_token_by_refresh_token(
7669
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
7770
) -> "AccessToken":
7871
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
79-
now = int(time.time())
80-
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
81-
return self._process_response(response, now)
72+
return await self._run_pipeline(request, **kwargs)
8273

8374
async def obtain_token_on_behalf_of(
8475
self,
@@ -90,10 +81,16 @@ async def obtain_token_on_behalf_of(
9081
request = self._get_on_behalf_of_request(
9182
scopes=scopes, client_credential=client_credential, user_assertion=user_assertion, **kwargs
9283
)
93-
now = int(time.time())
94-
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
95-
return self._process_response(response, now)
84+
return await self._run_pipeline(request, **kwargs)
9685

9786
# pylint:disable=no-self-use
9887
def _build_pipeline(self, **kwargs: "Any") -> "AsyncPipeline":
9988
return build_async_pipeline(**kwargs)
89+
90+
async def _run_pipeline(self, request: "HttpRequest", **kwargs: "Any") -> "AccessToken":
91+
# remove tenant_id kwarg that could have been passed from credential's get_token method
92+
# tenant_id is already part of `request` at this point
93+
kwargs.pop("tenant_id", None)
94+
now = int(time.time())
95+
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
96+
return self._process_response(response, now)

sdk/identity/azure-identity/tests/test_client_secret_credential.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
from azure.core.exceptions import ClientAuthenticationError
65
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
76
from azure.identity import ClientSecretCredential, TokenCachePersistenceOptions
87
from azure.identity._enums import RegionalAuthority
@@ -207,7 +206,9 @@ def test_multitenant_authentication():
207206
second_tenant = "second-tenant"
208207
second_token = first_token * 2
209208

210-
def send(request, **_):
209+
def send(request, **kwargs):
210+
assert "tenant_id" not in kwargs, "tenant_id kwarg shouldn't get passed to send method"
211+
211212
parsed = urlparse(request.url)
212213
tenant = parsed.path.split("/")[1]
213214
assert tenant in (first_tenant, second_tenant, "common"), 'unexpected tenant "{}"'.format(tenant)
@@ -233,6 +234,18 @@ def send(request, **_):
233234
token = credential.get_token("scope")
234235
assert token.token == first_token
235236

237+
238+
def test_live_multitenant_authentication(live_service_principal):
239+
# first create a credential with a non-existent tenant
240+
credential = ClientSecretCredential(
241+
"...", live_service_principal["client_id"], live_service_principal["client_secret"]
242+
)
243+
# then get a valid token for an actual tenant
244+
token = credential.get_token("https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"])
245+
assert token.token
246+
assert token.expires_on
247+
248+
236249
def test_multitenant_authentication_not_allowed():
237250
expected_tenant = "expected-tenant"
238251
expected_token = "***"

sdk/identity/azure-identity/tests/test_client_secret_credential_async.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from urllib.parse import urlparse
88

99
from azure.core.credentials import AccessToken
10-
from azure.core.exceptions import ClientAuthenticationError
1110
from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy
1211
from azure.identity import TokenCachePersistenceOptions
1312
from azure.identity._constants import EnvironmentVariables
@@ -257,10 +256,13 @@ async def test_multitenant_authentication():
257256
second_tenant = "second-tenant"
258257
second_token = first_token * 2
259258

260-
async def send(request, **_):
259+
async def send(request, **kwargs):
260+
assert "tenant_id" not in kwargs, "tenant_id kwarg shouldn't get passed to send method"
261+
261262
parsed = urlparse(request.url)
262263
tenant = parsed.path.split("/")[1]
263264
assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant)
265+
264266
token = first_token if tenant == first_tenant else second_token
265267
return mock_response(json_payload=build_aad_response(access_token=token))
266268

@@ -280,6 +282,21 @@ async def send(request, **_):
280282
token = await credential.get_token("scope")
281283
assert token.token == first_token
282284

285+
286+
@pytest.mark.asyncio
287+
async def test_live_multitenant_authentication(live_service_principal):
288+
# first create a credential with a non-existent tenant
289+
credential = ClientSecretCredential(
290+
"...", live_service_principal["client_id"], live_service_principal["client_secret"]
291+
)
292+
# then get a valid token for an actual tenant
293+
token = await credential.get_token(
294+
"https://vault.azure.net/.default", tenant_id=live_service_principal["tenant_id"]
295+
)
296+
assert token.token
297+
assert token.expires_on
298+
299+
283300
@pytest.mark.asyncio
284301
async def test_multitenant_authentication_not_allowed():
285302
expected_tenant = "expected-tenant"

0 commit comments

Comments
 (0)