Skip to content

Commit ae427fd

Browse files
authored
Fix ManagedIdentityCredential token caching (Azure#17323)
1 parent 621efa3 commit ae427fd

File tree

4 files changed

+97
-4
lines changed

4 files changed

+97
-4
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Release History
22

33
## 1.6.0b3 (Unreleased)
4-
4+
### Fixed
5+
- ManagedIdentityCredential caches tokens correctly
56

67
## 1.6.0b2 (2021-03-09)
78
### Breaking Changes

sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _process_response(self, response, request_time):
7979

8080
# caching is the final step because TokenCache.add mutates its "event"
8181
self._cache.add(
82-
event={"response": content, "scope": content["resource"]}, now=request_time,
82+
event={"response": content, "scope": [content["resource"]]}, now=request_time,
8383
)
8484

8585
return token
@@ -89,8 +89,9 @@ def get_cached_token(self, *scopes):
8989
resource = _scopes_to_resource(*scopes)
9090
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource])
9191
for token in tokens:
92-
if token["expires_on"] > time.time():
93-
return AccessToken(token["secret"], token["expires_on"])
92+
expires_on = int(token["expires_on"])
93+
if expires_on > time.time():
94+
return AccessToken(token["secret"], expires_on)
9495
return None
9596

9697
@abc.abstractmethod
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import time
6+
7+
from azure.core.pipeline.transport import HttpRequest
8+
from azure.identity._internal.managed_identity_client import ManagedIdentityClient
9+
10+
from helpers import mock_response, Request, validating_transport
11+
12+
13+
def test_caching():
14+
scope = "scope"
15+
expected_expires_on = int(time.time() + 3600)
16+
expected_token = "*"
17+
transport = validating_transport(
18+
requests=[Request(url="http://localhost")],
19+
responses=[
20+
mock_response(
21+
json_payload={
22+
"access_token": expected_token,
23+
"expires_in": 3600,
24+
"expires_on": expected_expires_on,
25+
"resource": scope,
26+
"token_type": "Bearer",
27+
}
28+
)
29+
],
30+
)
31+
client = ManagedIdentityClient(
32+
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=transport
33+
)
34+
35+
token = client.get_cached_token(scope)
36+
assert not token
37+
38+
token = client.request_token(scope)
39+
assert token.expires_on == expected_expires_on
40+
assert token.token == expected_token
41+
42+
token = client.get_cached_token(scope)
43+
assert token.expires_on == expected_expires_on
44+
assert token.token == expected_token
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import time
6+
7+
from azure.core.pipeline.transport import HttpRequest
8+
from azure.identity.aio._internal.managed_identity_client import AsyncManagedIdentityClient
9+
import pytest
10+
11+
from helpers import mock_response, Request
12+
from helpers_async import async_validating_transport
13+
14+
15+
@pytest.mark.asyncio
16+
async def test_caching():
17+
scope = "scope"
18+
expected_expires_on = int(time.time() + 3600)
19+
expected_token = "*"
20+
transport = async_validating_transport(
21+
requests=[Request(url="http://localhost")],
22+
responses=[
23+
mock_response(
24+
json_payload={
25+
"access_token": expected_token,
26+
"expires_in": 3600,
27+
"expires_on": expected_expires_on,
28+
"resource": scope,
29+
"token_type": "Bearer",
30+
}
31+
)
32+
],
33+
)
34+
client = AsyncManagedIdentityClient(
35+
request_factory=lambda _, __: HttpRequest("GET", "http://localhost"), transport=transport
36+
)
37+
38+
token = client.get_cached_token(scope)
39+
assert not token
40+
41+
token = await client.request_token(scope)
42+
assert token.expires_on == expected_expires_on
43+
assert token.token == expected_token
44+
45+
token = client.get_cached_token(scope)
46+
assert token.expires_on == expected_expires_on
47+
assert token.token == expected_token

0 commit comments

Comments
 (0)