Skip to content

Commit fff4d58

Browse files
authored
Token exchange support for ManagedIdentityCredential (Azure#19902)
1 parent 63088bb commit fff4d58

16 files changed

+380
-62
lines changed

sdk/identity/azure-identity/azure/identity/_constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ class EnvironmentVariables:
4646
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
4747
AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION = "AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION"
4848
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"
49+
50+
TOKEN_FILE_PATH = "TOKEN_FILE_PATH"
51+
TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, TOKEN_FILE_PATH)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from typing import TYPE_CHECKING
6+
7+
from .._internal import AadClient
8+
from .._internal.get_token_mixin import GetTokenMixin
9+
10+
if TYPE_CHECKING:
11+
from typing import Any, Callable, Optional
12+
from azure.core.credentials import AccessToken
13+
14+
15+
class ClientAssertionCredential(GetTokenMixin):
16+
def __init__(self, tenant_id, client_id, get_assertion, **kwargs):
17+
# type: (str, str, Callable[[], str], **Any) -> None
18+
"""Authenticates a service principal with a JWT assertion.
19+
20+
This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more
21+
convenient API for the most common assertion scenario, authenticating a service principal with a certificate.
22+
23+
:param str tenant_id: ID of the principal's tenant. Also called its "directory" ID.
24+
:param str client_id: the principal's client ID
25+
:param get_assertion: a callable that returns a string assertion. The credential will call this every time it
26+
acquires a new token.
27+
:paramtype get_assertion: Callable[[], str]
28+
29+
:keyword str authority: authority of an Azure Active Directory endpoint, for example
30+
"login.microsoftonline.com", the authority for Azure Public Cloud (which is the default).
31+
:class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds.
32+
"""
33+
self._get_assertion = get_assertion
34+
self._client = AadClient(tenant_id, client_id, **kwargs)
35+
super(ClientAssertionCredential, self).__init__(**kwargs)
36+
37+
def _acquire_token_silently(self, *scopes, **kwargs):
38+
# type: (*str, **Any) -> Optional[AccessToken]
39+
return self._client.get_cached_access_token(scopes, **kwargs)
40+
41+
def _request_token(self, *scopes, **kwargs):
42+
# type: (*str, **Any) -> AccessToken
43+
assertion = self._get_assertion()
44+
token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
45+
return token

sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def __init__(self, **kwargs):
6565
from .azure_arc import AzureArcCredential
6666

6767
self._credential = AzureArcCredential(**kwargs)
68+
elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS):
69+
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
70+
from .token_exchange import TokenExchangeCredential
71+
72+
self._credential = TokenExchangeCredential(
73+
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
74+
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
75+
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
76+
**kwargs
77+
)
6878
else:
6979
from .imds import ImdsCredential
7080

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import time
6+
from typing import TYPE_CHECKING
7+
8+
from .client_assertion import ClientAssertionCredential
9+
10+
if TYPE_CHECKING:
11+
# pylint:disable=unused-import,ungrouped-imports
12+
from typing import Any
13+
14+
15+
class TokenFileMixin(object):
16+
def __init__(self, token_file_path, **_):
17+
# type: (str, **Any) -> None
18+
super(TokenFileMixin, self).__init__()
19+
self._jwt = ""
20+
self._last_read_time = 0
21+
self._token_file_path = token_file_path
22+
23+
def get_service_account_token(self):
24+
# type: () -> str
25+
now = int(time.time())
26+
if now - self._last_read_time > 300:
27+
with open(self._token_file_path) as f:
28+
self._jwt = f.read()
29+
self._last_read_time = now
30+
return self._jwt
31+
32+
33+
class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
34+
def __init__(self, tenant_id, client_id, token_file_path, **kwargs):
35+
# type: (str, str, str, **Any) -> None
36+
super(TokenExchangeCredential, self).__init__(
37+
tenant_id=tenant_id,
38+
client_id=client_id,
39+
get_assertion=self.get_service_account_token,
40+
token_file_path=token_file_path,
41+
**kwargs
42+
)

sdk/identity/azure-identity/azure/identity/_internal/aad_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
4040
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
4141
return self._process_response(response, now)
4242

43+
def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs):
44+
# type: (Iterable[str], str, **Any) -> AccessToken
45+
request = self._get_jwt_assertion_request(scopes, assertion)
46+
now = int(time.time())
47+
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
48+
return self._process_response(response, now)
49+
4350
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
4451
# type: (Iterable[str], str, **Any) -> AccessToken
4552
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def get_cached_refresh_tokens(self, scopes):
8080
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
8181
pass
8282

83+
@abc.abstractmethod
84+
def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs):
85+
pass
86+
8387
@abc.abstractmethod
8488
def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
8589
pass
@@ -165,10 +169,8 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None,
165169
request = self._post(data, **kwargs)
166170
return request
167171

168-
def _get_client_certificate_request(self, scopes, certificate, **kwargs):
169-
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
170-
audience = self._get_token_url(**kwargs)
171-
assertion = self._get_jwt_assertion(certificate, audience)
172+
def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
173+
# type: (Iterable[str], str, **Any) -> HttpRequest
172174
data = {
173175
"client_assertion": assertion,
174176
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
@@ -180,19 +182,8 @@ def _get_client_certificate_request(self, scopes, certificate, **kwargs):
180182
request = self._post(data, **kwargs)
181183
return request
182184

183-
def _get_client_secret_request(self, scopes, secret, **kwargs):
184-
# type: (Iterable[str], str, **Any) -> HttpRequest
185-
data = {
186-
"client_id": self._client_id,
187-
"client_secret": secret,
188-
"grant_type": "client_credentials",
189-
"scope": " ".join(scopes),
190-
}
191-
request = self._post(data, **kwargs)
192-
return request
193-
194-
def _get_jwt_assertion(self, certificate, audience):
195-
# type: (AadClientCertificate, str) -> str
185+
def _get_client_certificate_request(self, scopes, certificate, **kwargs):
186+
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
196187
now = int(time.time())
197188
header = six.ensure_binary(
198189
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
@@ -201,7 +192,7 @@ def _get_jwt_assertion(self, certificate, audience):
201192
json.dumps(
202193
{
203194
"jti": str(uuid4()),
204-
"aud": audience,
195+
"aud": self._get_token_url(**kwargs),
205196
"iss": self._client_id,
206197
"sub": self._client_id,
207198
"nbf": now,
@@ -213,8 +204,20 @@ def _get_jwt_assertion(self, certificate, audience):
213204
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
214205
signature = certificate.sign(jws)
215206
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)
207+
assertion = jwt_bytes.decode("utf-8")
216208

217-
return jwt_bytes.decode("utf-8")
209+
return self._get_jwt_assertion_request(scopes, assertion, **kwargs)
210+
211+
def _get_client_secret_request(self, scopes, secret, **kwargs):
212+
# type: (Iterable[str], str, **Any) -> HttpRequest
213+
data = {
214+
"client_id": self._client_id,
215+
"client_secret": secret,
216+
"grant_type": "client_credentials",
217+
"scope": " ".join(scopes),
218+
}
219+
request = self._post(data, **kwargs)
220+
return request
218221

219222
def _get_refresh_token_request(self, scopes, refresh_token, **kwargs):
220223
# type: (Iterable[str], str, **Any) -> HttpRequest
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from typing import TYPE_CHECKING
6+
7+
from .._internal import AadClient, AsyncContextManager
8+
from .._internal.get_token_mixin import GetTokenMixin
9+
10+
if TYPE_CHECKING:
11+
from typing import Any, Callable, Optional
12+
from azure.core.credentials import AccessToken
13+
14+
15+
class ClientAssertionCredential(AsyncContextManager, GetTokenMixin):
16+
def __init__(self, tenant_id: str, client_id: str, get_assertion: "Callable[[], str]", **kwargs: "Any") -> None:
17+
"""Authenticates a service principal with a JWT assertion.
18+
19+
This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more
20+
convenient API for the most common assertion scenario, authenticating a service principal with a certificate.
21+
22+
:param str tenant_id: ID of the principal's tenant. Also called its "directory" ID.
23+
:param str client_id: the principal's client ID
24+
:param get_assertion: a callable that returns a string assertion. The credential will call this every time it
25+
acquires a new token.
26+
:paramtype get_assertion: Callable[[], str]
27+
28+
:keyword str authority: authority of an Azure Active Directory endpoint, for example
29+
"login.microsoftonline.com", the authority for Azure Public Cloud (which is the default).
30+
:class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds.
31+
"""
32+
self._get_assertion = get_assertion
33+
self._client = AadClient(tenant_id, client_id, **kwargs)
34+
super().__init__(**kwargs)
35+
36+
async def __aenter__(self):
37+
await self._client.__aenter__()
38+
return self
39+
40+
async def close(self) -> None:
41+
"""Close the credential's transport session."""
42+
await self._client.close()
43+
44+
async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]":
45+
return self._client.get_cached_access_token(scopes, **kwargs)
46+
47+
async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
48+
assertion = self._get_assertion()
49+
token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
50+
return token

sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ def __init__(self, **kwargs: "Any") -> None:
6262
from .azure_arc import AzureArcCredential
6363

6464
self._credential = AzureArcCredential(**kwargs)
65+
elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS):
66+
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
67+
from .token_exchange import TokenExchangeCredential
68+
69+
self._credential = TokenExchangeCredential(
70+
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
71+
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
72+
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
73+
**kwargs
74+
)
6575
else:
6676
from .imds import ImdsCredential
6777

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
from typing import TYPE_CHECKING
6+
7+
from .client_assertion import ClientAssertionCredential
8+
from ..._credentials.token_exchange import TokenFileMixin
9+
10+
if TYPE_CHECKING:
11+
# pylint:disable=unused-import,ungrouped-imports
12+
from typing import Any
13+
14+
15+
class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
16+
def __init__(self, tenant_id: str, client_id: str, token_file_path: str, **kwargs: "Any") -> None:
17+
super().__init__(
18+
tenant_id=tenant_id,
19+
client_id=client_id,
20+
get_assertion=self.get_service_account_token,
21+
token_file_path=token_file_path,
22+
**kwargs
23+
)

sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ async def obtain_token_by_authorization_code(
4848
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
4949
return self._process_response(response, now)
5050

51-
async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
52-
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
51+
async def obtain_token_by_client_certificate(
52+
self, scopes: "Iterable[str]", certificate: "AadClientCertificate", **kwargs: "Any"
53+
) -> "AccessToken":
5354
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
5455
now = int(time.time())
5556
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
@@ -63,6 +64,14 @@ async def obtain_token_by_client_secret(
6364
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
6465
return self._process_response(response, now)
6566

67+
async def obtain_token_by_jwt_assertion(
68+
self, scopes: "Iterable[str]", assertion: str, **kwargs: "Any"
69+
) -> "AccessToken":
70+
request = self._get_jwt_assertion_request(scopes, assertion)
71+
now = int(time.time())
72+
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
73+
return self._process_response(response, now)
74+
6675
async def obtain_token_by_refresh_token(
6776
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
6877
) -> "AccessToken":

0 commit comments

Comments
 (0)