Skip to content

Commit 35f2fac

Browse files
authored
[OpenAI] Deserialize Chat Message Request (Azure#27138)
1 parent 382389b commit 35f2fac

File tree

6 files changed

+126
-14
lines changed

6 files changed

+126
-14
lines changed

sdk/openai/openai/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Returns `usage` information when available.
1414
- Fixes a bug where errors weren't properly being thrown from the streaming methods.
1515
- Returns `error` information in `ContentFilterResults` when available.
16+
- Fixes parsing of `functionCall` in `ChatMessage` objects.
1617

1718
## 1.0.0-beta.5 (2023-08-25)
1819

sdk/openai/openai/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "js",
44
"TagPrefix": "js/openai/openai",
5-
"Tag": "js/openai/openai_85d9317957"
5+
"Tag": "js/openai/openai_d3a9528d70"
66
}

sdk/openai/openai/sources/customizations/api/operations.ts

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import {
77
ImageGenerations,
88
ImageLocation,
99
} from "../../generated/src/models/models.js";
10-
import { GetCompletionsOptions } from "../../generated/src/models/options.js";
10+
import {
11+
GetChatCompletionsWithAzureExtensionsOptions,
12+
GetCompletionsOptions,
13+
GetChatCompletionsOptions as GeneratedGetChatCompletionsOptions,
14+
} from "../../generated/src/models/options.js";
1115
import {
1216
_getCompletionsSend,
13-
_getChatCompletionsSend,
14-
_getChatCompletionsWithAzureExtensionsSend,
1517
_beginAzureBatchImageGenerationSend,
1618
} from "../../generated/src/api/operations.js";
1719
import { getOaiSSEs } from "./oaiSse.js";
@@ -20,14 +22,17 @@ import {
2022
BeginAzureBatchImageGenerationDefaultResponse,
2123
BeginAzureBatchImageGenerationLogicalResponse,
2224
OpenAIContext as Client,
25+
GetChatCompletions200Response,
26+
GetChatCompletionsDefaultResponse,
2327
GetChatCompletionsWithAzureExtensions200Response,
2428
GetChatCompletionsWithAzureExtensionsDefaultResponse,
2529
ImageGenerationsOutput,
2630
ImagePayloadOutput,
2731
getLongRunningPoller,
2832
isUnexpected,
33+
ChatMessage as GeneratedChatMessage,
2934
} from "../../generated/src/rest/index.js";
30-
import { StreamableMethod } from "@azure-rest/core-client";
35+
import { StreamableMethod, operationOptionsToRequestParameters } from "@azure-rest/core-client";
3136
import { ChatCompletions } from "../models/models.js";
3237
import { getChatCompletionsResult, getCompletionsResult } from "./deserializers.js";
3338
import { GetChatCompletionsOptions } from "./models.js";
@@ -287,3 +292,74 @@ export async function getAudioTranscription<Format extends AudioResultFormat>(
287292
? body
288293
: (renameKeysToCamelCase(body) as AudioResult<Format>);
289294
}
295+
296+
export function _getChatCompletionsWithAzureExtensionsSend(
297+
context: Client,
298+
messages: ChatMessage[],
299+
deploymentId: string,
300+
options: GetChatCompletionsWithAzureExtensionsOptions = { requestOptions: {} }
301+
): StreamableMethod<
302+
| GetChatCompletionsWithAzureExtensions200Response
303+
| GetChatCompletionsWithAzureExtensionsDefaultResponse
304+
> {
305+
return context
306+
.path("/deployments/{deploymentId}/extensions/chat/completions", deploymentId)
307+
.post({
308+
...operationOptionsToRequestParameters(options),
309+
body: {
310+
messages: parseChatMessage(messages),
311+
functions: options?.functions,
312+
function_call: options?.functionCall,
313+
max_tokens: options?.maxTokens,
314+
temperature: options?.temperature,
315+
top_p: options?.topP,
316+
logit_bias: options?.logitBias,
317+
user: options?.user,
318+
n: options?.n,
319+
stop: options?.stop,
320+
presence_penalty: options?.presencePenalty,
321+
frequency_penalty: options?.frequencyPenalty,
322+
stream: options?.stream,
323+
model: options?.model,
324+
dataSources: options?.dataSources,
325+
},
326+
});
327+
}
328+
329+
export function _getChatCompletionsSend(
330+
context: Client,
331+
messages: ChatMessage[],
332+
deploymentId: string,
333+
options: GeneratedGetChatCompletionsOptions = { requestOptions: {} }
334+
): StreamableMethod<GetChatCompletions200Response | GetChatCompletionsDefaultResponse> {
335+
return context.path("/deployments/{deploymentId}/chat/completions", deploymentId).post({
336+
...operationOptionsToRequestParameters(options),
337+
body: {
338+
messages: parseChatMessage(messages),
339+
functions: options?.functions,
340+
function_call: options?.functionCall,
341+
max_tokens: options?.maxTokens,
342+
temperature: options?.temperature,
343+
top_p: options?.topP,
344+
logit_bias: options?.logitBias,
345+
user: options?.user,
346+
n: options?.n,
347+
stop: options?.stop,
348+
presence_penalty: options?.presencePenalty,
349+
frequency_penalty: options?.frequencyPenalty,
350+
stream: options?.stream,
351+
model: options?.model,
352+
dataSources: options?.dataSources,
353+
},
354+
});
355+
}
356+
357+
function parseChatMessage(messages: ChatMessage[]): GeneratedChatMessage[] {
358+
return messages.map((p: ChatMessage) => ({
359+
role: p.role,
360+
content: p.content ?? null,
361+
name: p.name,
362+
function_call: p.functionCall,
363+
context: p.context,
364+
}));
365+
}

sdk/openai/openai/src/api/operations.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import {
2929
BeginAzureBatchImageGeneration202Response,
3030
BeginAzureBatchImageGenerationDefaultResponse,
3131
BeginAzureBatchImageGenerationLogicalResponse,
32+
ChatMessage as GeneratedChatMessage,
3233
OpenAIContext as Client,
3334
GetChatCompletions200Response,
3435
GetChatCompletionsDefaultResponse,
@@ -241,7 +242,7 @@ export function _getChatCompletionsSend(
241242
return context.path("/deployments/{deploymentId}/chat/completions", deploymentId).post({
242243
...operationOptionsToRequestParameters(options),
243244
body: {
244-
messages: messages,
245+
messages: parseChatMessage(messages),
245246
functions: options?.functions,
246247
function_call: options?.functionCall,
247248
max_tokens: options?.maxTokens,
@@ -376,7 +377,7 @@ export function _getChatCompletionsWithAzureExtensionsSend(
376377
.post({
377378
...operationOptionsToRequestParameters(options),
378379
body: {
379-
messages: messages,
380+
messages: parseChatMessage(messages),
380381
functions: options?.functions,
381382
function_call: options?.functionCall,
382383
max_tokens: options?.maxTokens,
@@ -786,3 +787,13 @@ export async function getAudioTranscription<Format extends AudioResultFormat>(
786787
? body
787788
: (renameKeysToCamelCase(body) as AudioResult<Format>);
788789
}
790+
791+
function parseChatMessage(messages: ChatMessage[]): GeneratedChatMessage[] {
792+
return messages.map((p: ChatMessage) => ({
793+
role: p.role,
794+
content: p.content ?? null,
795+
name: p.name,
796+
function_call: p.functionCall,
797+
context: p.context,
798+
}));
799+
}

sdk/openai/openai/test/public/node/whisper.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function getModel(authMethod: AuthMethod): string {
1919
}
2020

2121
describe("OpenAI", function () {
22-
matrix([["AzureAPIKey", "OpenAIKey"]] as const, async function (authMethod: AuthMethod) {
22+
matrix([["AzureAPIKey", "OpenAIKey", "AAD"]] as const, async function (authMethod: AuthMethod) {
2323
describe(`[${authMethod}] Client`, () => {
2424
let recorder: Recorder;
2525
let client: OpenAIClient;

sdk/openai/openai/test/public/openai.spec.ts

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {
2222
} from "./utils/utils.js";
2323
import { createHttpHeaders, createPipelineRequest } from "@azure/core-rest-pipeline";
2424
import { getImageDimensions } from "./utils/getImageDimensions.js";
25-
import { OpenAIClient, ImageLocation } from "../../src/index.js";
25+
import { OpenAIClient, ImageLocation, ChatMessage } from "../../src/index.js";
2626

2727
describe("OpenAI", function () {
2828
let recorder: Recorder;
@@ -158,7 +158,6 @@ describe("OpenAI", function () {
158158
"What's the most common feedback we received from our customers about the product?",
159159
},
160160
];
161-
const weatherMessages = [{ role: "user", content: "What's the weather like in Boston?" }];
162161
const getCurrentWeather = {
163162
name: "get_current_weather",
164163
description: "Get the current weather in a given location",
@@ -210,11 +209,33 @@ describe("OpenAI", function () {
210209
chatCompletionDeployments,
211210
chatCompletionModels
212211
),
213-
async (deploymentName) =>
214-
client.getChatCompletions(deploymentName, weatherMessages, {
212+
async (deploymentName) => {
213+
const weatherMessages: ChatMessage[] = [
214+
{ role: "user", content: "What's the weather like in Boston?" },
215+
];
216+
const result = await client.getChatCompletions(deploymentName, weatherMessages, {
215217
functions: [getCurrentWeather],
216-
}),
217-
(c) => assertChatCompletions(c, { functions: true })
218+
});
219+
assertChatCompletions(result, { functions: true });
220+
const responseMessage = result.choices[0].message;
221+
if (!responseMessage?.functionCall) {
222+
assert.fail("Undefined function call");
223+
}
224+
const functionArgs = JSON.parse(responseMessage.functionCall.arguments);
225+
weatherMessages.push(responseMessage);
226+
weatherMessages.push({
227+
role: "function",
228+
name: responseMessage.functionCall.name,
229+
content: JSON.stringify({
230+
location: functionArgs.location,
231+
temperature: "72",
232+
unit: functionArgs.unit,
233+
forecast: ["sunny", "windy"],
234+
}),
235+
});
236+
return client.getChatCompletions(deploymentName, weatherMessages);
237+
},
238+
(result) => assertChatCompletions(result, { functions: true })
218239
),
219240
chatCompletionDeployments,
220241
chatCompletionModels,
@@ -289,6 +310,9 @@ describe("OpenAI", function () {
289310
});
290311

291312
it("calls functions", async function () {
313+
const weatherMessages: ChatMessage[] = [
314+
{ role: "user", content: "What's the weather like in Boston?" },
315+
];
292316
updateWithSucceeded(
293317
await withDeployments(
294318
getSucceeded(

0 commit comments

Comments
 (0)