33using System . CommandLine ;
44using System . Linq ;
55using System . Net . Http ;
6+ using System . Text . RegularExpressions ;
67using System . Threading . Tasks ;
78using GitCredentialManager ;
89using GitCredentialManager . Authentication ;
@@ -74,10 +75,9 @@ public bool IsSupported(HttpResponseMessage response)
7475
7576 public async Task < ICredential > GetCredentialAsync ( InputArguments input )
7677 {
77- Uri remoteUri = input . GetRemoteUri ( ) ;
78-
7978 if ( UsePersonalAccessTokens ( ) )
8079 {
80+ Uri remoteUri = input . GetRemoteUri ( ) ;
8181 string service = GetServiceName ( remoteUri ) ;
8282 string account = GetAccountNameForCredentialQuery ( input ) ;
8383
@@ -104,7 +104,7 @@ public async Task<ICredential> GetCredentialAsync(InputArguments input)
104104 {
105105 // Include the username request here so that we may use it as an override
106106 // for user account lookups when getting Azure Access Tokens.
107- var azureResult = await GetAzureAccessTokenAsync ( remoteUri , input . UserName ) ;
107+ var azureResult = await GetAzureAccessTokenAsync ( input ) ;
108108 return new GitCredential ( azureResult . AccountUpn , azureResult . AccessToken ) ;
109109 }
110110 }
@@ -222,8 +222,11 @@ private async Task<ICredential> GeneratePersonalAccessTokenAsync(InputArguments
222222 return new GitCredential ( result . AccountUpn , pat ) ;
223223 }
224224
225- private async Task < IMicrosoftAuthenticationResult > GetAzureAccessTokenAsync ( Uri remoteUri , string userName )
225+ private async Task < IMicrosoftAuthenticationResult > GetAzureAccessTokenAsync ( InputArguments input )
226226 {
227+ Uri remoteUri = input . GetRemoteUri ( ) ;
228+ string userName = input . UserName ;
229+
227230 // We should not allow unencrypted communication and should inform the user
228231 if ( StringComparer . OrdinalIgnoreCase . Equals ( remoteUri . Scheme , "http" ) )
229232 {
@@ -234,14 +237,27 @@ private async Task<IMicrosoftAuthenticationResult> GetAzureAccessTokenAsync(Uri
234237 Uri orgUri = UriHelpers . CreateOrganizationUri ( remoteUri , out string orgName ) ;
235238
236239 _context . Trace . WriteLine ( $ "Determining Microsoft Authentication authority for Azure DevOps organization '{ orgName } '...") ;
237- string authAuthority = _authorityCache . GetAuthority ( orgName ) ;
238- if ( authAuthority is null )
240+ if ( TryGetAuthorityFromHeaders ( input . WwwAuth , out string authAuthority ) )
241+ {
242+ _context . Trace . WriteLine ( "Authority was found in WWW-Authenticate headers from Git input." ) ;
243+ }
244+ else
239245 {
240- // If there is no cached value we must query for it and cache it for future use
241- _context . Trace . WriteLine ( $ "No cached authority value - querying { orgUri } for authority...") ;
242- authAuthority = await _azDevOps . GetAuthorityAsync ( orgUri ) ;
243- _authorityCache . UpdateAuthority ( orgName , authAuthority ) ;
246+ // Try to get the authority from the cache
247+ authAuthority = _authorityCache . GetAuthority ( orgName ) ;
248+ if ( authAuthority is null )
249+ {
250+ // If there is no cached value we must query for it and cache it for future use
251+ _context . Trace . WriteLine ( $ "No cached authority value - querying { orgUri } for authority...") ;
252+ authAuthority = await _azDevOps . GetAuthorityAsync ( orgUri ) ;
253+ _authorityCache . UpdateAuthority ( orgName , authAuthority ) ;
254+ }
255+ else
256+ {
257+ _context . Trace . WriteLine ( "Authority was found in cache." ) ;
258+ }
244259 }
260+
245261 _context . Trace . WriteLine ( $ "Authority is '{ authAuthority } '.") ;
246262
247263 //
@@ -284,6 +300,30 @@ private async Task<IMicrosoftAuthenticationResult> GetAzureAccessTokenAsync(Uri
284300 return result ;
285301 }
286302
303+ internal /* for testing purposes */ static bool TryGetAuthorityFromHeaders ( IEnumerable < string > headers , out string authority )
304+ {
305+ authority = null ;
306+
307+ if ( headers is null )
308+ {
309+ return false ;
310+ }
311+
312+ var regex = new Regex ( @"authorization_uri=""?(?<authority>.+)""?" , RegexOptions . Compiled | RegexOptions . IgnoreCase ) ;
313+
314+ foreach ( string header in headers )
315+ {
316+ Match match = regex . Match ( header ) ;
317+ if ( match . Success )
318+ {
319+ authority = match . Groups [ "authority" ] . Value . Trim ( new [ ] { '"' , '\' ' } ) ;
320+ return true ;
321+ }
322+ }
323+
324+ return false ;
325+ }
326+
287327 private string GetClientId ( )
288328 {
289329 // Check for developer override value
0 commit comments