22// Licensed under the MIT License.
33
44using System ;
5+ using System . Collections . Generic ;
56using System . Linq ;
67using System . Reflection ;
8+ using System . Threading ;
79using System . Threading . Tasks ;
810using Azure . Core . Tests ;
911using Castle . DynamicProxy ;
@@ -12,94 +14,198 @@ namespace Azure.Core.TestFramework
1214{
1315 public class DiagnosticScopeValidatingInterceptor : IInterceptor
1416 {
17+ private static readonly MethodInfo ValidateDiagnosticScopeMethod = typeof ( DiagnosticScopeValidatingInterceptor ) . GetMethod ( nameof ( AwaitAndValidateDiagnosticScope ) , BindingFlags . NonPublic | BindingFlags . Static ) ;
1518 public void Intercept ( IInvocation invocation )
1619 {
1720 var methodName = invocation . Method . Name ;
18- if ( methodName . EndsWith ( "Async" ) )
19- {
20- Type declaringType = invocation . Method . DeclaringType ;
21- var ns = declaringType . Namespace ;
22- var expectedName = declaringType . Name + "." + methodName . Substring ( 0 , methodName . Length - 5 ) ;
23- using ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener ( s => s . StartsWith ( "Azure." ) , asyncLocal : true ) ;
24- invocation . Proceed ( ) ;
2521
26- bool expectFailure = false ;
27- bool skipChecks = false ;
22+ Type declaringType = invocation . Method . DeclaringType ;
23+ var expectedName = declaringType . Name + "." + methodName . Substring ( 0 , methodName . Length - 5 ) ;
24+ bool strict = ! invocation . Method . GetCustomAttributes ( true ) . Any ( a => a . GetType ( ) . FullName == "Azure.Core.ForwardsClientCallsAttribute" ) ;
2825
29- bool strict = ! invocation . Method . GetCustomAttributes ( true ) . Any ( a => a . GetType ( ) . FullName == "Azure.Core.ForwardsClientCallsAttribute" ) ;
30- if ( invocation . Method . ReturnType . Name . Contains ( "Pageable" ) ||
31- invocation . Method . ReturnType . Name . Contains ( "IAsyncEnumerable" ) )
26+ if ( invocation . Method . ReturnType is { IsGenericType : true } genericType &&
27+ genericType . GetGenericTypeDefinition ( ) == typeof ( AsyncPageable < > ) )
28+ {
29+ invocation . Proceed ( ) ;
30+ invocation . ReturnValue = Activator . CreateInstance ( typeof ( DiagnosticScopeValidatingAsyncEnumerable < > ) . MakeGenericType ( genericType . GenericTypeArguments [ 0 ] ) , invocation . ReturnValue , expectedName , methodName , strict ) ;
31+ }
32+ else if ( methodName . EndsWith ( "Async" ) && ! invocation . Method . ReturnType . Name . Contains ( "IAsyncEnumerable" ) )
33+ {
34+ Type genericArgument = typeof ( object ) ;
35+ Type awaitableType = invocation . Method . ReturnType ;
36+ if ( invocation . Method . ReturnType is { IsGenericType : true , GenericTypeArguments : { Length : 1 } genericTypeArguments } )
3237 {
33- return ;
38+ genericArgument = genericTypeArguments [ 0 ] ;
39+ awaitableType = invocation . Method . ReturnType . GetGenericTypeDefinition ( ) ;
3440 }
41+ ValidateDiagnosticScopeMethod . MakeGenericMethod ( genericArgument )
42+ . Invoke ( null , new object [ ] { awaitableType , invocation , expectedName , methodName , strict } ) ;
43+ }
44+ else
45+ {
46+ invocation . Proceed ( ) ;
47+ }
48+ }
3549
36- try
50+ internal static void AwaitAndValidateDiagnosticScope < T > ( Type genericType , IInvocation invocation , string expectedName , string methodName , bool strict )
51+ {
52+ // All this ceremony is not to await the returned Task/ValueTask syncronously
53+ // instead we are replacing the invocation.ReturnValue with the ValidateDiagnosticScope task
54+ // but we need to make sure the types match
55+ if ( genericType == typeof ( Task < > ) )
56+ {
57+ invocation . ReturnValue = ValidateDiagnosticScope ( async ( ) =>
3758 {
38- object returnValue = invocation . ReturnValue ;
39- if ( returnValue is Task t )
40- {
41- t . GetAwaiter ( ) . GetResult ( ) ;
42- }
43- else
44- {
45- // Await ValueTask
46- Type returnType = returnValue . GetType ( ) ;
47- MethodInfo getAwaiterMethod = returnType . GetMethod ( "GetAwaiter" , BindingFlags . Instance | BindingFlags . Public ) ;
48- MethodInfo getResultMethod = getAwaiterMethod . ReturnType . GetMethod ( "GetResult" , BindingFlags . Instance | BindingFlags . Public ) ;
49-
50- getResultMethod . Invoke (
51- getAwaiterMethod . Invoke ( returnValue , Array . Empty < object > ( ) ) ,
52- Array . Empty < object > ( ) ) ;
53- }
54- }
55- catch ( Exception ex )
59+ invocation . Proceed ( ) ;
60+ return ( await ( Task < T > ) invocation . ReturnValue , false ) ;
61+ } , expectedName , methodName , strict ) . AsTask ( ) ;
62+ }
63+ else if ( genericType == typeof ( Task ) )
64+ {
65+ invocation . ReturnValue = ValidateDiagnosticScope < object > ( async ( ) =>
66+ {
67+ invocation . Proceed ( ) ;
68+ await ( Task ) invocation . ReturnValue ;
69+ return default ;
70+ } , expectedName , methodName , strict ) . AsTask ( ) ;
71+ }
72+ else if ( genericType == typeof ( ValueTask < > ) )
73+ {
74+ invocation . ReturnValue = ValidateDiagnosticScope ( async ( ) =>
5675 {
57- expectFailure = true ;
76+ invocation . Proceed ( ) ;
77+ return ( await ( ValueTask < T > ) invocation . ReturnValue , false ) ;
78+ } , expectedName , methodName , strict ) ;
79+ }
80+ else if ( genericType == typeof ( ValueTask ) )
81+ {
82+ invocation . ReturnValue = new ValueTask ( ValidateDiagnosticScope < object > ( async ( ) =>
83+ {
84+ invocation . Proceed ( ) ;
85+ await ( ValueTask ) invocation . ReturnValue ; ;
86+ return default ;
87+ } , expectedName , methodName , strict ) . AsTask ( ) ) ;
88+ }
89+ else
90+ {
91+ ValidateDiagnosticScope < object > ( ( ) =>
92+ {
93+ invocation . Proceed ( ) ;
94+ return default ;
95+ } , expectedName , methodName , strict ) . GetAwaiter ( ) . GetResult ( ) ;
96+ }
97+ }
5898
59- if ( ex is ArgumentException )
60- {
61- // Don't expect scope for argument validation failures
62- skipChecks = true ;
63- }
99+ internal static async ValueTask < T > ValidateDiagnosticScope < T > ( Func < ValueTask < ( T Result , bool SkipChecks ) > > action , string expectedName , string methodName , bool strict )
100+ {
101+ bool expectFailure = false ;
102+ bool skipChecks = false ;
103+ T result ;
104+
105+ using ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener ( s => s . StartsWith ( "Azure." ) , asyncLocal : true ) ;
106+ try
107+ {
108+ ( result , skipChecks ) = await action ( ) ;
109+ }
110+ catch ( Exception ex )
111+ {
112+ expectFailure = true ;
113+
114+ if ( ex is ArgumentException )
115+ {
116+ // Don't expect scope for argument validation failures
117+ skipChecks = true ;
64118 }
65- finally
119+
120+ throw ;
121+ }
122+ finally
123+ {
124+ // Remove subscribers before enumerating events.
125+ diagnosticListener . Dispose ( ) ;
126+ if ( ! skipChecks )
66127 {
67- // Remove subscribers before enumerating events.
68- diagnosticListener . Dispose ( ) ;
69- if ( ! skipChecks )
128+ if ( strict )
70129 {
71- if ( strict )
130+ ClientDiagnosticListener . ProducedDiagnosticScope e = diagnosticListener . Scopes . FirstOrDefault ( e => e . Name == expectedName ) ;
131+
132+ if ( e == default )
133+ {
134+ throw new InvalidOperationException ( $ "Expected diagnostic scope not created { expectedName } { Environment . NewLine } " +
135+ $ " created { diagnosticListener . Scopes . Count } scopes [{ string . Join ( ", " , diagnosticListener . Scopes ) } ] { Environment . NewLine } " +
136+ $ " You may have forgotten to add clientDiagnostics.CreateScope(...), set your operationId to { expectedName } in { methodName } or applied the Azure.Core.ForwardsClientCallsAttribute to { methodName } .") ;
137+ }
138+
139+ if ( ! e . Activity . Tags . Any ( tag => tag . Key == "az.namespace" ) )
72140 {
73- ClientDiagnosticListener . ProducedDiagnosticScope e = diagnosticListener . Scopes . FirstOrDefault ( e => e . Name == expectedName ) ;
74-
75- if ( e == default )
76- {
77- throw new InvalidOperationException ( $ "Expected diagnostic scope not created { expectedName } { Environment . NewLine } created scopes { string . Join ( ", " , diagnosticListener . Scopes ) } { Environment . NewLine } You may have forgotten to set your operationId to { expectedName } in { methodName } or applied the Azure.Core.ForwardsClientCallsAttribute to { methodName } .") ;
78- }
79-
80- if ( ! e . Activity . Tags . Any ( tag => tag . Key == "az.namespace" ) )
81- {
82- throw new InvalidOperationException ( $ "All diagnostic scopes should have 'az.namespace' attribute, make sure the assembly containing **ClientOptions type is marked with the AzureResourceProviderNamespace attribute specifying the appropriate provider. This attribute should be included in AssemblyInfo, and can be included by pulling in AzureResourceProviderNamespaceAttribute.cs using the AzureCoreSharedSources alias.") ;
83- }
84-
85- if ( expectFailure && ! e . IsFailed )
86- {
87- throw new InvalidOperationException ( $ "Expected scope { expectedName } to be marked as failed but it succeeded") ;
88- }
141+ throw new InvalidOperationException ( $ "All diagnostic scopes should have 'az.namespace' attribute, make sure the assembly containing **ClientOptions type is marked with the AzureResourceProviderNamespace attribute specifying the appropriate provider. This attribute should be included in AssemblyInfo, and can be included by pulling in AzureResourceProviderNamespaceAttribute.cs using the AzureCoreSharedSources alias.") ;
89142 }
90- else
143+
144+ if ( expectFailure && ! e . IsFailed )
145+ {
146+ throw new InvalidOperationException ( $ "Expected scope { expectedName } to be marked as failed but it succeeded") ;
147+ }
148+ }
149+ else
150+ {
151+ if ( ! diagnosticListener . Scopes . Any ( ) )
91152 {
92- if ( ! diagnosticListener . Scopes . Any ( ) )
93- {
94- throw new InvalidOperationException ( $ "Expected some diagnostic scopes to be created, found none") ;
95- }
153+ throw new InvalidOperationException ( $ "Expected some diagnostic scopes to be created, found none") ;
96154 }
97155 }
98156 }
99157 }
100- else
158+
159+ return result ;
160+ }
161+
162+ internal class DiagnosticScopeValidatingAsyncEnumerable < T > : AsyncPageable < T >
163+ {
164+ private readonly AsyncPageable < T > _pageable ;
165+ private readonly string _expectedName ;
166+ private readonly string _methodName ;
167+ private readonly bool _strict ;
168+ private readonly bool _overridesGetAsyncEnumerator ;
169+
170+ public DiagnosticScopeValidatingAsyncEnumerable ( AsyncPageable < T > pageable , string expectedName , string methodName , bool strict )
101171 {
102- invocation . Proceed ( ) ;
172+ if ( pageable == null ) throw new ArgumentNullException ( nameof ( pageable ) , "Operations returning [Async]Pageable should never return null." ) ;
173+
174+ // If AsyncPageable overrides GetAsyncEnumerator we have to pass the call through to it
175+ // this would effectively disable the validation so avoid doing it as much as possible
176+ var getAsyncEnumeratorMethod = pageable . GetType ( ) . GetMethod ( "GetAsyncEnumerator" , BindingFlags . Public | BindingFlags . Instance ) ;
177+ _overridesGetAsyncEnumerator = getAsyncEnumeratorMethod . DeclaringType is { IsGenericType : true } genericType &&
178+ genericType . GetGenericTypeDefinition ( ) != typeof ( AsyncPageable < > ) ;
179+
180+ _pageable = pageable ;
181+ _expectedName = expectedName ;
182+ _methodName = methodName ;
183+ _strict = strict ;
184+ }
185+
186+ public override IAsyncEnumerator < T > GetAsyncEnumerator ( CancellationToken cancellationToken = default )
187+ {
188+ if ( _overridesGetAsyncEnumerator )
189+ {
190+ return _pageable . GetAsyncEnumerator ( cancellationToken ) ;
191+ }
192+
193+ return base . GetAsyncEnumerator ( cancellationToken ) ;
194+ }
195+
196+ public override async IAsyncEnumerable < Page < T > > AsPages ( string continuationToken = null , int ? pageSizeHint = null )
197+ {
198+ await using var enumerator = _pageable . AsPages ( continuationToken , pageSizeHint ) . GetAsyncEnumerator ( ) ;
199+
200+ while ( await ValidateDiagnosticScope ( async ( ) =>
201+ {
202+ bool movedNext = await enumerator . MoveNextAsync ( ) ;
203+ // Don't expect the MoveNextAsync call that returns false to create scope
204+ return ( movedNext , ! movedNext ) ;
205+ } , _expectedName , $ "AsPages() implementation returned from { _methodName } ", _strict ) )
206+ {
207+ yield return enumerator . Current ;
208+ }
103209 }
104210 }
105211 }
0 commit comments