44# Licensed under the MIT License.
55# ------------------------------------
66import re
7- from typing import TYPE_CHECKING
7+ import time
8+ from typing import TYPE_CHECKING , Dict , List , Any
89
10+ from azure .core .pipeline import PipelineRequest , PipelineResponse
911from azure .core .pipeline .policies import SansIOHTTPPolicy
1012
1113from .._generated .aio import ContainerRegistry
1214from .._helpers import _parse_challenge
1315from .._user_agent import USER_AGENT
1416
1517if TYPE_CHECKING :
16- from typing import Dict , List , Any
17- from azure .core .credentials import TokenCredential
18- from azure .core .pipeline import PipelineRequest , PipelineResponse
18+ from azure .core .credentials_async import AsyncTokenCredential
1919
2020
2121class ExchangeClientAuthenticationPolicy (SansIOHTTPPolicy ):
2222 """Authentication policy for exchange client that does not modify the request"""
2323
24- def on_request (self , request ):
25- # type: (PipelineRequest) -> None
24+ def on_request (self , request : PipelineRequest ) -> None :
2625 pass
2726
28- def on_response (self , request , response ):
29- # type: (PipelineRequest, PipelineResponse) -> None
27+ def on_response (self , request : PipelineRequest , response : PipelineResponse ) -> None :
3028 pass
3129
3230
@@ -42,40 +40,48 @@ class ACRExchangeClient(object):
4240 BEARER = "Bearer"
4341 AUTHENTICATION_CHALLENGE_PARAMS_PATTERN = re .compile ('(?:(\\ w+)="([^""]*)")+' )
4442
45- def __init__ (self , endpoint , credential , ** kwargs ):
46- # type: (str, TokenCredential, Dict[str, Any]) -> None
43+ def __init__ (
44+ self , endpoint : str , credential : "AsyncTokencredential" , ** kwargs : Dict [str , Any ]
45+ ) -> None :
4746 if not endpoint .startswith ("https://" ) and not endpoint .startswith ("http://" ):
4847 endpoint = "https://" + endpoint
4948 self ._endpoint = endpoint
50- self ._credential_scopes = "https://management.core.windows.net/.default"
49+ self ._credential_scope = "https://management.core.windows.net/.default"
5150 self ._client = ContainerRegistry (
5251 credential = credential ,
5352 url = endpoint ,
5453 sdk_moniker = USER_AGENT ,
5554 authentication_policy = ExchangeClientAuthenticationPolicy (),
56- credential_scopes = kwargs .pop ("credential_scopes" , self ._credential_scopes ),
55+ credential_scopes = kwargs .pop ("credential_scopes" , self ._credential_scope ),
5756 ** kwargs
5857 )
5958 self ._credential = credential
59+ self ._refresh_token = None
60+ self ._last_refresh_time = None
6061
61- async def get_acr_access_token (self , challenge , ** kwargs ):
62- # type: (str) -> str
62+ async def get_acr_access_token (self , challenge : str , ** kwargs : Dict [str , Any ]) -> str :
6363 parsed_challenge = _parse_challenge (challenge )
64- refresh_token = await self .exchange_aad_token_for_refresh_token ( service = parsed_challenge ["service" ])
64+ refresh_token = await self .get_refresh_token ( parsed_challenge ["service" ], ** kwargs )
6565 return await self .exchange_refresh_token_for_access_token (
6666 refresh_token , service = parsed_challenge ["service" ], scope = parsed_challenge ["scope" ], ** kwargs
6767 )
6868
69- async def exchange_aad_token_for_refresh_token (self , service = None , ** kwargs ):
70- # type: (str, Dict[str, Any]) -> str
71- token = await self ._credential .get_token (self ._credential_scopes )
69+ async def get_refresh_token (self , service : str , ** kwargs : Dict [str , Any ]) -> str :
70+ if not self ._refresh_token or time .time () - self ._last_refresh_time > 300 :
71+ self ._refresh_token = await self .exchange_aad_token_for_refresh_token (service , ** kwargs )
72+ self ._last_refresh_time = time .time ()
73+ return self ._refresh_token
74+
75+ async def exchange_aad_token_for_refresh_token (self , service : str = None , ** kwargs : Dict [str , Any ]) -> str :
76+ token = await self ._credential .get_token (self ._credential_scope )
7277 refresh_token = await self ._client .authentication .exchange_aad_access_token_for_acr_refresh_token (
7378 service , token .token , ** kwargs
7479 )
7580 return refresh_token .refresh_token
7681
77- async def exchange_refresh_token_for_access_token (self , refresh_token , service = None , scope = None , ** kwargs ):
78- # type: (str, str, str, Dict[str, Any]) -> str
82+ async def exchange_refresh_token_for_access_token (
83+ self , refresh_token : str , service : str = None , scope : str = None , ** kwargs : Dict [str , Any ]
84+ ) -> str :
7985 access_token = await self ._client .authentication .exchange_acr_refresh_token_for_acr_access_token (
8086 service , scope , refresh_token , ** kwargs
8187 )
@@ -88,8 +94,7 @@ async def __aenter__(self):
8894 async def __aexit__ (self , * args ):
8995 self ._client .__aexit__ (* args )
9096
91- async def close (self ):
92- # type: () -> None
97+ async def close (self ) -> None :
9398 """Close sockets opened by the client.
9499 Calling this method is unnecessary when using the client as a context manager.
95100 """
0 commit comments