Skip to content

Commit a84e331

Browse files
deyaaeldeenxirzec
andauthored
[OpenAI] Re-implement SSEs (Azure#26704)
### Packages impacted by this PR @azure/openai ### Issues associated with this PR Azure#26376 ### Describe the problem that is addressed by this PR The old implementation of SSEs didn't do the right thing with regard to parsing events spanning multiple stream chunks. It converted every chunk to a string first which is not valid. ### What are the possible designs available to address the problem? If there are more than one possible design, why was the one in this PR chosen? This implementation follows closely the spec in https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation and is largely based on [`@microsoft/fetch-event-source`](https://www.npmjs.com/package/@microsoft/fetch-event-source)'s. I think we should consider moving the implementation to core to be used to interpret responses with content type value of `text/event-stream` and I can file an issue after merging this PR. ### Are there test cases added in this PR? _(If not, why?)_ Yes! ### Provide a list of related PRs _(if any)_ N/A ### Command used to generate this PR:**_(Applicable only to SDK release request PRs)_ ### Checklists - [x] Added impacted package name to the issue description - [ ] Does this PR needs any fixes in the SDK Generator?** _(If so, create an Issue in the [Autorest/typescript](https://github.com/Azure/autorest.typescript) repository and link it here)_ - [x] Added a changelog (if necessary) --------- Co-authored-by: Jeff Fisher <jeffish@microsoft.com>
1 parent adf8bd0 commit a84e331

30 files changed

+1008
-183
lines changed

sdk/openai/openai/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
### Bugs Fixed
1010

11+
- Fix a bug where server-sent events were not being parsed correctly.
12+
1113
### Other Changes
1214

1315
## 1.0.0-beta.3 (2023-07-13)

sdk/openai/openai/package.json

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"main": "dist/index.cjs",
77
"module": "dist-esm/src/index.js",
88
"browser": {
9-
"./dist-esm/src/api/getStream.js": "./dist-esm/src/api/getStream.browser.js"
9+
"./dist-esm/src/api/getSSEs.js": "./dist-esm/src/api/getSSEs.browser.js"
1010
},
1111
"type": "module",
1212
"exports": {
@@ -46,7 +46,7 @@
4646
"extract-api": "tsc -p . && api-extractor run --local",
4747
"format": "prettier --write --config ../../../.prettierrc.json --ignore-path ../../../.prettierignore \"sources/customizations/**/*.ts\" \"src/**/*.ts\" \"test/**/*.ts\" \"samples-dev/**/*.ts\" \"*.{js,json}\"",
4848
"integration-test:browser": "npm run unit-test:browser",
49-
"integration-test:node": "dev-tool run test:node-js-input -- --timeout 5000000 \"dist-esm/test/**/*.spec.js\"",
49+
"integration-test:node": "dev-tool run test:node-js-input -- --timeout 5000000 \"dist-esm/test/public/{,!(browser)/**/}/*.spec.js\"",
5050
"integration-test": "npm run integration-test:node && npm run integration-test:browser",
5151
"lint:fix": "eslint README.md package.json api-extractor.json src test --ext .ts,.javascript,.js --fix --fix-type [problem,suggestion]",
5252
"lint": "eslint README.md package.json api-extractor.json src test --ext .ts,.javascript,.js",
@@ -55,7 +55,7 @@
5555
"test:node": "npm run clean && tsc -p . && npm run integration-test:node",
5656
"test": "npm run clean && tsc -p . && npm run unit-test:node && dev-tool run bundle && npm run unit-test:browser && npm run integration-test",
5757
"unit-test:browser": "dev-tool run test:browser -- karma.conf.cjs",
58-
"unit-test:node": "dev-tool run test:node-ts-input -- \"test/internal/unit/{,!(browser)/**/}*.spec.ts\" \"test/public/{,!(browser)/**/}*.spec.ts\"",
58+
"unit-test:node": "dev-tool run test:node-ts-input -- \"test/internal/{,!(browser)/**/}*.spec.ts\" \"test/public/{,!(browser)/**/}*.spec.ts\"",
5959
"unit-test": "npm run unit-test:node && npm run unit-test:browser"
6060
},
6161
"files": [
@@ -94,14 +94,12 @@
9494
"@azure-tools/test-credential": "^1.0.0",
9595
"@azure/test-utils": "^1.0.0",
9696
"@microsoft/api-extractor": "^7.31.1",
97-
"@types/fs-extra": "^9.0.0",
9897
"@types/mocha": "^7.0.2",
9998
"@types/node": "^14.0.0",
10099
"cross-env": "^7.0.3",
101100
"dotenv": "^16.0.0",
102101
"eslint": "^8.16.0",
103102
"esm": "^3.2.25",
104-
"fs-extra": "^10.0.0",
105103
"karma": "^6.4.0",
106104
"karma-chrome-launcher": "^3.1.1",
107105
"karma-coverage": "^2.2.0",

sdk/openai/openai/review/openai.api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ export class OpenAIClient {
201201
getCompletions(deploymentName: string, prompt: string[], options?: GetCompletionsOptions): Promise<Completions>;
202202
getEmbeddings(deploymentName: string, input: string[], options?: GetEmbeddingsOptions): Promise<Embeddings>;
203203
getImages(prompt: string, options?: ImageGenerationOptions): Promise<ImageGenerationResponse>;
204-
listChatCompletions(deploymentName: string, messages: ChatMessage[], options?: GetChatCompletionsOptions): Promise<AsyncIterable<Omit<ChatCompletions, "usage">>>;
205-
listCompletions(deploymentName: string, prompt: string[], options?: GetCompletionsOptions): Promise<AsyncIterable<Omit<Completions, "usage">>>;
204+
listChatCompletions(deploymentName: string, messages: ChatMessage[], options?: GetChatCompletionsOptions): AsyncIterable<Omit<ChatCompletions, "usage">>;
205+
listCompletions(deploymentName: string, prompt: string[], options?: GetCompletionsOptions): AsyncIterable<Omit<Completions, "usage">>;
206206
}
207207

208208
// @public (undocumented)

sdk/openai/openai/sources/customizations/OpenAIClient.ts

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import {
1616
OpenAIClientOptions,
1717
} from "../generated/api/index.js";
1818
import { getChatCompletionsResult, getCompletionsResult } from "./api/operations.js";
19-
import { getSSEs } from "./api/sse.js";
19+
import { getOaiSSEs } from "./api/oaiSse.js";
2020
import { ChatCompletions, Completions, Embeddings } from "../generated/api/models.js";
2121
import { _getChatCompletionsSend, _getCompletionsSend } from "../generated/api/operations.js";
2222
import { ImageGenerationOptions } from "./api/operations.js";
@@ -154,90 +154,90 @@ export class OpenAIClient {
154154

155155
/**
156156
* Returns textual completions as configured for a given prompt.
157-
* @param deploymentOrModelName - Specifies either the model deployment name (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
157+
* @param deploymentName - Specifies either the model deployment name (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
158158
* @param prompt - The prompt to use for this request.
159159
* @param options - The options for this completions request.
160160
* @returns The completions for the given prompt.
161161
*/
162162
getCompletions(
163-
deploymentOrModelName: string,
163+
deploymentName: string,
164164
prompt: string[],
165165
options: GetCompletionsOptions = { requestOptions: {} }
166166
): Promise<Completions> {
167-
this.setModel(deploymentOrModelName, options);
168-
return getCompletions(this._client, prompt, deploymentOrModelName, options);
167+
this.setModel(deploymentName, options);
168+
return getCompletions(this._client, prompt, deploymentName, options);
169169
}
170170

171171
/**
172172
* Lists the completions tokens as they become available for a given prompt.
173-
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
173+
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
174174
* @param prompt - The prompt to use for this request.
175175
* @param options - The completions options for this completions request.
176176
* @returns An asynchronous iterable of completions tokens.
177177
*/
178178
listCompletions(
179-
deploymentOrModelName: string,
179+
deploymentName: string,
180180
prompt: string[],
181181
options: GetCompletionsOptions = {}
182-
): Promise<AsyncIterable<Omit<Completions, "usage">>> {
183-
this.setModel(deploymentOrModelName, options);
184-
const response = _getCompletionsSend(this._client, prompt, deploymentOrModelName, {
182+
): AsyncIterable<Omit<Completions, "usage">> {
183+
this.setModel(deploymentName, options);
184+
const response = _getCompletionsSend(this._client, prompt, deploymentName, {
185185
...options,
186186
stream: true,
187187
});
188-
return getSSEs(response, getCompletionsResult);
188+
return getOaiSSEs(response, getCompletionsResult);
189189
}
190190

191191
/**
192192
* Return the computed embeddings for a given prompt.
193-
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
193+
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
194194
* @param input - The prompt to use for this request.
195195
* @param options - The embeddings options for this embeddings request.
196196
* @returns The embeddings for the given prompt.
197197
*/
198198
getEmbeddings(
199-
deploymentOrModelName: string,
199+
deploymentName: string,
200200
input: string[],
201201
options: GetEmbeddingsOptions = { requestOptions: {} }
202202
): Promise<Embeddings> {
203-
this.setModel(deploymentOrModelName, options);
204-
return getEmbeddings(this._client, input, deploymentOrModelName, options);
203+
this.setModel(deploymentName, options);
204+
return getEmbeddings(this._client, input, deploymentName, options);
205205
}
206206

207207
/**
208208
* Get chat completions for provided chat context messages.
209-
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
209+
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
210210
* @param messages - The chat context messages to use for this request.
211211
* @param options - The chat completions options for this completions request.
212212
* @returns The chat completions for the given chat context messages.
213213
*/
214214
getChatCompletions(
215-
deploymentOrModelName: string,
215+
deploymentName: string,
216216
messages: ChatMessage[],
217217
options: GetChatCompletionsOptions = { requestOptions: {} }
218218
): Promise<ChatCompletions> {
219-
this.setModel(deploymentOrModelName, options);
220-
return getChatCompletions(this._client, messages, deploymentOrModelName, options);
219+
this.setModel(deploymentName, options);
220+
return getChatCompletions(this._client, messages, deploymentName, options);
221221
}
222222

223223
/**
224224
* Lists the chat completions tokens as they become available for a chat context.
225-
* @param deploymentOrModelName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
225+
* @param deploymentName - The name of the model deployment (when using Azure OpenAI) or model name (when using non-Azure OpenAI) to use for this request.
226226
* @param messages - The chat context messages to use for this request.
227227
* @param options - The chat completions options for this chat completions request.
228228
* @returns An asynchronous iterable of chat completions tokens.
229229
*/
230230
listChatCompletions(
231-
deploymentOrModelName: string,
231+
deploymentName: string,
232232
messages: ChatMessage[],
233233
options: GetChatCompletionsOptions = { requestOptions: {} }
234-
): Promise<AsyncIterable<Omit<ChatCompletions, "usage">>> {
235-
this.setModel(deploymentOrModelName, options);
236-
const response = _getChatCompletionsSend(this._client, messages, deploymentOrModelName, {
234+
): AsyncIterable<Omit<ChatCompletions, "usage">> {
235+
this.setModel(deploymentName, options);
236+
const response = _getChatCompletionsSend(this._client, messages, deploymentName, {
237237
...options,
238238
stream: true,
239239
});
240-
return getSSEs(response, getChatCompletionsResult);
240+
return getOaiSSEs(response, getChatCompletionsResult);
241241
}
242242

243243
/**
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
import { StreamableMethod } from "@azure-rest/core-client";
5+
import { EventMessage, toSSE } from "./sse.js";
6+
7+
async function* toAsyncIterable<T>(stream: ReadableStream<T>): AsyncIterable<T> {
8+
const reader = stream.getReader();
9+
try {
10+
while (true) {
11+
const { value, done } = await reader.read();
12+
if (done) {
13+
return;
14+
}
15+
yield value;
16+
}
17+
} finally {
18+
reader.releaseLock();
19+
}
20+
}
21+
22+
async function getStream<TResponse>(
23+
response: StreamableMethod<TResponse>
24+
): Promise<AsyncIterable<Uint8Array>> {
25+
const stream = (await response.asBrowserStream()).body;
26+
if (!stream) throw new Error("No stream found in response. Did you enable the stream option?");
27+
return toAsyncIterable(stream);
28+
}
29+
30+
export async function getSSEs(
31+
response: StreamableMethod<unknown>
32+
): Promise<AsyncIterable<EventMessage>> {
33+
const iter = await getStream(response);
34+
return toSSE(iter);
35+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
import { StreamableMethod } from "@azure-rest/core-client";
5+
import { EventMessage, toSSE } from "./sse.js";
6+
7+
async function getStream<TResponse>(
8+
response: StreamableMethod<TResponse>
9+
): Promise<AsyncIterable<Uint8Array>> {
10+
const stream = (await response.asNodeStream()).body;
11+
if (!stream) throw new Error("No stream found in response. Did you enable the stream option?");
12+
return stream as AsyncIterable<Uint8Array>;
13+
}
14+
15+
export async function getSSEs(
16+
response: StreamableMethod<unknown>
17+
): Promise<AsyncIterable<EventMessage>> {
18+
const chunkIterator = await getStream(response);
19+
return toSSE(chunkIterator);
20+
}

sdk/openai/openai/sources/customizations/api/getStream.browser.ts

Lines changed: 0 additions & 22 deletions
This file was deleted.

sdk/openai/openai/sources/customizations/api/getStream.ts

Lines changed: 0 additions & 14 deletions
This file was deleted.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
import { StreamableMethod } from "@azure-rest/core-client";
5+
import { getSSEs } from "./getSSEs.js";
6+
import { wrapError } from "./util.js";
7+
8+
export async function* getOaiSSEs<TEvent>(
9+
response: StreamableMethod<unknown>,
10+
toEvent: (obj: Record<string, any>) => TEvent
11+
): AsyncIterable<TEvent> {
12+
const stream = await getSSEs(response);
13+
let isDone = false;
14+
for await (const event of stream) {
15+
if (isDone) {
16+
// handle a case where the service sends excess stream
17+
// data after the [DONE] event
18+
continue;
19+
} else if (event.data === "[DONE]") {
20+
isDone = true;
21+
} else {
22+
yield toEvent(
23+
wrapError(
24+
() => JSON.parse(event.data),
25+
"Error parsing an event. See 'cause' for more details"
26+
)
27+
);
28+
}
29+
}
30+
}

0 commit comments

Comments
 (0)