Skip to content

Commit 775284a

Browse files
[Container Registry] Acr token cache (Azure#18107)
1 parent b95e58b commit 775284a

File tree

2 files changed

+41
-26
lines changed

2 files changed

+41
-26
lines changed

sdk/containerregistry/azure-containerregistry/azure/containerregistry/_exchange_client.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Licensed under the MIT License.
55
# ------------------------------------
66
import re
7+
import time
78
from typing import TYPE_CHECKING
89

910
from azure.core.pipeline.policies import SansIOHTTPPolicy
@@ -47,29 +48,38 @@ def __init__(self, endpoint, credential, **kwargs):
4748
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
4849
endpoint = "https://" + endpoint
4950
self._endpoint = endpoint
50-
self._credential_scopes = "https://management.core.windows.net/.default"
51+
self.credential_scope = "https://management.core.windows.net/.default"
5152
self._client = ContainerRegistry(
5253
credential=credential,
5354
url=endpoint,
5455
sdk_moniker=USER_AGENT,
5556
authentication_policy=ExchangeClientAuthenticationPolicy(),
56-
credential_scopes=kwargs.pop("credential_scopes", self._credential_scopes),
57+
credential_scopes=kwargs.pop("credential_scopes", self.credential_scope),
5758
**kwargs
5859
)
5960
self._credential = credential
61+
self._refresh_token = None
62+
self._last_refresh_time = 0
6063

6164
def get_acr_access_token(self, challenge, **kwargs):
6265
# type: (str, Dict[str, Any]) -> str
6366
parsed_challenge = _parse_challenge(challenge)
64-
refresh_token = self.exchange_aad_token_for_refresh_token(service=parsed_challenge["service"])
67+
refresh_token = self.get_refresh_token(parsed_challenge["service"], **kwargs)
6568
return self.exchange_refresh_token_for_access_token(
6669
refresh_token, service=parsed_challenge["service"], scope=parsed_challenge["scope"], **kwargs
6770
)
6871

72+
def get_refresh_token(self, service, **kwargs):
73+
# type: (str, **Any) -> str
74+
if not self._refresh_token or time.time() - self._last_refresh_time > 300:
75+
self._refresh_token = self.exchange_aad_token_for_refresh_token(service, **kwargs)
76+
self._last_refresh_time = time.time()
77+
return self._refresh_token
78+
6979
def exchange_aad_token_for_refresh_token(self, service=None, **kwargs):
7080
# type: (str, Dict[str, Any]) -> str
7181
refresh_token = self._client.authentication.exchange_aad_access_token_for_acr_refresh_token(
72-
service=service, access_token=self._credential.get_token(self._credential_scopes).token, **kwargs
82+
service=service, access_token=self._credential.get_token(self.credential_scope).token, **kwargs
7383
)
7484
return refresh_token.refresh_token
7585

sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_exchange_client.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,27 @@
44
# Licensed under the MIT License.
55
# ------------------------------------
66
import re
7-
from typing import TYPE_CHECKING
7+
import time
8+
from typing import TYPE_CHECKING, Dict, List, Any
89

10+
from azure.core.pipeline import PipelineRequest, PipelineResponse
911
from azure.core.pipeline.policies import SansIOHTTPPolicy
1012

1113
from .._generated.aio import ContainerRegistry
1214
from .._helpers import _parse_challenge
1315
from .._user_agent import USER_AGENT
1416

1517
if TYPE_CHECKING:
16-
from typing import Dict, List, Any
17-
from azure.core.credentials import TokenCredential
18-
from azure.core.pipeline import PipelineRequest, PipelineResponse
18+
from azure.core.credentials_async import AsyncTokenCredential
1919

2020

2121
class ExchangeClientAuthenticationPolicy(SansIOHTTPPolicy):
2222
"""Authentication policy for exchange client that does not modify the request"""
2323

24-
def on_request(self, request):
25-
# type: (PipelineRequest) -> None
24+
def on_request(self, request: PipelineRequest) -> None:
2625
pass
2726

28-
def on_response(self, request, response):
29-
# type: (PipelineRequest, PipelineResponse) -> None
27+
def on_response(self, request: PipelineRequest, response: PipelineResponse) -> None:
3028
pass
3129

3230

@@ -42,40 +40,48 @@ class ACRExchangeClient(object):
4240
BEARER = "Bearer"
4341
AUTHENTICATION_CHALLENGE_PARAMS_PATTERN = re.compile('(?:(\\w+)="([^""]*)")+')
4442

45-
def __init__(self, endpoint, credential, **kwargs):
46-
# type: (str, TokenCredential, Dict[str, Any]) -> None
43+
def __init__(
44+
self, endpoint: str, credential: "AsyncTokencredential", **kwargs: Dict[str, Any]
45+
) -> None:
4746
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
4847
endpoint = "https://" + endpoint
4948
self._endpoint = endpoint
50-
self._credential_scopes = "https://management.core.windows.net/.default"
49+
self._credential_scope = "https://management.core.windows.net/.default"
5150
self._client = ContainerRegistry(
5251
credential=credential,
5352
url=endpoint,
5453
sdk_moniker=USER_AGENT,
5554
authentication_policy=ExchangeClientAuthenticationPolicy(),
56-
credential_scopes=kwargs.pop("credential_scopes", self._credential_scopes),
55+
credential_scopes=kwargs.pop("credential_scopes", self._credential_scope),
5756
**kwargs
5857
)
5958
self._credential = credential
59+
self._refresh_token = None
60+
self._last_refresh_time = None
6061

61-
async def get_acr_access_token(self, challenge, **kwargs):
62-
# type: (str) -> str
62+
async def get_acr_access_token(self, challenge: str, **kwargs: Dict[str, Any]) -> str:
6363
parsed_challenge = _parse_challenge(challenge)
64-
refresh_token = await self.exchange_aad_token_for_refresh_token(service=parsed_challenge["service"])
64+
refresh_token = await self.get_refresh_token(parsed_challenge["service"], **kwargs)
6565
return await self.exchange_refresh_token_for_access_token(
6666
refresh_token, service=parsed_challenge["service"], scope=parsed_challenge["scope"], **kwargs
6767
)
6868

69-
async def exchange_aad_token_for_refresh_token(self, service=None, **kwargs):
70-
# type: (str, Dict[str, Any]) -> str
71-
token = await self._credential.get_token(self._credential_scopes)
69+
async def get_refresh_token(self, service: str, **kwargs: Dict[str, Any]) -> str:
70+
if not self._refresh_token or time.time() - self._last_refresh_time > 300:
71+
self._refresh_token = await self.exchange_aad_token_for_refresh_token(service, **kwargs)
72+
self._last_refresh_time = time.time()
73+
return self._refresh_token
74+
75+
async def exchange_aad_token_for_refresh_token(self, service: str = None, **kwargs: Dict[str, Any]) -> str:
76+
token = await self._credential.get_token(self._credential_scope)
7277
refresh_token = await self._client.authentication.exchange_aad_access_token_for_acr_refresh_token(
7378
service, token.token, **kwargs
7479
)
7580
return refresh_token.refresh_token
7681

77-
async def exchange_refresh_token_for_access_token(self, refresh_token, service=None, scope=None, **kwargs):
78-
# type: (str, str, str, Dict[str, Any]) -> str
82+
async def exchange_refresh_token_for_access_token(
83+
self, refresh_token: str, service: str = None, scope: str = None, **kwargs: Dict[str, Any]
84+
) -> str:
7985
access_token = await self._client.authentication.exchange_acr_refresh_token_for_acr_access_token(
8086
service, scope, refresh_token, **kwargs
8187
)
@@ -88,8 +94,7 @@ async def __aenter__(self):
8894
async def __aexit__(self, *args):
8995
self._client.__aexit__(*args)
9096

91-
async def close(self):
92-
# type: () -> None
97+
async def close(self) -> None:
9398
"""Close sockets opened by the client.
9499
Calling this method is unnecessary when using the client as a context manager.
95100
"""

0 commit comments

Comments
 (0)