Skip to content

Commit 7de093f

Browse files
[azopenai] Errors weren't propagating properly in image generation for OpenAI (Azure#21125)
Code that was handwritten needs to check and return ResponseError's by hand. Added in code to fix this for image generation, and to add in testing for all the areas that have hand-written code (ChatCompletions and Completions streaming and Dall-E integration with OpenAI). Fixes Azure#21120
1 parent c82eb8a commit 7de093f

File tree

6 files changed

+135
-17
lines changed

6 files changed

+135
-17
lines changed

sdk/cognitiveservices/azopenai/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "go",
44
"TagPrefix": "go/cognitiveservices/azopenai",
5-
"Tag": "go/cognitiveservices/azopenai_25f5951837"
5+
"Tag": "go/cognitiveservices/azopenai_2b6f93a94d"
66
}

sdk/cognitiveservices/azopenai/client_chat_completions_test.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ var chatCompletionsRequest = azopenai.ChatCompletionsOptions{
3131
},
3232
MaxTokens: to.Ptr(int32(1024)),
3333
Temperature: to.Ptr(float32(0.0)),
34-
Model: &openAIChatCompletionsModelDeployment,
34+
Model: &openAIChatCompletionsModel,
3535
}
3636

3737
var expectedContent = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10."
@@ -192,3 +192,26 @@ func TestClient_GetChatCompletions_InvalidModel(t *testing.T) {
192192
require.ErrorAs(t, err, &respErr)
193193
require.Equal(t, "DeploymentNotFound", respErr.ErrorCode)
194194
}
195+
196+
func TestClient_GetChatCompletionsStream_Error(t *testing.T) {
197+
if recording.GetRecordMode() == recording.PlaybackMode {
198+
t.Skip()
199+
}
200+
201+
doTest := func(t *testing.T, client *azopenai.Client) {
202+
t.Helper()
203+
streamResp, err := client.GetChatCompletionsStream(context.Background(), chatCompletionsRequest, nil)
204+
require.Empty(t, streamResp)
205+
assertResponseIsError(t, err)
206+
}
207+
208+
t.Run("AzureOpenAI", func(t *testing.T) {
209+
client := newBogusAzureOpenAIClient(t, chatCompletionsModelDeployment)
210+
doTest(t, client)
211+
})
212+
213+
t.Run("OpenAI", func(t *testing.T) {
214+
client := newBogusOpenAIClient(t)
215+
doTest(t, client)
216+
})
217+
}

sdk/cognitiveservices/azopenai/client_shared_test.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"strings"
1313
"testing"
1414

15+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1516
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
1617
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
1718
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
@@ -25,10 +26,10 @@ var (
2526
completionsModelDeployment string // env: AOAI_COMPLETIONS_MODEL_DEPLOYMENT
2627
chatCompletionsModelDeployment string // env: AOAI_CHAT_COMPLETIONS_MODEL_DEPLOYMENT
2728

28-
openAIKey string // env: OPENAI_API_KEY
29-
openAIEndpoint string // env: OPENAI_ENDPOINT
30-
openAICompletionsModelDeployment string // env: OPENAI_CHAT_COMPLETIONS_MODEL
31-
openAIChatCompletionsModelDeployment string // env: OPENAI_COMPLETIONS_MODEL
29+
openAIKey string // env: OPENAI_API_KEY
30+
openAIEndpoint string // env: OPENAI_ENDPOINT
31+
openAICompletionsModel string // env: OPENAI_CHAT_COMPLETIONS_MODEL
32+
openAIChatCompletionsModel string // env: OPENAI_COMPLETIONS_MODEL
3233
)
3334

3435
const fakeEndpoint = "https://recordedhost/"
@@ -42,10 +43,10 @@ func init() {
4243
openAIEndpoint = fakeEndpoint
4344

4445
completionsModelDeployment = "text-davinci-003"
45-
openAICompletionsModelDeployment = "text-davinci-003"
46+
openAICompletionsModel = "text-davinci-003"
4647

4748
chatCompletionsModelDeployment = "gpt-4"
48-
openAIChatCompletionsModelDeployment = "gpt-4"
49+
openAIChatCompletionsModel = "gpt-4"
4950
} else {
5051
if err := godotenv.Load(); err != nil {
5152
fmt.Printf("Failed to load .env file: %s\n", err)
@@ -67,8 +68,8 @@ func init() {
6768

6869
openAIKey = os.Getenv("OPENAI_API_KEY")
6970
openAIEndpoint = os.Getenv("OPENAI_ENDPOINT")
70-
openAICompletionsModelDeployment = os.Getenv("OPENAI_COMPLETIONS_MODEL")
71-
openAIChatCompletionsModelDeployment = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL")
71+
openAICompletionsModel = os.Getenv("OPENAI_COMPLETIONS_MODEL")
72+
openAIChatCompletionsModel = os.Getenv("OPENAI_CHAT_COMPLETIONS_MODEL")
7273

7374
if openAIEndpoint != "" && !strings.HasSuffix(openAIEndpoint, "/") {
7475
// (this just makes recording replacement easier)
@@ -88,6 +89,9 @@ func newRecordingTransporter(t *testing.T) policy.Transporter {
8889
err = recording.AddHeaderRegexSanitizer("Api-Key", fakeAPIKey, "", nil)
8990
require.NoError(t, err)
9091

92+
err = recording.AddHeaderRegexSanitizer("User-Agent", "fake-user-agent", ".*", nil)
93+
require.NoError(t, err)
94+
9195
// "RequestUri": "https://openai-shared.openai.azure.com/openai/deployments/text-davinci-003/completions?api-version=2023-03-15-preview",
9296
err = recording.AddURISanitizer(fakeEndpoint, regexp.QuoteMeta(endpoint), nil)
9397
require.NoError(t, err)
@@ -138,3 +142,35 @@ func newClientOptionsForTest(t *testing.T) *azopenai.ClientOptions {
138142

139143
return co
140144
}
145+
146+
// newBogusAzureOpenAIClient creates a client that uses an invalid key, which will cause Azure OpenAI to return
147+
// a failure.
148+
func newBogusAzureOpenAIClient(t *testing.T, modelDeploymentID string) *azopenai.Client {
149+
cred, err := azopenai.NewKeyCredential("bogus-api-key")
150+
require.NoError(t, err)
151+
152+
client, err := azopenai.NewClientWithKeyCredential(endpoint, cred, modelDeploymentID, newClientOptionsForTest(t))
153+
require.NoError(t, err)
154+
return client
155+
}
156+
157+
// newBogusOpenAIClient creates a client that uses an invalid key, which will cause OpenAI to return
158+
// a failure.
159+
func newBogusOpenAIClient(t *testing.T) *azopenai.Client {
160+
cred, err := azopenai.NewKeyCredential("bogus-api-key")
161+
require.NoError(t, err)
162+
163+
client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
164+
require.NoError(t, err)
165+
return client
166+
}
167+
168+
func assertResponseIsError(t *testing.T, err error) {
169+
t.Helper()
170+
171+
var respErr *azcore.ResponseError
172+
require.ErrorAs(t, err, &respErr)
173+
174+
// we sometimes get rate limited but (for this kind of test) it's actually okay
175+
require.Truef(t, respErr.StatusCode == http.StatusUnauthorized || respErr.StatusCode == http.StatusTooManyRequests, "An acceptable error comes back (actual: %d)", respErr.StatusCode)
176+
}

sdk/cognitiveservices/azopenai/custom_client_image.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ func generateImageWithOpenAI(ctx context.Context, client *Client, body ImageGene
7272
return CreateImageResponse{}, err
7373
}
7474

75+
if !runtime.HasStatusCode(resp, http.StatusOK) {
76+
return CreateImageResponse{}, runtime.NewResponseError(resp)
77+
}
78+
7579
var gens *ImageGenerations
7680

7781
if err := runtime.UnmarshalAsJSON(resp, &gens); err != nil {

sdk/cognitiveservices/azopenai/custom_client_image_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ func TestImageGeneration_OpenAI(t *testing.T) {
4040
testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatURL)
4141
}
4242

43+
func TestImageGeneration_AzureOpenAI_WithError(t *testing.T) {
44+
if recording.GetRecordMode() == recording.PlaybackMode {
45+
t.Skip()
46+
}
47+
48+
client := newBogusAzureOpenAIClient(t, "")
49+
testImageGenerationFailure(t, client)
50+
}
51+
52+
func TestImageGeneration_OpenAI_WithError(t *testing.T) {
53+
if recording.GetRecordMode() == recording.PlaybackMode {
54+
t.Skip()
55+
}
56+
57+
client := newBogusOpenAIClient(t)
58+
testImageGenerationFailure(t, client)
59+
}
60+
4361
func TestImageGeneration_OpenAI_Base64(t *testing.T) {
4462
client := newOpenAIClientForTest(t)
4563
testImageGeneration(t, client, azopenai.ImageGenerationResponseFormatB64JSON)
@@ -76,3 +94,17 @@ func testImageGeneration(t *testing.T, client *azopenai.Client, responseFormat a
7694
}
7795
}
7896
}
97+
98+
func testImageGenerationFailure(t *testing.T, bogusClient *azopenai.Client) {
99+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
100+
defer cancel()
101+
102+
resp, err := bogusClient.CreateImage(ctx, azopenai.ImageGenerationOptions{
103+
Prompt: to.Ptr("a cat"),
104+
Size: to.Ptr(azopenai.ImageSize256x256),
105+
ResponseFormat: to.Ptr(azopenai.ImageGenerationResponseFormatURL),
106+
}, nil)
107+
require.Empty(t, resp)
108+
109+
assertResponseIsError(t, err)
110+
}

sdk/cognitiveservices/azopenai/custom_client_test.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
1717
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
1818
"github.com/Azure/azure-sdk-for-go/sdk/cognitiveservices/azopenai"
19+
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
1920
"github.com/stretchr/testify/require"
2021
)
2122

@@ -88,12 +89,7 @@ func TestGetCompletionsStream_AzureOpenAI(t *testing.T) {
8889
}
8990

9091
func TestGetCompletionsStream_OpenAI(t *testing.T) {
91-
cred, err := azopenai.NewKeyCredential(openAIKey)
92-
require.NoError(t, err)
93-
94-
client, err := azopenai.NewClientForOpenAI(openAIEndpoint, cred, newClientOptionsForTest(t))
95-
require.NoError(t, err)
96-
92+
client := newOpenAIClientForTest(t)
9793
testGetCompletionsStream(t, client, false)
9894
}
9995

@@ -102,7 +98,7 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo
10298
Prompt: []string{"What is Azure OpenAI?"},
10399
MaxTokens: to.Ptr(int32(2048)),
104100
Temperature: to.Ptr(float32(0.0)),
105-
Model: to.Ptr(openAICompletionsModelDeployment),
101+
Model: to.Ptr(openAICompletionsModel),
106102
}
107103

108104
response, err := client.GetCompletionsStream(context.TODO(), body, nil)
@@ -142,3 +138,30 @@ func testGetCompletionsStream(t *testing.T, client *azopenai.Client, isAzure boo
142138
require.Equal(t, want, got)
143139
require.Equal(t, 86, eventCount)
144140
}
141+
142+
func TestClient_GetCompletions_Error(t *testing.T) {
143+
if recording.GetRecordMode() == recording.PlaybackMode {
144+
t.Skip()
145+
}
146+
147+
doTest := func(t *testing.T, client *azopenai.Client) {
148+
streamResp, err := client.GetCompletionsStream(context.Background(), azopenai.CompletionsOptions{
149+
Prompt: []string{"What is Azure OpenAI?"},
150+
MaxTokens: to.Ptr(int32(2048 - 127)),
151+
Temperature: to.Ptr(float32(0.0)),
152+
Model: &openAICompletionsModel,
153+
}, nil)
154+
require.Empty(t, streamResp)
155+
assertResponseIsError(t, err)
156+
}
157+
158+
t.Run("AzureOpenAI", func(t *testing.T) {
159+
client := newBogusAzureOpenAIClient(t, completionsModelDeployment)
160+
doTest(t, client)
161+
})
162+
163+
t.Run("OpenAI", func(t *testing.T) {
164+
client := newBogusOpenAIClient(t)
165+
doTest(t, client)
166+
})
167+
}

0 commit comments

Comments
 (0)