|
6 | 6 | import com.azure.core.http.HttpClient; |
7 | 7 | import com.azure.core.http.HttpHeaders; |
8 | 8 | import com.azure.core.http.HttpMethod; |
| 9 | +import com.azure.core.http.HttpPipeline; |
9 | 10 | import com.azure.core.http.HttpPipelineBuilder; |
10 | 11 | import com.azure.core.http.HttpRequest; |
11 | 12 | import com.azure.core.http.MockHttpResponse; |
| 13 | +import com.azure.core.http.policy.RetryPolicy; |
12 | 14 | import com.azure.core.http.rest.Response; |
13 | 15 | import com.azure.core.http.rest.SimpleResponse; |
14 | 16 | import com.azure.core.implementation.serializer.DefaultJsonSerializer; |
|
19 | 21 | import org.junit.jupiter.api.BeforeEach; |
20 | 22 | import org.junit.jupiter.api.Test; |
21 | 23 | import org.junit.jupiter.params.ParameterizedTest; |
| 24 | +import org.junit.jupiter.params.provider.MethodSource; |
22 | 25 | import org.junit.jupiter.params.provider.ValueSource; |
23 | 26 | import org.mockito.ArgumentCaptor; |
24 | 27 | import org.mockito.Mock; |
|
28 | 31 | import reactor.test.StepVerifier; |
29 | 32 |
|
30 | 33 | import java.time.Duration; |
| 34 | +import java.util.concurrent.atomic.AtomicInteger; |
31 | 35 | import java.util.function.Supplier; |
| 36 | +import java.util.stream.Stream; |
32 | 37 |
|
33 | 38 | import static org.junit.jupiter.api.Assertions.assertEquals; |
34 | 39 | import static org.mockito.ArgumentMatchers.any; |
@@ -492,6 +497,64 @@ public void pollingStrategyPassContextToHttpClient() { |
492 | 497 | assertEquals(3, activationCallCount[0]); |
493 | 498 | } |
494 | 499 |
|
| 500 | + @ParameterizedTest |
| 501 | + @MethodSource("statusCodeProvider") |
| 502 | + public void retryPollingOperationWithPostActivationOperation(int[] args) { |
| 503 | + int[] activationCallCount = new int[1]; |
| 504 | + activationCallCount[0] = 0; |
| 505 | + String mockPollUrl = "http://localhost/poll"; |
| 506 | + String finalResultUrl = "http://localhost/final"; |
| 507 | + when(activationOperation.get()).thenReturn(Mono.defer(() -> { |
| 508 | + activationCallCount[0]++; |
| 509 | + SimpleResponse<PollResult> response = new SimpleResponse<>( |
| 510 | + new HttpRequest(HttpMethod.POST, "http://localhost"), |
| 511 | + 200, |
| 512 | + new HttpHeaders().set("Operation-Location", mockPollUrl).set("Location", finalResultUrl), |
| 513 | + new PollResult("InProgress")); |
| 514 | + return Mono.just(response); |
| 515 | + })); |
| 516 | + |
| 517 | + HttpRequest pollRequest = new HttpRequest(HttpMethod.GET, mockPollUrl); |
| 518 | + AtomicInteger attemptCount = new AtomicInteger(); |
| 519 | + HttpPipeline pipeline = new HttpPipelineBuilder() |
| 520 | + .policies(new RetryPolicy()) |
| 521 | + .httpClient(request -> { |
| 522 | + int count = attemptCount.getAndIncrement(); |
| 523 | + if (mockPollUrl.equals(request.getUrl().toString()) && count == 0) { |
| 524 | + return Mono.just(new MockHttpResponse(pollRequest, args[0], |
| 525 | + new HttpHeaders().set("Location", finalResultUrl), |
| 526 | + new PollResult("Retry"))); |
| 527 | + } else if (mockPollUrl.equals(request.getUrl().toString()) && count == 1) { |
| 528 | + return Mono.just(new MockHttpResponse(pollRequest, args[1], |
| 529 | + new HttpHeaders().set("Location", finalResultUrl), |
| 530 | + new PollResult("Succeeded"))); |
| 531 | + } else if (finalResultUrl.equals(request.getUrl().toString())) { |
| 532 | + return Mono.just(new MockHttpResponse(pollRequest, args[2], new HttpHeaders(), |
| 533 | + new PollResult("final-state"))); |
| 534 | + } else { |
| 535 | + return Mono.error(new IllegalArgumentException("Unknown request URL " + request.getUrl())); |
| 536 | + } |
| 537 | + }) |
| 538 | + .build(); |
| 539 | + PollerFlux<PollResult, PollResult> pollerFlux = PollerFlux.create( |
| 540 | + Duration.ofSeconds(1), |
| 541 | + () -> activationOperation.get(), |
| 542 | + new OperationResourcePollingStrategy<>(pipeline), |
| 543 | + new TypeReference<PollResult>() { }, new TypeReference<PollResult>() { }); |
| 544 | + |
| 545 | + StepVerifier.create(pollerFlux.takeUntil(apr -> apr.getStatus().isComplete()).last().flatMap(AsyncPollResponse::getFinalResult)) |
| 546 | + .expectNextMatches(pollResult -> "final-state".equals(pollResult.getStatus())) |
| 547 | + .verifyComplete(); |
| 548 | + assertEquals(args[3], attemptCount.get()); |
| 549 | + assertEquals(1, activationCallCount[0]); |
| 550 | + } |
| 551 | + |
| 552 | + static Stream<int[]> statusCodeProvider() { |
| 553 | + return Stream.of( |
| 554 | + new int[]{500, 200, 200, 3}, |
| 555 | + new int[]{200, 500, 200, 2}); |
| 556 | + } |
| 557 | + |
495 | 558 | public static class PollResult { |
496 | 559 | private String status; |
497 | 560 | private String resourceLocation; |
|
0 commit comments