Skip to content

Commit 4f191fb

Browse files
authored
Add AsyncPageable scope validation (Azure#17635)
Fixes: Azure#17633
1 parent 3338036 commit 4f191fb

File tree

18 files changed

+605
-374
lines changed

18 files changed

+605
-374
lines changed

eng/Packages.Data.props

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@
8484
<PackageReference Update="Microsoft.Azure.Amqp" Version="2.4.9" />
8585
<PackageReference Update="Microsoft.Identity.Client" Version="4.22.0" />
8686
<PackageReference Update="Microsoft.Identity.Client.Extensions.Msal" Version="2.16.5" />
87-
8887
<!-- TODO: Make sure this package is arch-board approved -->
8988
<PackageReference Update="System.IdentityModel.Tokens.Jwt" Version="5.4.0" />
9089
</ItemGroup>

sdk/appconfiguration/Azure.Data.AppConfiguration/src/ConfigurationClient.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ private Request CreateGetRequest(string key, string label, DateTimeOffset accept
705705
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
706706
private async Task<Page<ConfigurationSetting>> GetConfigurationSettingsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
707707
{
708-
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetConfigurationSettingsPage)}");
708+
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetConfigurationSettings)}");
709709
scope.Start();
710710

711711
try
@@ -738,7 +738,7 @@ private async Task<Page<ConfigurationSetting>> GetConfigurationSettingsPageAsync
738738
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
739739
private Page<ConfigurationSetting> GetConfigurationSettingsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
740740
{
741-
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetConfigurationSettingsPage)}");
741+
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetConfigurationSettings)}");
742742
scope.Start();
743743

744744
try
@@ -787,7 +787,7 @@ private Request CreateBatchRequest(SettingSelector selector, string pageLink)
787787
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
788788
private async Task<Page<ConfigurationSetting>> GetRevisionsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
789789
{
790-
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetRevisionsPage)}");
790+
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetRevisions)}");
791791
scope.Start();
792792

793793
try
@@ -820,7 +820,7 @@ private async Task<Page<ConfigurationSetting>> GetRevisionsPageAsync(SettingSele
820820
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
821821
private Page<ConfigurationSetting> GetRevisionsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
822822
{
823-
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetRevisionsPage)}");
823+
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ConfigurationClient)}.{nameof(GetRevisions)}");
824824
scope.Start();
825825

826826
try

sdk/communication/Azure.Communication.Administration/src/PhoneNumberAdministrationClient.cs

Lines changed: 112 additions & 114 deletions
Large diffs are not rendered by default.

sdk/communication/Azure.Communication.Chat/src/ChatThreadClient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ public virtual AsyncPageable<ChatThreadMember> GetMembersAsync(CancellationToken
382382
{
383383
async Task<Page<ChatThreadMember>> FirstPageFunc(int? pageSizeHint)
384384
{
385-
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ChatThreadClient)}.{nameof(GetMessages)}");
385+
using DiagnosticScope scope = _clientDiagnostics.CreateScope($"{nameof(ChatThreadClient)}.{nameof(GetMembers)}");
386386
scope.Start();
387387

388388
try

sdk/core/Azure.Core.TestFramework/src/DiagnosticScopeValidatingInterceptor.cs

Lines changed: 172 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.Linq;
67
using System.Reflection;
8+
using System.Threading;
79
using System.Threading.Tasks;
810
using Azure.Core.Tests;
911
using Castle.DynamicProxy;
@@ -12,94 +14,198 @@ namespace Azure.Core.TestFramework
1214
{
1315
public class DiagnosticScopeValidatingInterceptor : IInterceptor
1416
{
17+
private static readonly MethodInfo ValidateDiagnosticScopeMethod = typeof(DiagnosticScopeValidatingInterceptor).GetMethod(nameof(AwaitAndValidateDiagnosticScope), BindingFlags.NonPublic | BindingFlags.Static);
1518
public void Intercept(IInvocation invocation)
1619
{
1720
var methodName = invocation.Method.Name;
18-
if (methodName.EndsWith("Async"))
19-
{
20-
Type declaringType = invocation.Method.DeclaringType;
21-
var ns = declaringType.Namespace;
22-
var expectedName = declaringType.Name + "." + methodName.Substring(0, methodName.Length - 5);
23-
using ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure."), asyncLocal: true);
24-
invocation.Proceed();
2521

26-
bool expectFailure = false;
27-
bool skipChecks = false;
22+
Type declaringType = invocation.Method.DeclaringType;
23+
var expectedName = declaringType.Name + "." + methodName.Substring(0, methodName.Length - 5);
24+
bool strict = !invocation.Method.GetCustomAttributes(true).Any(a => a.GetType().FullName == "Azure.Core.ForwardsClientCallsAttribute");
2825

29-
bool strict = !invocation.Method.GetCustomAttributes(true).Any(a => a.GetType().FullName == "Azure.Core.ForwardsClientCallsAttribute");
30-
if (invocation.Method.ReturnType.Name.Contains("Pageable") ||
31-
invocation.Method.ReturnType.Name.Contains("IAsyncEnumerable"))
26+
if (invocation.Method.ReturnType is {IsGenericType: true} genericType &&
27+
genericType.GetGenericTypeDefinition() == typeof(AsyncPageable<>))
28+
{
29+
invocation.Proceed();
30+
invocation.ReturnValue = Activator.CreateInstance(typeof(DiagnosticScopeValidatingAsyncEnumerable<>).MakeGenericType(genericType.GenericTypeArguments[0]), invocation.ReturnValue, expectedName, methodName, strict);
31+
}
32+
else if (methodName.EndsWith("Async") && !invocation.Method.ReturnType.Name.Contains("IAsyncEnumerable"))
33+
{
34+
Type genericArgument = typeof(object);
35+
Type awaitableType = invocation.Method.ReturnType;
36+
if (invocation.Method.ReturnType is {IsGenericType: true, GenericTypeArguments: {Length: 1} genericTypeArguments})
3237
{
33-
return;
38+
genericArgument = genericTypeArguments[0];
39+
awaitableType = invocation.Method.ReturnType.GetGenericTypeDefinition();
3440
}
41+
ValidateDiagnosticScopeMethod.MakeGenericMethod(genericArgument)
42+
.Invoke(null, new object[]{awaitableType, invocation, expectedName, methodName, strict});
43+
}
44+
else
45+
{
46+
invocation.Proceed();
47+
}
48+
}
3549

36-
try
50+
internal static void AwaitAndValidateDiagnosticScope<T>(Type genericType, IInvocation invocation, string expectedName, string methodName, bool strict)
51+
{
52+
// All this ceremony is not to await the returned Task/ValueTask syncronously
53+
// instead we are replacing the invocation.ReturnValue with the ValidateDiagnosticScope task
54+
// but we need to make sure the types match
55+
if (genericType == typeof(Task<>))
56+
{
57+
invocation.ReturnValue = ValidateDiagnosticScope(async () =>
3758
{
38-
object returnValue = invocation.ReturnValue;
39-
if (returnValue is Task t)
40-
{
41-
t.GetAwaiter().GetResult();
42-
}
43-
else
44-
{
45-
// Await ValueTask
46-
Type returnType = returnValue.GetType();
47-
MethodInfo getAwaiterMethod = returnType.GetMethod("GetAwaiter", BindingFlags.Instance | BindingFlags.Public);
48-
MethodInfo getResultMethod = getAwaiterMethod.ReturnType.GetMethod("GetResult", BindingFlags.Instance | BindingFlags.Public);
49-
50-
getResultMethod.Invoke(
51-
getAwaiterMethod.Invoke(returnValue, Array.Empty<object>()),
52-
Array.Empty<object>());
53-
}
54-
}
55-
catch (Exception ex)
59+
invocation.Proceed();
60+
return (await (Task<T>)invocation.ReturnValue, false);
61+
}, expectedName, methodName, strict).AsTask();
62+
}
63+
else if (genericType == typeof(Task))
64+
{
65+
invocation.ReturnValue = ValidateDiagnosticScope<object>(async () =>
66+
{
67+
invocation.Proceed();
68+
await (Task)invocation.ReturnValue;
69+
return default;
70+
}, expectedName, methodName, strict).AsTask();
71+
}
72+
else if (genericType == typeof(ValueTask<>))
73+
{
74+
invocation.ReturnValue = ValidateDiagnosticScope(async () =>
5675
{
57-
expectFailure = true;
76+
invocation.Proceed();
77+
return (await (ValueTask<T>)invocation.ReturnValue, false);
78+
}, expectedName, methodName, strict);
79+
}
80+
else if (genericType == typeof(ValueTask))
81+
{
82+
invocation.ReturnValue = new ValueTask(ValidateDiagnosticScope<object>(async () =>
83+
{
84+
invocation.Proceed();
85+
await (ValueTask)invocation.ReturnValue;;
86+
return default;
87+
}, expectedName, methodName, strict).AsTask());
88+
}
89+
else
90+
{
91+
ValidateDiagnosticScope<object>(() =>
92+
{
93+
invocation.Proceed();
94+
return default;
95+
}, expectedName, methodName, strict).GetAwaiter().GetResult();
96+
}
97+
}
5898

59-
if (ex is ArgumentException)
60-
{
61-
// Don't expect scope for argument validation failures
62-
skipChecks = true;
63-
}
99+
internal static async ValueTask<T> ValidateDiagnosticScope<T>(Func<ValueTask<(T Result, bool SkipChecks)>> action, string expectedName, string methodName, bool strict)
100+
{
101+
bool expectFailure = false;
102+
bool skipChecks = false;
103+
T result;
104+
105+
using ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure."), asyncLocal: true);
106+
try
107+
{
108+
(result, skipChecks) = await action();
109+
}
110+
catch (Exception ex)
111+
{
112+
expectFailure = true;
113+
114+
if (ex is ArgumentException)
115+
{
116+
// Don't expect scope for argument validation failures
117+
skipChecks = true;
64118
}
65-
finally
119+
120+
throw;
121+
}
122+
finally
123+
{
124+
// Remove subscribers before enumerating events.
125+
diagnosticListener.Dispose();
126+
if (!skipChecks)
66127
{
67-
// Remove subscribers before enumerating events.
68-
diagnosticListener.Dispose();
69-
if (!skipChecks)
128+
if (strict)
70129
{
71-
if (strict)
130+
ClientDiagnosticListener.ProducedDiagnosticScope e = diagnosticListener.Scopes.FirstOrDefault(e => e.Name == expectedName);
131+
132+
if (e == default)
133+
{
134+
throw new InvalidOperationException($"Expected diagnostic scope not created {expectedName} {Environment.NewLine}" +
135+
$" created {diagnosticListener.Scopes.Count} scopes [{string.Join(", ", diagnosticListener.Scopes)}] {Environment.NewLine}" +
136+
$" You may have forgotten to add clientDiagnostics.CreateScope(...), set your operationId to {expectedName} in {methodName} or applied the Azure.Core.ForwardsClientCallsAttribute to {methodName}.");
137+
}
138+
139+
if (!e.Activity.Tags.Any(tag => tag.Key == "az.namespace"))
72140
{
73-
ClientDiagnosticListener.ProducedDiagnosticScope e = diagnosticListener.Scopes.FirstOrDefault(e => e.Name == expectedName);
74-
75-
if (e == default)
76-
{
77-
throw new InvalidOperationException($"Expected diagnostic scope not created {expectedName} {Environment.NewLine} created scopes {string.Join(", ", diagnosticListener.Scopes)} {Environment.NewLine} You may have forgotten to set your operationId to {expectedName} in {methodName} or applied the Azure.Core.ForwardsClientCallsAttribute to {methodName}.");
78-
}
79-
80-
if (!e.Activity.Tags.Any(tag => tag.Key == "az.namespace"))
81-
{
82-
throw new InvalidOperationException($"All diagnostic scopes should have 'az.namespace' attribute, make sure the assembly containing **ClientOptions type is marked with the AzureResourceProviderNamespace attribute specifying the appropriate provider. This attribute should be included in AssemblyInfo, and can be included by pulling in AzureResourceProviderNamespaceAttribute.cs using the AzureCoreSharedSources alias.");
83-
}
84-
85-
if (expectFailure && !e.IsFailed)
86-
{
87-
throw new InvalidOperationException($"Expected scope {expectedName} to be marked as failed but it succeeded");
88-
}
141+
throw new InvalidOperationException($"All diagnostic scopes should have 'az.namespace' attribute, make sure the assembly containing **ClientOptions type is marked with the AzureResourceProviderNamespace attribute specifying the appropriate provider. This attribute should be included in AssemblyInfo, and can be included by pulling in AzureResourceProviderNamespaceAttribute.cs using the AzureCoreSharedSources alias.");
89142
}
90-
else
143+
144+
if (expectFailure && !e.IsFailed)
145+
{
146+
throw new InvalidOperationException($"Expected scope {expectedName} to be marked as failed but it succeeded");
147+
}
148+
}
149+
else
150+
{
151+
if (!diagnosticListener.Scopes.Any())
91152
{
92-
if (!diagnosticListener.Scopes.Any())
93-
{
94-
throw new InvalidOperationException($"Expected some diagnostic scopes to be created, found none");
95-
}
153+
throw new InvalidOperationException($"Expected some diagnostic scopes to be created, found none");
96154
}
97155
}
98156
}
99157
}
100-
else
158+
159+
return result;
160+
}
161+
162+
internal class DiagnosticScopeValidatingAsyncEnumerable<T> : AsyncPageable<T>
163+
{
164+
private readonly AsyncPageable<T> _pageable;
165+
private readonly string _expectedName;
166+
private readonly string _methodName;
167+
private readonly bool _strict;
168+
private readonly bool _overridesGetAsyncEnumerator;
169+
170+
public DiagnosticScopeValidatingAsyncEnumerable(AsyncPageable<T> pageable, string expectedName, string methodName, bool strict)
101171
{
102-
invocation.Proceed();
172+
if (pageable == null) throw new ArgumentNullException(nameof(pageable), "Operations returning [Async]Pageable should never return null.");
173+
174+
// If AsyncPageable overrides GetAsyncEnumerator we have to pass the call through to it
175+
// this would effectively disable the validation so avoid doing it as much as possible
176+
var getAsyncEnumeratorMethod = pageable.GetType().GetMethod("GetAsyncEnumerator", BindingFlags.Public | BindingFlags.Instance);
177+
_overridesGetAsyncEnumerator = getAsyncEnumeratorMethod.DeclaringType is {IsGenericType: true} genericType &&
178+
genericType.GetGenericTypeDefinition() != typeof(AsyncPageable<>);
179+
180+
_pageable = pageable;
181+
_expectedName = expectedName;
182+
_methodName = methodName;
183+
_strict = strict;
184+
}
185+
186+
public override IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
187+
{
188+
if (_overridesGetAsyncEnumerator)
189+
{
190+
return _pageable.GetAsyncEnumerator(cancellationToken);
191+
}
192+
193+
return base.GetAsyncEnumerator(cancellationToken);
194+
}
195+
196+
public override async IAsyncEnumerable<Page<T>> AsPages(string continuationToken = null, int? pageSizeHint = null)
197+
{
198+
await using var enumerator = _pageable.AsPages(continuationToken, pageSizeHint).GetAsyncEnumerator();
199+
200+
while (await ValidateDiagnosticScope(async () =>
201+
{
202+
bool movedNext = await enumerator.MoveNextAsync();
203+
// Don't expect the MoveNextAsync call that returns false to create scope
204+
return (movedNext, !movedNext);
205+
}, _expectedName, $"AsPages() implementation returned from {_methodName}", _strict))
206+
{
207+
yield return enumerator.Current;
208+
}
103209
}
104210
}
105211
}

0 commit comments

Comments
 (0)