diff --git a/client.go b/client.go index c57ba17c7..e9e061902 100644 --- a/client.go +++ b/client.go @@ -156,6 +156,26 @@ func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err er return } +func sendRequestStreamV2(client *Client, req *http.Request) (stream *StreamerV2, err error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) + if err != nil { + return + } + + // TODO: how to handle error? + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return NewStreamerV2(resp.Body), nil +} + func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") diff --git a/run.go b/run.go index 094b0a4db..728fe853d 100644 --- a/run.go +++ b/run.go @@ -82,12 +82,13 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model string `json:"model,omitempty"` - Instructions string `json:"instructions,omitempty"` - AdditionalInstructions string `json:"additional_instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. @@ -124,6 +125,11 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +type RunRequestStreaming struct { + RunRequest + Stream bool `json:"stream"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } @@ -337,6 +343,36 @@ func (c *Client) SubmitToolOutputs( return } +type SubmitToolOutputsStreamRequest struct { + SubmitToolOutputsRequest + Stream bool `json:"stream"` +} + +func (c *Client) SubmitToolOutputsStream( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest, +) (stream *StreamerV2, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + r := SubmitToolOutputsStreamRequest{ + SubmitToolOutputsRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + return sendRequestStreamV2(c, req) +} + // CancelRun cancels a run. func (c *Client) CancelRun( ctx context.Context, @@ -375,6 +411,106 @@ func (c *Client) CreateThreadAndRun( return } +type StreamMessageDelta struct { + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIDs []string `json:"file_ids"` +} + +type AssistantStreamEvent struct { + ID string `json:"id"` + Object string `json:"object"` + Delta StreamMessageDelta `json:"delta,omitempty"` + + // Run + CreatedAt int64 `json:"created_at,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + AssistantID string `json:"assistant_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Usage Usage `json:"usage,omitempty"` + + // ThreadMessage.Completed + Role string `json:"role,omitempty"` + Content []MessageContent `json:"content,omitempty"` + // IncompleteDetails + // IncompleteAt + + // Run steps + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + StepDetails StepDetails `json:"step_details"` + ExpiredAt *int64 `json:"expired_at,omitempty"` +} + +type AssistantStream struct { + *streamReader[AssistantStreamEvent] +} + +func (c *Client) CreateThreadAndRunStream( + ctx context.Context, + request CreateThreadAndRunRequest) (stream *StreamerV2, err error) { + type createThreadAndStreamRequest struct { + CreateThreadAndRunRequest + Stream bool `json:"stream"` + } + + urlSuffix := "/threads/runs" + sr := createThreadAndStreamRequest{ + CreateThreadAndRunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(sr), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + return sendRequestStreamV2(c, req) +} + +func (c *Client) CreateRunStream( + ctx context.Context, + threadID string, + request RunRequest) (stream *StreamerV2, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + + r := RunRequestStreaming{ + RunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + return sendRequestStreamV2(c, req) +} + // RetrieveRunStep retrieves a run step. func (c *Client) RetrieveRunStep( ctx context.Context, diff --git a/run_test.go b/run_test.go index cdf99db05..606ac426a 100644 --- a/run_test.go +++ b/run_test.go @@ -219,6 +219,31 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndRun error") + _, err = client.CreateThreadAndRunStream(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndStream error") + + _, err = client.CreateRunStream(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRunStreaming error") + + _, err = client.SubmitToolOutputsStream(ctx, threadID, runID, openai.SubmitToolOutputsRequest{ + ToolOutputs: nil, + }) + checks.NoError(t, err, "SubmitToolOutputsStream error") + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) checks.NoError(t, err, "RetrieveRunStep error") diff --git a/sse.go b/sse.go new file mode 100644 index 000000000..fe5a5c5f3 --- /dev/null +++ b/sse.go @@ -0,0 +1,165 @@ +package openai + +import ( + "bufio" + "io" + "strconv" + "strings" +) + +// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance. +func NewEOLSplitterFunc() bufio.SplitFunc { + splitter := NewEOLSplitter() + return splitter.Split +} + +// EOLSplitter is the custom split function to handle CR LF, CR, and LF as end-of-line. +type EOLSplitter struct { + prevCR bool +} + +// NewEOLSplitter creates a new EOLSplitter instance. +func NewEOLSplitter() *EOLSplitter { + return &EOLSplitter{prevCR: false} +} + +const crlfLen = 2 + +// Split function to handle CR LF, CR, and LF as end-of-line. +func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Check if the previous data ended with a CR + if s.prevCR { + s.prevCR = false + if len(data) > 0 && data[0] == '\n' { + return 1, nil, nil // Skip the LF following the previous CR + } + } + + // Search for the first occurrence of CR LF, CR, or LF + for i := 0; i < len(data); i++ { + if data[i] == '\r' { + if i+1 < len(data) && data[i+1] == '\n' { + // Found CR LF + return i + crlfLen, data[:i], nil + } + // Found CR + if !atEOF && i == len(data)-1 { + // If CR is the last byte, and not EOF, then need to check if + // the next byte is LF. + // + // save the state and request more data + s.prevCR = true + return 0, nil, nil + } + return i + 1, data[:i], nil + } + if data[i] == '\n' { + // Found LF + return i + 1, data[:i], nil + } + } + + // If at EOF, we have a final, non-terminated line. Return it. + if atEOF && len(data) > 0 { + return len(data), data, nil + } + + // Request more data. + return 0, nil, nil +} + +type ServerSentEvent struct { + ID string // ID of the event + Data string // Data of the event + Event string // Type of the event + Retry int // Retry time in milliseconds + Comment string // Comment +} + +type SSEScanner struct { + scanner *bufio.Scanner + next ServerSentEvent + err error + readComment bool +} + +func NewSSEScanner(r io.Reader, readComment bool) *SSEScanner { + scanner := bufio.NewScanner(r) + + // N.B. The bufio.ScanLines handles `\r?\n``, but not `\r` itself as EOL, as + // the SSE spec requires + // + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream + // + // scanner.Split(bufio.ScanLines) + scanner.Split(NewEOLSplitterFunc()) + + return &SSEScanner{ + scanner: scanner, + readComment: readComment, + } +} + +func (s *SSEScanner) Next() bool { + // Zero the next event before scanning a new one + var event ServerSentEvent + s.next = event + + var dataLines []string + + var seenNonEmptyLine bool + + for s.scanner.Scan() { + line := strings.TrimSpace(s.scanner.Text()) + + if line == "" { + if seenNonEmptyLine { + break + } + + continue + } + + seenNonEmptyLine = true + switch { + case strings.HasPrefix(line, "id: "): + event.ID = strings.TrimPrefix(line, "id: ") + case strings.HasPrefix(line, "data: "): + dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) + case strings.HasPrefix(line, "event: "): + event.Event = strings.TrimPrefix(line, "event: ") + case strings.HasPrefix(line, "retry: "): + retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) + if err == nil { + event.Retry = retry + } + // ignore invalid retry values + case strings.HasPrefix(line, ":"): + if s.readComment { + event.Comment = strings.TrimPrefix(line, ":") + } + // ignore comment line + default: + // ignore unknown lines + } + } + + s.err = s.scanner.Err() + + if !seenNonEmptyLine { + return false + } + + event.Data = strings.Join(dataLines, "\n") + s.next = event + + return true +} + +func (s *SSEScanner) Scan() ServerSentEvent { + return s.next +} + +func (s *SSEScanner) Err() error { + return s.err +} diff --git a/sse_test.go b/sse_test.go new file mode 100644 index 000000000..73c458d43 --- /dev/null +++ b/sse_test.go @@ -0,0 +1,274 @@ +package openai_test + +import ( + "bufio" + "io" + "reflect" + "strings" + "testing" + + "github.com/sashabaranov/go-openai" +) + +// ChunksReader simulates a reader that splits the input across multiple reads. +type ChunksReader struct { + chunks []string + index int +} + +func NewChunksReader(chunks []string) *ChunksReader { + return &ChunksReader{ + chunks: chunks, + } +} + +func (r *ChunksReader) Read(p []byte) (n int, err error) { + if r.index >= len(r.chunks) { + return 0, io.EOF + } + n = copy(p, r.chunks[r.index]) + r.index++ + return n, nil +} + +// TestEolSplitter tests the custom EOL splitter function. +func TestEolSplitter(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"CRLF", "Line1\r\nLine2\r\nLine3\r\n", []string{"Line1", "Line2", "Line3"}}, + {"CR", "Line1\rLine2\rLine3\r", []string{"Line1", "Line2", "Line3"}}, + {"LF", "Line1\nLine2\nLine3\n", []string{"Line1", "Line2", "Line3"}}, + {"Mixed", "Line1\r\nLine2\rLine3\nLine4\r\nLine5", []string{"Line1", "Line2", "Line3", "Line4", "Line5"}}, + {"SingleLineNoEOL", "Line1", []string{"Line1"}}, + {"SingleLineLF", "Line1\n", []string{"Line1"}}, + {"SingleLineCR", "Line1\r", []string{"Line1"}}, + {"SingleLineCRLF", "Line1\r\n", []string{"Line1"}}, + {"DoubleNewLines", "Line1\n\nLine2", []string{"Line1", "", "Line2"}}, + {"lflf", "\n\n", []string{"", ""}}, + {"crlfcrlf", "\r\n\r\n", []string{"", ""}}, + {"crcr", "\r\r", []string{"", ""}}, + {"mixed eol: crlf cr lf", "A\r\nB\rC\nD", []string{"A", "B", "C", "D"}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + reader := strings.NewReader(test.input) + scanner := bufio.NewScanner(reader) + scanner.Split(openai.NewEOLSplitterFunc()) + + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(lines) != len(test.expected) { + t.Errorf("Expected %d lines, got %d", len(test.expected), len(lines)) + t.Errorf("Expected: %v, got: %v", test.expected, lines) + } + + for i := range lines { + if lines[i] != test.expected[i] { + t.Errorf("Expected line %d to be %q, got %q", i, test.expected[i], lines[i]) + } + } + }) + } +} + +// TestEolSplitterBoundaryCondition tests the boundary condition where CR LF is split across two slices. +func TestEolSplitterBoundaryCondition(t *testing.T) { + // Additional cases + cases := []struct { + input []string + expected []string + }{ + {[]string{"Line1\r", "\nLine2"}, []string{"Line1", "Line2"}}, + {[]string{"Line1\r", "\nLine2\r"}, []string{"Line1", "Line2"}}, + {[]string{"Line1\r", "\nLine2\r\n"}, []string{"Line1", "Line2"}}, + {[]string{"Line1\r", "\nLine2\r", "Line3"}, []string{"Line1", "Line2", "Line3"}}, + {[]string{"Line1\r", "\nLine2\r", "\nLine3\r\n"}, []string{"Line1", "Line2", "Line3"}}, + } + for _, c := range cases { + // Custom reader to simulate the boundary condition + reader := NewChunksReader(c.input) + scanner := bufio.NewScanner(reader) + scanner.Split(openai.NewEOLSplitterFunc()) + + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(lines) != len(c.expected) { + t.Errorf("Expected %d lines, got %d", len(c.expected), len(lines)) + continue + } + + for i := range lines { + if lines[i] != c.expected[i] { + t.Errorf("Expected line %d to be %q, got %q", i, c.expected[i], lines[i]) + } + } + } +} + +func TestSSEScanner(t *testing.T) { + tests := []struct { + raw string + want []openai.ServerSentEvent + }{ + { + raw: `data: hello world`, + want: []openai.ServerSentEvent{ + { + Data: "hello world", + }, + }, + }, + { + raw: `event: hello +data: hello world`, + want: []openai.ServerSentEvent{ + { + Event: "hello", + Data: "hello world", + }, + }, + }, + { + raw: `event: hello-json +data: { +data: "msg": "hello world", +data: "id": 12345 +data: }`, + want: []openai.ServerSentEvent{ + { + Event: "hello-json", + Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}", + }, + }, + }, + { + raw: `data: hello world + +data: hello again`, + want: []openai.ServerSentEvent{ + { + Data: "hello world", + }, + { + Data: "hello again", + }, + }, + }, + { + raw: `retry: 10000 + data: hello world`, + want: []openai.ServerSentEvent{ + { + Retry: 10000, + Data: "hello world", + }, + }, + }, + { + raw: `retry: 10000 + +retry: 20000`, + want: []openai.ServerSentEvent{ + { + Retry: 10000, + }, + { + Retry: 20000, + }, + }, + }, + { + raw: `: comment 1 +: comment 2 +id: message-id +retry: 20000 +event: hello-event +data: hello`, + want: []openai.ServerSentEvent{ + { + ID: "message-id", + Retry: 20000, + Event: "hello-event", + Data: "hello", + }, + }, + }, + { + raw: `: comment 1 +id: message 1 +data: hello 1 +retry: 10000 +event: hello-event 1 + +: comment 2 +data: hello 2 +id: message 2 +retry: 20000 +event: hello-event 2 +`, + want: []openai.ServerSentEvent{ + { + ID: "message 1", + Retry: 10000, + Event: "hello-event 1", + Data: "hello 1", + }, + { + ID: "message 2", + Retry: 20000, + Event: "hello-event 2", + Data: "hello 2", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.raw, func(t *testing.T) { + rawWithCRLF := strings.ReplaceAll(tt.raw, "\n", "\r\n") + runSSEScanTest(t, rawWithCRLF, tt.want) + + // Test with "\r" EOL + rawWithCR := strings.ReplaceAll(tt.raw, "\n", "\r") + runSSEScanTest(t, rawWithCR, tt.want) + + // Test with "\n" EOL (original) + runSSEScanTest(t, tt.raw, tt.want) + }) + } +} + +func runSSEScanTest(t *testing.T, raw string, want []openai.ServerSentEvent) { + sseScanner := openai.NewSSEScanner(strings.NewReader(raw), false) + + var got []openai.ServerSentEvent + for sseScanner.Next() { + got = append(got, sseScanner.Scan()) + } + + if err := sseScanner.Err(); err != nil { + t.Errorf("SSEScanner error: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("SSEScanner() = %v, want %v", got, want) + } +} diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..433548794 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -16,7 +16,7 @@ var ( ) type streamable interface { - ChatCompletionStreamResponse | CompletionResponse + ChatCompletionStreamResponse | CompletionResponse | AssistantStreamEvent } type streamReader[T streamable] struct { diff --git a/stream_v2.go b/stream_v2.go new file mode 100644 index 000000000..6c495c2c3 --- /dev/null +++ b/stream_v2.go @@ -0,0 +1,293 @@ +package openai + +import ( + "encoding/json" + "io" +) + +type StreamRawEvent struct { + streamEvent + Data json.RawMessage +} + +type StreamDone struct { + streamEvent +} + +type StreamThreadMessageCompleted struct { + Message + streamEvent +} + +type StreamThreadMessageDelta struct { + ID string `json:"id"` + Object string `json:"object"` + Delta Delta `json:"delta"` + + streamEvent +} + +type Delta struct { + // DeltaText | DeltaImageFile + Content []DeltaContent `json:"content"` +} + +type DeltaContent struct { + Index int `json:"index"` + Type string `json:"type"` + + Text *DeltaText `json:"text"` + ImageFile *DeltaImageFile `json:"image_file"` + ImageURL *DeltaImageURL `json:"image_url"` +} + +type DeltaText struct { + Value string `json:"value"` + // Annotations []any `json:"annotations"` +} + +type DeltaImageFile struct { + FileID string `json:"file_id"` + Detail string `json:"detail"` +} + +type DeltaImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + +func NewStreamerV2(r io.Reader) *StreamerV2 { + var rc io.ReadCloser + + if closer, ok := r.(io.ReadCloser); ok { + rc = closer + } else { + rc = io.NopCloser(r) + } + + return &StreamerV2{ + readCloser: rc, + scanner: NewSSEScanner(r, false), + } +} + +type StreamerV2 struct { + // readCloser is only used for closing the stream + readCloser io.ReadCloser + + scanner *SSEScanner + next StreamEvent + + // buffer for implementing io.Reader + buffer []byte +} + +// TeeSSE tees the stream data with a io.TeeReader +func (s *StreamerV2) TeeSSE(w io.Writer) { + // readCloser is a helper struct that implements io.ReadCloser by combining an io.Reader and an io.Closer + type readCloser struct { + io.Reader + io.Closer + } + + s.readCloser = &readCloser{ + Reader: io.TeeReader(s.readCloser, w), + Closer: s.readCloser, + } + + s.scanner = NewSSEScanner(s.readCloser, false) +} + +// Close closes the underlying io.ReadCloser. +func (s *StreamerV2) Close() error { + return s.readCloser.Close() +} + +type StreamThreadCreated struct { + Thread + streamEvent +} + +type StreamThreadRunCreated struct { + Run + streamEvent +} + +type StreamThreadRunRequiresAction struct { + Run + streamEvent +} + +type StreamThreadRunCompleted struct { + Run + streamEvent +} + +type StreamRunStepCompleted struct { + RunStep + streamEvent +} + +type StreamEvent interface { + Event() string + JSON() json.RawMessage +} + +type streamEvent struct { + event string + data json.RawMessage +} + +// Event returns the event name +func (s *streamEvent) Event() string { + return s.event +} + +// JSON returns the raw JSON data +func (s *streamEvent) JSON() json.RawMessage { + return s.data +} + +func (s *StreamerV2) Next() bool { + if !s.scanner.Next() { + return false + } + + event := s.scanner.Scan() + + streamEvent := streamEvent{ + event: event.Event, + data: json.RawMessage(event.Data), + } + + switch event.Event { + case "thread.created": + var thread Thread + if err := json.Unmarshal([]byte(event.Data), &thread); err == nil { + s.next = &StreamThreadCreated{ + Thread: thread, + streamEvent: streamEvent, + } + } + case "thread.run.created": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunCreated{ + Run: run, + streamEvent: streamEvent, + } + } + + case "thread.run.requires_action": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunRequiresAction{ + Run: run, + streamEvent: streamEvent, + } + } + case "thread.run.completed": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunCompleted{ + Run: run, + streamEvent: streamEvent, + } + } + case "thread.message.delta": + var delta StreamThreadMessageDelta + if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { + delta.streamEvent = streamEvent + s.next = &delta + } + case "thread.run.step.completed": + var runStep RunStep + if err := json.Unmarshal([]byte(event.Data), &runStep); err == nil { + s.next = &StreamRunStepCompleted{ + RunStep: runStep, + streamEvent: streamEvent, + } + } + case "thread.message.completed": + var msg Message + if err := json.Unmarshal([]byte(event.Data), &msg); err == nil { + s.next = &StreamThreadMessageCompleted{ + Message: msg, + streamEvent: streamEvent, + } + } + case "done": + streamEvent.data = nil + s.next = &StreamDone{ + streamEvent: streamEvent, + } + default: + s.next = &StreamRawEvent{ + streamEvent: streamEvent, + } + } + + return true +} + +// Read implements io.Reader of the text deltas of thread.message.delta events. +func (s *StreamerV2) Read(p []byte) (int, error) { + // If we have data in the buffer, copy it to p first. + if len(s.buffer) > 0 { + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] + return n, nil + } + + for s.Next() { + // Read only text deltas + text, ok := s.MessageDeltaText() + if !ok { + continue + } + + s.buffer = []byte(text) + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] + return n, nil + } + + // Check for streamer error + if err := s.Err(); err != nil { + return 0, err + } + + return 0, io.EOF +} + +func (s *StreamerV2) Event() StreamEvent { + return s.next +} + +// Text returns text delta if the current event is a "thread.message.delta". Alias of MessageDeltaText. +func (s *StreamerV2) Text() (string, bool) { + return s.MessageDeltaText() +} + +// MessageDeltaText returns text delta if the current event is a "thread.message.delta". +func (s *StreamerV2) MessageDeltaText() (string, bool) { + event, ok := s.next.(*StreamThreadMessageDelta) + if !ok { + return "", false + } + + var text string + for _, content := range event.Delta.Content { + if content.Text != nil { + // Can we return the first text we find? Does OpenAI stream ever + // return multiple text contents in a delta? + text += content.Text.Value + } + } + + return text, true +} + +func (s *StreamerV2) Err() error { + return s.scanner.Err() +} diff --git a/stream_v2_test.go b/stream_v2_test.go new file mode 100644 index 000000000..0a5d2b9f6 --- /dev/null +++ b/stream_v2_test.go @@ -0,0 +1,237 @@ +//nolint:lll +package openai_test + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "github.com/sashabaranov/go-openai" +) + +func TestNewStreamTextReader(t *testing.T) { + raw := ` +event: thread.message.delta +data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}} + +event: thread.message.delta +data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"world"}}]}} + +event: done +data: [DONE] +` + reader := openai.NewStreamerV2(strings.NewReader(raw)) + + expected := "helloworld" + buffer := make([]byte, len(expected)) + n, err := reader.Read(buffer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != len("hello") { + t.Fatalf("expected to read %d bytes, read %d bytes", len("hello"), n) + } + if string(buffer[:n]) != "hello" { + t.Fatalf("expected %q, got %q", "hello", string(buffer[:n])) + } + + n, err = reader.Read(buffer[n:]) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != len("world") { + t.Fatalf("expected to read %d bytes, read %d bytes", len("world"), n) + } + if string(buffer[:len(expected)]) != expected { + t.Fatalf("expected %q, got %q", expected, string(buffer[:len(expected)])) + } + + n, err = reader.Read(buffer) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got %v", err) + } + if n != 0 { + t.Fatalf("expected to read 0 bytes, read %d bytes", n) + } +} + +type TestCase struct { + Event string + Data string +} + +func constructStreamInput(testCases []TestCase) io.Reader { + var sb bytes.Buffer + for _, tc := range testCases { + sb.WriteString("event: ") + sb.WriteString(tc.Event) + sb.WriteString("\n") + sb.WriteString("data: ") + sb.WriteString(tc.Data) + sb.WriteString("\n\n") + } + return &sb +} + +func jsonEqual[T any](t *testing.T, data []byte, expected T) error { + var obj T + if err := json.Unmarshal(data, &obj); err != nil { + t.Fatalf("Error unmarshalling JSON: %v", err) + } + + if !reflect.DeepEqual(obj, expected) { + t.Fatalf("Expected %v, but got %v", expected, obj) + } + + return nil +} + +func TestStreamerV2(t *testing.T) { + testCases := []TestCase{ + { + Event: "thread.created", + Data: `{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`, + }, + { + Event: "thread.run.created", + Data: `{"id":"run_ojU7pVxtTIaa4l1GgRmHVSbK","object":"thread.run","created_at":1715864046,"assistant_id":"asst_7xUrZ16RBU2BpaUOzLnc9HsD","thread_id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","status":"queued","started_at":null,"expires_at":1715864646,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-3.5-turbo","instructions":null,"tools":[],"tool_resources":{"code_interpreter":{"file_ids":[]}},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto"}`, + }, + { + Event: "thread.message.delta", + Data: `{"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}}`, + }, + { + Event: "thread.run.requires_action", + Data: `{"id":"run_oNjmoH9jHSQBSPkuVqfHSaLs","object":"thread.run","created_at":1716281751,"assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_4yCKEOWSRQRofNuzl7Ny3uNs","status":"requires_action","started_at":1716281751,"expires_at":1716282351,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":{"type":"submit_tool_outputs","submit_tool_outputs":{"tool_calls":[{"id":"call_q7J5q7taE0K0x83HRuJxJJjR","type":"function","function":{"name":"lookupDefinition","arguments":"{\"entry\":\"square root of pi\",\"language\":\"en\"}"}}]}},"last_error":null,"model":"gpt-3.5-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"lookupDefinition","description":"Lookup the definition of an entry. e.g. word, short phrase, person, place, or term","parameters":{"properties":{"entry":{"description":"The entry to lookup","type":"string"},"language":{"description":"ISO 639-1 language code, e.g., 'en' for English, 'zh' for Chinese","type":"string"}},"type":"object"}}}],"tool_resources":{"code_interpreter":{"file_ids":[]}},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto"}`, + }, + { + Event: "thread.run.completed", + Data: `{"id":"run_o14scUSKGFFRrwhsfGkh2pMJ","object":"thread.run","created_at":1716281844,"assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_732uu0FpoLAGrOlxAz8syqD0","status":"completed","started_at":1716281844,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1716281845,"required_action":null,"last_error":null,"model":"gpt-3.5-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"lookupDefinition","description":"Lookup the definition of an entry. e.g. word, short phrase, person, place, or term","parameters":{"properties":{"entry":{"description":"The entry to lookup","type":"string"},"language":{"description":"ISO 639-1 language code, e.g., 'en' for English, 'zh' for Chinese","type":"string"}},"type":"object"}}}],"tool_resources":{"code_interpreter":{"file_ids":[]}},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":300,"completion_tokens":24,"total_tokens":324},"response_format":"auto","tool_choice":"auto"}`, + }, + { + Event: "thread.run.step.completed", + Data: `{"id":"step_9UKPyHGdL6VczTfigS5bdGQb","object":"thread.run.step","created_at":1716281845,"run_id":"run_o14scUSKGFFRrwhsfGkh2pMJ","assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_732uu0FpoLAGrOlxAz8syqD0","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1716281845,"expires_at":1716282444,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_Hb14QXWwPWEiMJ12L8Spa3T9"}},"usage":{"prompt_tokens":300,"completion_tokens":24,"total_tokens":324}}`, + }, + { + Event: "thread.message.completed", + Data: `{"id":"msg_Hb14QXWwPWEiMJ12L8Spa3T9","object":"thread.message","created_at":1716281845,"assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_732uu0FpoLAGrOlxAz8syqD0","run_id":"run_o14scUSKGFFRrwhsfGkh2pMJ","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1716281845,"role":"assistant","content":[{"type":"text","text":{"value":"Sure! Here you go:\n\nWhy couldn't the leopard play hide and seek?\n\nBecause he was always spotted!","annotations":[]}}],"attachments":[],"metadata":{}}`, + }, + { + Event: "done", + Data: "[DONE]", + }, + } + + streamer := openai.NewStreamerV2(constructStreamInput(testCases)) + + for _, tc := range testCases { + if !streamer.Next() { + t.Fatal("Expected Next() to return true, but got false") + } + + event := streamer.Event() + + if event.Event() != tc.Event { + t.Fatalf("Expected event type to be %s, but got %s", tc.Event, event.Event()) + } + + if tc.Event != "done" { + // compare the json data + jsondata := event.JSON() + if string(jsondata) != tc.Data { + t.Fatalf("Expected JSON data to be %s, but got %s", tc.Data, string(jsondata)) + } + } + + switch event := event.(type) { + case *openai.StreamThreadCreated: + jsonEqual(t, []byte(tc.Data), event.Thread) + case *openai.StreamThreadRunCreated: + jsonEqual(t, []byte(tc.Data), event.Run) + case *openai.StreamThreadMessageDelta: + fmt.Println(event) + + // reinitialize the delta object to avoid comparing the hidden streamEvent fields + delta := openai.StreamThreadMessageDelta{ + ID: event.ID, + Object: event.Object, + Delta: event.Delta, + } + + jsonEqual(t, []byte(tc.Data), delta) + case *openai.StreamThreadRunRequiresAction: + jsonEqual(t, []byte(tc.Data), event.Run) + case *openai.StreamThreadRunCompleted: + jsonEqual(t, []byte(tc.Data), event.Run) + case *openai.StreamRunStepCompleted: + jsonEqual(t, []byte(tc.Data), event.RunStep) + case *openai.StreamDone: + if event.JSON() != nil { + t.Fatalf("Expected JSON data to be nil, but got %s", string(event.JSON())) + } + } + } +} + +func TestStreamThreadMessageDeltaJSON(t *testing.T) { + tests := []struct { + name string + jsonData string + expectType string + expectValue interface{} + }{ + { + name: "DeltaContent with Text", + jsonData: `{"index":0,"type":"text","text":{"value":"hello"}}`, + expectType: "text", + expectValue: &openai.DeltaText{Value: "hello"}, + }, + { + name: "DeltaContent with ImageFile", + jsonData: `{"index":1,"type":"image_file","image_file":{"file_id":"file123","detail":"An image"}}`, + expectType: "image_file", + expectValue: &openai.DeltaImageFile{FileID: "file123", Detail: "An image"}, + }, + { + name: "DeltaContent with ImageURL", + jsonData: `{"index":2,"type":"image_url","image_url":{"url":"https://example.com/image.jpg","detail":"low"}}`, + expectType: "image_url", + expectValue: &openai.DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var content openai.DeltaContent + err := json.Unmarshal([]byte(tt.jsonData), &content) + if err != nil { + t.Fatalf("Error unmarshalling JSON: %v", err) + } + + if content.Type != tt.expectType { + t.Errorf("Expected Type to be '%s', got %s", tt.expectType, content.Type) + } + + var actualValue interface{} + switch tt.expectType { + case "text": + actualValue = content.Text + case "image_file": + actualValue = content.ImageFile + case "image_url": + actualValue = content.ImageURL + default: + t.Fatalf("Unexpected type: %s", tt.expectType) + } + + if !reflect.DeepEqual(actualValue, tt.expectValue) { + t.Errorf("Expected value to be '%v', got %v", tt.expectValue, actualValue) + } + }) + } +}