Skip to content

Commit ad49f21

Browse files
Merge sync/async code paths (Azure#36010)
1 parent b6e493b commit ad49f21

File tree

3 files changed

+66
-179
lines changed

3 files changed

+66
-179
lines changed

sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2890,14 +2890,7 @@ internal async Task<Response> StagedDownloadAsync(
28902890
{
28912891
ClientSideDecryptor.BeginContentEncryptionKeyCaching();
28922892
}
2893-
if (async)
2894-
{
2895-
return await downloader.DownloadToAsync(destination, conditions, cancellationToken).ConfigureAwait(false);
2896-
}
2897-
else
2898-
{
2899-
return downloader.DownloadTo(destination, conditions, cancellationToken);
2900-
}
2893+
return await downloader.DownloadToInternal(destination, conditions, async, cancellationToken).ConfigureAwait(false);
29012894
}
29022895
#endregion Parallel Download
29032896

sdk/storage/Azure.Storage.Blobs/src/PartitionedDownloader.cs

Lines changed: 64 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ public PartitionedDownloader(
125125
}
126126
}
127127

128-
public async Task<Response> DownloadToAsync(
128+
public async Task<Response> DownloadToInternal(
129129
Stream destination,
130130
BlobRequestConditions conditions,
131+
bool async,
131132
CancellationToken cancellationToken)
132133
{
133134
// Wrap the download range calls in a Download span for distributed
@@ -142,20 +143,18 @@ public async Task<Response> DownloadToAsync(
142143
// a large blob, we'll get its full size in Content-Range and
143144
// can keep downloading it in segments.
144145
var initialRange = new HttpRange(0, _initialRangeSize);
145-
Task<Response<BlobDownloadStreamingResult>> initialResponseTask =
146-
_client.DownloadStreamingInternal(
146+
Response<BlobDownloadStreamingResult> initialResponse;
147+
148+
try
149+
{
150+
initialResponse = await _client.DownloadStreamingInternal(
147151
initialRange,
148152
conditions,
149153
ValidationOptions,
150154
_progress,
151155
_innerOperationName,
152-
async: true,
153-
cancellationToken).AsTask();
154-
155-
Response<BlobDownloadStreamingResult> initialResponse = null;
156-
try
157-
{
158-
initialResponse = await initialResponseTask.ConfigureAwait(false);
156+
async,
157+
cancellationToken).ConfigureAwait(false);
159158
}
160159
catch (RequestFailedException ex) when (ex.ErrorCode == BlobErrorCode.InvalidRange)
161160
{
@@ -165,9 +164,8 @@ public async Task<Response> DownloadToAsync(
165164
ValidationOptions,
166165
_progress,
167166
_innerOperationName,
168-
async: true,
169-
cancellationToken)
170-
.ConfigureAwait(false);
167+
async,
168+
cancellationToken).ConfigureAwait(false);
171169
}
172170

173171
// If the initial request returned no content (i.e., a 304),
@@ -187,7 +185,7 @@ public async Task<Response> DownloadToAsync(
187185
new ClientSideDecryptor(_client.ClientSideEncryption)).DecryptWholeBlobWriteInternal(
188186
destination,
189187
initialResponse.Value.Details.Metadata,
190-
async: true,
188+
async,
191189
cancellationToken).ConfigureAwait(false);
192190
}
193191
}
@@ -201,11 +199,11 @@ public async Task<Response> DownloadToAsync(
201199
await CopyToInternal(
202200
initialResponse,
203201
destination,
204-
async: true,
202+
async,
205203
cancellationToken)
206204
.ConfigureAwait(false);
207205

208-
await FlushFinalIfNecessaryInternal(destination, async: true, cancellationToken).ConfigureAwait(false);
206+
await FlushFinalIfNecessaryInternal(destination, async, cancellationToken).ConfigureAwait(false);
209207
return initialResponse.GetRawResponse();
210208
}
211209

@@ -215,53 +213,70 @@ await CopyToInternal(
215213
ETag etag = initialResponse.Value.Details.ETag;
216214
BlobRequestConditions conditionsWithEtag = conditions?.WithIfMatch(etag) ?? new BlobRequestConditions { IfMatch = etag };
217215

218-
// Create a queue of tasks that will each download one segment
219-
// of the blob. The queue maintains the order of the segments
220-
// so we can keep appending to the end of the destination
221-
// stream when each segment finishes.
222-
var runningTasks = new Queue<Task<Response<BlobDownloadStreamingResult>>>();
223-
runningTasks.Enqueue(initialResponseTask);
224-
if (_maxWorkerCount <= 1)
216+
#pragma warning disable AZC0110 // DO NOT use await keyword in possibly synchronous scope.
217+
// Rule checker cannot understand this section, but this
218+
// massively reduces code duplication.
219+
Queue<Task<Response<BlobDownloadStreamingResult>>> runningTasks = null;
220+
int effectiveWorkerCount = async ? _maxWorkerCount : 1;
221+
if (effectiveWorkerCount > 1)
225222
{
226-
// consume initial task immediately if _maxWorkerCount is 1 (or less to be safe). Otherwise loop below would have 2 concurrent tasks.
227-
await ConsumeQueuedTask().ConfigureAwait(false);
223+
runningTasks = new();
224+
runningTasks.Enqueue(Task.FromResult(initialResponse));
225+
}
226+
else
227+
{
228+
await CopyToInternal(initialResponse, destination, async, cancellationToken).ConfigureAwait(false);
228229
}
229230

230231
// Fill the queue with tasks to download each of the remaining
231232
// ranges in the blob
232233
foreach (HttpRange httpRange in GetRanges(initialLength, totalLength))
233234
{
234-
// Add the next Task (which will start the download but
235-
// return before it's completed downloading)
236-
runningTasks.Enqueue(_client.DownloadStreamingInternal(
237-
httpRange,
238-
conditionsWithEtag,
239-
ValidationOptions,
240-
_progress,
241-
_innerOperationName,
242-
async: true,
243-
cancellationToken).AsTask());
244-
245-
// If we have fewer tasks than alotted workers, then just
246-
// continue adding tasks until we have _maxWorkerCount
247-
// running in parallel
248-
if (runningTasks.Count < _maxWorkerCount)
235+
ValueTask<Response<BlobDownloadStreamingResult>> responseValueTask = _client
236+
.DownloadStreamingInternal(
237+
httpRange,
238+
conditionsWithEtag,
239+
ValidationOptions,
240+
_progress,
241+
_innerOperationName,
242+
async,
243+
cancellationToken);
244+
if (runningTasks != null)
249245
{
250-
continue;
246+
// Add the next Task (which will start the download but
247+
// return before it's completed downloading)
248+
runningTasks.Enqueue(responseValueTask.AsTask());
249+
250+
// If we have fewer tasks than alotted workers, then just
251+
// continue adding tasks until we have effectiveWorkerCount
252+
// running in parallel
253+
if (runningTasks.Count < effectiveWorkerCount)
254+
{
255+
continue;
256+
}
257+
258+
// Once all the workers are busy, wait for the first
259+
// segment to finish downloading before we create more work
260+
await ConsumeQueuedTask().ConfigureAwait(false);
261+
}
262+
else
263+
{
264+
Response<BlobDownloadStreamingResult> result = await responseValueTask.ConfigureAwait(false);
265+
await CopyToInternal(result, destination, async, cancellationToken).ConfigureAwait(false);
251266
}
252-
253-
// Once all the workers are busy, wait for the first
254-
// segment to finish downloading before we create more work
255-
await ConsumeQueuedTask().ConfigureAwait(false);
256267
}
257268

258269
// Wait for all of the remaining segments to download
259-
while (runningTasks.Count > 0)
270+
if (runningTasks != null)
260271
{
261-
await ConsumeQueuedTask().ConfigureAwait(false);
272+
while (runningTasks.Count > 0)
273+
{
274+
await ConsumeQueuedTask().ConfigureAwait(false);
275+
}
262276
}
277+
#pragma warning restore AZC0110 // DO NOT use await keyword in possibly synchronous scope.
263278

264-
await FlushFinalIfNecessaryInternal(destination, async: true, cancellationToken).ConfigureAwait(false);
279+
await FlushFinalIfNecessaryInternal(destination, async, cancellationToken).ConfigureAwait(false);
265280
return initialResponse.GetRawResponse();
266281

267282
// Wait for the first segment in the queue of tasks to complete
@@ -280,7 +295,7 @@ async Task ConsumeQueuedTask()
280295
await CopyToInternal(
281296
response,
282297
destination,
283-
async: true,
298+
async,
284299
cancellationToken)
285300
.ConfigureAwait(false);
286301
}
@@ -296,120 +311,6 @@ await CopyToInternal(
296311
}
297312
}
298313

299-
public Response DownloadTo(
300-
Stream destination,
301-
BlobRequestConditions conditions,
302-
CancellationToken cancellationToken)
303-
{
304-
// Wrap the download range calls in a Download span for distributed
305-
// tracing
306-
DiagnosticScope scope = _client.ClientConfiguration.ClientDiagnostics.CreateScope(_operationName);
307-
try
308-
{
309-
scope.Start();
310-
311-
// Just start downloading using an initial range. If it's a
312-
// small blob, we'll get the whole thing in one shot. If it's
313-
// a large blob, we'll get its full size in Content-Range and
314-
// can keep downloading it in segments.
315-
var initialRange = new HttpRange(0, _initialRangeSize);
316-
Response<BlobDownloadStreamingResult> initialResponse;
317-
318-
try
319-
{
320-
initialResponse = _client.DownloadStreamingInternal(
321-
initialRange,
322-
conditions,
323-
ValidationOptions,
324-
_progress,
325-
_innerOperationName,
326-
async: false,
327-
cancellationToken).EnsureCompleted();
328-
}
329-
catch (RequestFailedException ex) when (ex.ErrorCode == BlobErrorCode.InvalidRange)
330-
{
331-
initialResponse = _client.DownloadStreamingInternal(
332-
range: default,
333-
conditions,
334-
ValidationOptions,
335-
_progress,
336-
_innerOperationName,
337-
async: false,
338-
cancellationToken).EnsureCompleted();
339-
}
340-
341-
// If the initial request returned no content (i.e., a 304),
342-
// we'll pass that back to the user immediately
343-
if (initialResponse.IsUnavailable())
344-
{
345-
return initialResponse.GetRawResponse();
346-
}
347-
348-
// We deferred client-side encryption, so now we must handle it before anything
349-
// is written to destination
350-
if (_client.UsingClientSideEncryption)
351-
{
352-
if (initialResponse.Value.Details.Metadata.TryGetValue(Constants.ClientSideEncryption.EncryptionDataKey, out string rawEncryptiondata))
353-
{
354-
destination = new BlobClientSideDecryptor(
355-
new ClientSideDecryptor(_client.ClientSideEncryption)).DecryptWholeBlobWriteInternal(
356-
destination,
357-
initialResponse.Value.Details.Metadata,
358-
async: false,
359-
cancellationToken).EnsureCompleted();
360-
}
361-
}
362-
363-
// Copy the first segment to the destination stream
364-
CopyToInternal(initialResponse, destination, async: false, cancellationToken).EnsureCompleted();
365-
366-
// If the first segment was the entire blob, we're finished now
367-
long initialLength = initialResponse.Value.Details.ContentLength;
368-
long totalLength = ParseRangeTotalLength(initialResponse.Value.Details.ContentRange);
369-
if (initialLength == totalLength)
370-
{
371-
FlushFinalIfNecessaryInternal(destination, async: false, cancellationToken).EnsureCompleted();
372-
return initialResponse.GetRawResponse();
373-
}
374-
375-
// Capture the etag from the first segment and construct
376-
// conditions to ensure the blob doesn't change while we're
377-
// downloading the remaining segments
378-
ETag etag = initialResponse.Value.Details.ETag;
379-
BlobRequestConditions conditionsWithEtag = conditions?.WithIfMatch(etag) ?? new BlobRequestConditions { IfMatch = etag };
380-
381-
// Download each of the remaining ranges in the blob
382-
foreach (HttpRange httpRange in GetRanges(initialLength, totalLength))
383-
{
384-
// Don't need to worry about 304s here because the ETag
385-
// condition will turn into a 412 and throw a proper
386-
// RequestFailedException
387-
Response<BlobDownloadStreamingResult> result = _client.DownloadStreamingInternal(
388-
httpRange,
389-
conditionsWithEtag,
390-
ValidationOptions,
391-
_progress,
392-
_innerOperationName,
393-
async: false,
394-
cancellationToken).EnsureCompleted();
395-
CopyToInternal(result, destination, async: false, cancellationToken).EnsureCompleted();
396-
}
397-
398-
FlushFinalIfNecessaryInternal(destination, async: false, cancellationToken).EnsureCompleted();
399-
400-
return initialResponse.GetRawResponse();
401-
}
402-
catch (Exception ex)
403-
{
404-
scope.Failed(ex);
405-
throw;
406-
}
407-
finally
408-
{
409-
scope.Dispose();
410-
}
411-
}
412-
413314
private static long ParseRangeTotalLength(string range)
414315
{
415316
if (range == null)

sdk/storage/Azure.Storage.Blobs/tests/PartitionedDownloaderTests.cs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,7 @@ private void SetupDownload(Mock<BlobBaseClient> blockClient, MockDataSource data
229229

230230
private async Task<Response> InvokeDownloadToAsync(PartitionedDownloader downloader, Stream stream)
231231
{
232-
if (_async)
233-
{
234-
return await downloader.DownloadToAsync(stream, s_conditions, s_cancellationToken);
235-
}
236-
else
237-
{
238-
return downloader.DownloadTo(stream, s_conditions, s_cancellationToken);
239-
}
232+
return await downloader.DownloadToInternal(stream, s_conditions, _async, s_cancellationToken);
240233
}
241234

242235
private class MockDataSource

0 commit comments

Comments
 (0)