Skip to content

Commit 837fef3

Browse files
authored
Azure OpenAI adjusting unit tests (Azure#35200)
1 parent 239672d commit 837fef3

File tree

49 files changed

+2240
-190
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2240
-190
lines changed

sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAIAsyncClientTest.java

Lines changed: 102 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
import com.azure.ai.openai.models.ChatCompletions;
77
import com.azure.ai.openai.models.ChatCompletionsOptions;
8-
import com.azure.ai.openai.models.ChatRole;
98
import com.azure.ai.openai.models.Completions;
109
import com.azure.ai.openai.models.CompletionsOptions;
10+
import com.azure.ai.openai.models.CompletionsUsage;
1111
import com.azure.ai.openai.models.Embeddings;
12+
import com.azure.core.exception.ClientAuthenticationException;
13+
import com.azure.core.exception.HttpResponseException;
1214
import com.azure.core.http.HttpClient;
1315
import com.azure.core.http.rest.RequestOptions;
1416
import com.azure.core.util.BinaryData;
@@ -17,10 +19,13 @@
1719
import org.junit.jupiter.params.provider.MethodSource;
1820
import reactor.test.StepVerifier;
1921

22+
import java.util.ArrayList;
23+
2024
import static com.azure.ai.openai.TestUtils.DISPLAY_NAME_WITH_ARGUMENTS;
2125
import static org.junit.jupiter.api.Assertions.assertEquals;
22-
import static org.junit.jupiter.api.Assertions.assertFalse;
26+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
2327
import static org.junit.jupiter.api.Assertions.assertNotNull;
28+
import static org.junit.jupiter.api.Assertions.assertTrue;
2429

2530
public class NonAzureOpenAIAsyncClientTest extends OpenAIClientTestBase {
2631
private OpenAIAsyncClient client;
@@ -33,12 +38,12 @@ private OpenAIAsyncClient getNonAzureOpenAIAsyncClient(HttpClient httpClient) {
3338

3439
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
3540
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
36-
public void getCompletions(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
41+
public void testGetCompletions(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
3742
client = getNonAzureOpenAIAsyncClient(httpClient);
3843
getCompletionsRunner((modelId, prompt) -> {
3944
StepVerifier.create(client.getCompletions(modelId, new CompletionsOptions(prompt)))
4045
.assertNext(resultCompletions -> {
41-
assertCompletions(new int[]{0}, null, null, resultCompletions);
46+
assertCompletions(1, resultCompletions);
4247
})
4348
.verifyComplete();
4449
});
@@ -49,57 +54,121 @@ public void getCompletions(HttpClient httpClient, OpenAIServiceVersion serviceVe
4954
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
5055
public void testGetCompletionsStream(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
5156
client = getNonAzureOpenAIAsyncClient(httpClient);
52-
getCompletionsRunner((modelId, prompt) -> {
53-
StepVerifier.create(client.getCompletionsStream(modelId, new CompletionsOptions(prompt)).last())
54-
.assertNext(completions -> {
55-
assertNotNull(completions.getId());
56-
assertNotNull(completions.getChoices());
57-
assertFalse(completions.getChoices().isEmpty());
58-
assertNotNull(completions.getChoices().get(0).getText());
57+
getCompletionsRunner((deploymentId, prompt) -> {
58+
StepVerifier.create(client.getCompletionsStream(deploymentId, new CompletionsOptions(prompt)))
59+
.recordWith(ArrayList::new)
60+
.thenConsumeWhile(chatCompletions -> {
61+
assertCompletionsStream(chatCompletions);
62+
return true;
5963
})
64+
.consumeRecordedWith(messageList -> assertTrue(messageList.size() > 1))
6065
.verifyComplete();
6166
});
6267
}
6368

6469
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
6570
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
66-
public void getCompletionsFromPrompt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
71+
public void testGetCompletionsFromPrompt(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
6772
client = getNonAzureOpenAIAsyncClient(httpClient);
6873
getCompletionsFromSinglePromptRunner((modelId, prompt) -> {
6974
StepVerifier.create(client.getCompletions(modelId, prompt))
7075
.assertNext(resultCompletions -> {
71-
assertCompletions(new int[]{0}, null, null, resultCompletions);
76+
assertCompletions(1, resultCompletions);
7277
})
7378
.verifyComplete();
7479
});
7580
}
7681

7782
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
7883
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
79-
public void getCompletionsWithResponse(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
84+
public void testGetCompletionsWithResponse(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
8085
client = getNonAzureOpenAIAsyncClient(httpClient);
8186
getCompletionsRunner((modelId, prompt) -> {
8287
StepVerifier.create(client.getCompletionsWithResponse(modelId,
8388
BinaryData.fromObject(new CompletionsOptions(prompt)),
8489
new RequestOptions()))
8590
.assertNext(response -> {
86-
assertEquals(200, response.getStatusCode());
87-
Completions resultCompletions = response.getValue().toObject(Completions.class);
88-
assertCompletions(new int[]{0}, null, null, resultCompletions);
91+
Completions resultCompletions = assertAndGetValueFromResponse(response, Completions.class, 200);
92+
assertCompletions(1, resultCompletions);
93+
})
94+
.verifyComplete();
95+
});
96+
}
97+
98+
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
99+
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
100+
public void testGetCompletionsBadSecretKey(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
101+
client = getNonAzureOpenAIAsyncClient(httpClient);
102+
getCompletionsRunner((modelId, prompt) -> {
103+
StepVerifier.create(client.getCompletionsWithResponse(modelId,
104+
BinaryData.fromObject(new CompletionsOptions(prompt)),
105+
new RequestOptions()))
106+
.verifyErrorSatisfies(throwable -> {
107+
assertInstanceOf(ClientAuthenticationException.class, throwable);
108+
assertEquals(401, ((ClientAuthenticationException) throwable).getResponse().getStatusCode());
109+
});
110+
});
111+
}
112+
113+
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
114+
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
115+
public void testGetCompletionsExpiredSecretKey(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
116+
client = getNonAzureOpenAIAsyncClient(httpClient);
117+
getCompletionsRunner((modelId, prompt) -> {
118+
StepVerifier.create(client.getCompletionsWithResponse(modelId,
119+
BinaryData.fromObject(new CompletionsOptions(prompt)),
120+
new RequestOptions()))
121+
.verifyErrorSatisfies(throwable -> {
122+
assertInstanceOf(HttpResponseException.class, throwable);
123+
assertEquals(429, ((HttpResponseException) throwable).getResponse().getStatusCode());
124+
});
125+
});
126+
}
127+
128+
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
129+
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
130+
public void testGetCompletionsUsageField(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
131+
client = getNonAzureOpenAIAsyncClient(httpClient);
132+
getCompletionsRunner((modelId, prompt) -> {
133+
CompletionsOptions completionsOptions = new CompletionsOptions(prompt);
134+
completionsOptions.setMaxTokens(1024);
135+
completionsOptions.setN(3);
136+
completionsOptions.setLogprobs(1);
137+
StepVerifier.create(client.getCompletions(modelId, completionsOptions))
138+
.assertNext(resultCompletions -> {
139+
CompletionsUsage usage = resultCompletions.getUsage();
140+
assertCompletions(completionsOptions.getN() * completionsOptions.getPrompt().size(), resultCompletions);
141+
assertNotNull(usage);
142+
assertTrue(usage.getTotalTokens() > 0);
143+
assertEquals(usage.getCompletionTokens() + usage.getPromptTokens(), usage.getTotalTokens());
89144
})
90145
.verifyComplete();
91146
});
92147
}
93148

94149
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
95150
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
96-
public void getChatCompletions(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
151+
public void testGetCompletionsTokenCutoff(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
152+
client = getNonAzureOpenAIAsyncClient(httpClient);
153+
getCompletionsRunner((modelId, prompt) -> {
154+
CompletionsOptions completionsOptions = new CompletionsOptions(prompt);
155+
completionsOptions.setMaxTokens(3);
156+
StepVerifier.create(client.getCompletions(modelId, completionsOptions))
157+
.assertNext(resultCompletions ->
158+
assertCompletions(1, "length", resultCompletions))
159+
.verifyComplete();
160+
});
161+
}
162+
163+
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
164+
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
165+
public void testGetChatCompletions(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
97166
client = getNonAzureOpenAIAsyncClient(httpClient);
98167
getChatCompletionsForNonAzureRunner((modelId, chatMessages) -> {
99168
StepVerifier.create(client.getChatCompletions(modelId, new ChatCompletionsOptions(chatMessages)))
100169
.assertNext(resultChatCompletions -> {
101170
assertNotNull(resultChatCompletions.getUsage());
102-
assertChatCompletions(new int[]{0}, new ChatRole[]{ChatRole.ASSISTANT}, resultChatCompletions);
171+
assertChatCompletions(1, resultChatCompletions);
103172
})
104173
.verifyComplete();
105174
});
@@ -110,39 +179,36 @@ public void getChatCompletions(HttpClient httpClient, OpenAIServiceVersion servi
110179
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
111180
public void testGetChatCompletionsStream(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
112181
client = getNonAzureOpenAIAsyncClient(httpClient);
113-
getChatCompletionsForNonAzureRunner((modelId, chatMessages) -> {
114-
StepVerifier.create(client.getChatCompletionsStream(modelId, new ChatCompletionsOptions(chatMessages)).last())
115-
.assertNext(chatCompletions -> {
116-
assertNotNull(chatCompletions.getId());
117-
assertNotNull(chatCompletions.getChoices());
118-
assertFalse(chatCompletions.getChoices().isEmpty());
119-
assertNotNull(chatCompletions.getChoices().get(0).getDelta());
120-
})
121-
.verifyComplete();
122-
182+
getChatCompletionsRunner((deploymentId, chatMessages) -> {
183+
StepVerifier.create(client.getChatCompletionsStream(deploymentId, new ChatCompletionsOptions(chatMessages)))
184+
.recordWith(ArrayList::new)
185+
.thenConsumeWhile(chatCompletions -> true)
186+
.consumeRecordedWith(messageList -> {
187+
assertTrue(messageList.size() > 1);
188+
messageList.forEach(OpenAIClientTestBase::assertChatCompletionsStream);
189+
}).verifyComplete();
123190
});
124191
}
125192

126193
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
127194
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
128-
public void getChatCompletionsWithResponse(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
195+
public void testGetChatCompletionsWithResponse(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
129196
client = getNonAzureOpenAIAsyncClient(httpClient);
130197
getChatCompletionsForNonAzureRunner((modelId, chatMessages) -> {
131198
StepVerifier.create(client.getChatCompletionsWithResponse(modelId,
132199
BinaryData.fromObject(new ChatCompletionsOptions(chatMessages)),
133200
new RequestOptions()))
134201
.assertNext(response -> {
135-
assertEquals(200, response.getStatusCode());
136-
ChatCompletions resultChatCompletions = response.getValue().toObject(ChatCompletions.class);
137-
assertChatCompletions(new int[]{0}, new ChatRole[]{ChatRole.ASSISTANT}, resultChatCompletions);
202+
ChatCompletions resultChatCompletions = assertAndGetValueFromResponse(response, ChatCompletions.class, 200);
203+
assertChatCompletions(1, resultChatCompletions);
138204
})
139205
.verifyComplete();
140206
});
141207
}
142208

143209
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
144210
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
145-
public void getEmbeddings(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
211+
public void testGetEmbeddings(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
146212
client = getNonAzureOpenAIAsyncClient(httpClient);
147213
getEmbeddingNonAzureRunner((modelId, embeddingsOptions) -> {
148214
StepVerifier.create(client.getEmbeddings(modelId, embeddingsOptions))
@@ -153,15 +219,14 @@ public void getEmbeddings(HttpClient httpClient, OpenAIServiceVersion serviceVer
153219

154220
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
155221
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
156-
public void getEmbeddingsWithResponse(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
222+
public void testGetEmbeddingsWithResponse(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
157223
client = getNonAzureOpenAIAsyncClient(httpClient);
158224
getEmbeddingNonAzureRunner((modelId, embeddingsOptions) -> {
159225
StepVerifier.create(client.getEmbeddingsWithResponse(modelId,
160226
BinaryData.fromObject(embeddingsOptions),
161227
new RequestOptions()))
162228
.assertNext(response -> {
163-
assertEquals(200, response.getStatusCode());
164-
Embeddings resultEmbeddings = response.getValue().toObject(Embeddings.class);
229+
Embeddings resultEmbeddings = assertAndGetValueFromResponse(response, Embeddings.class, 200);
165230
assertEmbeddings(resultEmbeddings);
166231
})
167232
.verifyComplete();

0 commit comments

Comments
 (0)