Skip to content

Commit 6aa9d4d

Browse files
committed
refactor: refactor completion methods to use context and consistent naming
- Added `context.Context` as a parameter to `CreateCompletion` calls across multiple files. - Renamed `EngineExecOutput` to `CompletionOutput` and updated related methods and fields. - Updated field names in `StreamCompletionOutput` to use PascalCase for consistency. - Removed unused `defaultChatID` constant and `SummaryContentOutput` type. - Added a new `setupChatContext` method call in the `CreateCompletion` function. Signed-off-by: codiing-hui <wecoding@yeah.net>
1 parent 1e65b0a commit 6aa9d4d

File tree

4 files changed

+49
-56
lines changed

4 files changed

+49
-56
lines changed

internal/cli/commit/commit.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package commit
22

33
import (
4+
"context"
45
"fmt"
56
"html"
67
"os"
@@ -246,7 +247,7 @@ func (o *Options) codeReview(engine *llm.Engine, vars map[string]any) error {
246247
return err
247248
}
248249

249-
resp, err := engine.CreateCompletion(p.Messages())
250+
resp, err := engine.CreateCompletion(context.Background(), p.Messages())
250251
if err != nil {
251252
return err
252253
}
@@ -265,7 +266,7 @@ func (o *Options) summarizeTitle(engine *llm.Engine, vars map[string]any) error
265266
return err
266267
}
267268

268-
resp, err := engine.CreateCompletion(p.Messages())
269+
resp, err := engine.CreateCompletion(context.Background(), p.Messages())
269270
if err != nil {
270271
return err
271272
}
@@ -285,7 +286,7 @@ func (o *Options) summarizePrefix(engine *llm.Engine, vars map[string]any) error
285286
return err
286287
}
287288

288-
resp, err := engine.CreateCompletion(p.Messages())
289+
resp, err := engine.CreateCompletion(context.Background(), p.Messages())
289290
if err != nil {
290291
return err
291292
}
@@ -330,7 +331,7 @@ func (o *Options) generateCommitMsg(engine *llm.Engine, vars map[string]any) (st
330331
return "", err
331332
}
332333

333-
resp, err := engine.CreateCompletion(translationPrompt.Messages())
334+
resp, err := engine.CreateCompletion(context.Background(), translationPrompt.Messages())
334335
if err != nil {
335336
return "", err
336337
}

internal/cli/review/review.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package review
22

33
import (
4+
"context"
45
"errors"
56
"strings"
67

@@ -74,7 +75,7 @@ func (o *Options) reviewCode(cmd *cobra.Command, args []string) error {
7475

7576
// Get summarize comment from diff datas
7677
color.Cyan("We are trying to review code changes")
77-
reviewResp, err := llmEngine.CreateCompletion(reviewPrompt.Messages())
78+
reviewResp, err := llmEngine.CreateCompletion(context.Background(), reviewPrompt.Messages())
7879
if err != nil {
7980
return err
8081
}
@@ -92,7 +93,7 @@ func (o *Options) reviewCode(cmd *cobra.Command, args []string) error {
9293
}
9394

9495
color.Cyan("we are trying to translate code review to " + o.commitLang + " language")
95-
translationResp, err := llmEngine.CreateCompletion(translationPrompt.Messages())
96+
translationResp, err := llmEngine.CreateCompletion(context.Background(), translationPrompt.Messages())
9697
if err != nil {
9798
return err
9899
}

internal/llm/engine.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ import (
2020
)
2121

2222
const (
23-
noExec = "[noexec]"
24-
defaultChatID = "temp_session"
23+
noExec = "[noexec]"
2524
)
2625

2726
type Engine struct {
@@ -137,32 +136,35 @@ func (e *Engine) GetConvoStore() convo.Store {
137136

138137
func (e *Engine) Interrupt() {
139138
e.channel <- StreamCompletionOutput{
140-
content: "[Interrupt]",
141-
last: true,
142-
interrupt: true,
143-
executable: false,
139+
Content: "[Interrupt]",
140+
Last: true,
141+
Interrupt: true,
142+
Executable: false,
144143
}
145144

146145
e.running = false
147146
}
148147

149-
func (e *Engine) CreateCompletion(messages []llms.ChatMessage) (*EngineExecOutput, error) {
150-
ctx := context.Background()
151-
148+
func (e *Engine) CreateCompletion(ctx context.Context, messages []llms.ChatMessage) (*CompletionOutput, error) {
152149
e.running = true
153150

151+
if err := e.setupChatContext(ctx, &messages); err != nil {
152+
return nil, err
153+
}
154+
154155
rsp, err := e.Model.GenerateContent(ctx, slices.Map(messages, convert), e.callOptions()...)
155156
if err != nil {
156157
return nil, errbook.Wrap("Failed to create completion.", err)
157158
}
158159

159160
content := rsp.Choices[0].Content
161+
160162
e.appendAssistantMessage(content)
161163

162-
var output EngineExecOutput
164+
var output CompletionOutput
163165
err = json.Unmarshal([]byte(content), &output)
164166
if err != nil {
165-
output = EngineExecOutput{
167+
output = CompletionOutput{
166168
Command: "",
167169
Explanation: content,
168170
Executable: false,
@@ -178,8 +180,8 @@ func (e *Engine) CreateStreamCompletion(ctx context.Context, messages []llms.Cha
178180
streamingFunc := func(ctx context.Context, chunk []byte) error {
179181
if !e.config.Quiet {
180182
e.channel <- StreamCompletionOutput{
181-
content: string(chunk),
182-
last: false,
183+
Content: string(chunk),
184+
Last: false,
183185
}
184186
}
185187
return nil
@@ -214,18 +216,19 @@ func (e *Engine) CreateStreamCompletion(ctx context.Context, messages []llms.Cha
214216

215217
if !e.config.Quiet {
216218
e.channel <- StreamCompletionOutput{
217-
content: "",
218-
last: true,
219-
executable: executable,
219+
Content: "",
220+
Last: true,
221+
Executable: executable,
220222
}
221223
}
222224
e.running = false
225+
223226
e.appendAssistantMessage(output)
224227

225228
return &StreamCompletionOutput{
226-
content: output,
227-
last: true,
228-
executable: executable,
229+
Content: output,
230+
Last: true,
231+
Executable: executable,
229232
}, nil
230233
}
231234

internal/llm/types.go

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,22 @@ func (m EngineMode) String() string {
1717
}
1818
}
1919

20-
type SummaryContentOutput struct {
21-
Content string `json:"content"`
22-
}
23-
24-
type EngineExecOutput struct {
20+
type CompletionOutput struct {
2521
Command string `json:"cmd"`
2622
Explanation string `json:"exp"`
2723
Executable bool `json:"exec"`
2824
}
2925

30-
func (eo EngineExecOutput) GetCommand() string {
31-
return eo.Command
26+
func (c CompletionOutput) GetCommand() string {
27+
return c.Command
3228
}
3329

34-
func (eo EngineExecOutput) GetExplanation() string {
35-
return eo.Explanation
30+
func (c CompletionOutput) GetExplanation() string {
31+
return c.Explanation
3632
}
3733

38-
func (eo EngineExecOutput) IsExecutable() bool {
39-
return eo.Executable
34+
func (c CompletionOutput) IsExecutable() bool {
35+
return c.Executable
4036
}
4137

4238
// CompletionInput is a tea.Msg that wraps the content read from stdin.
@@ -46,32 +42,24 @@ type CompletionInput struct {
4642

4743
// StreamCompletionOutput a tea.Msg that wraps the content returned from llm.
4844
type StreamCompletionOutput struct {
49-
content string
50-
last bool
51-
interrupt bool
52-
executable bool
53-
}
54-
55-
func (co StreamCompletionOutput) GetContent() string {
56-
return co.content
57-
}
58-
59-
func (co StreamCompletionOutput) IsLast() bool {
60-
return co.last
45+
Content string
46+
Last bool
47+
Interrupt bool
48+
Executable bool
6149
}
6250

63-
func (co StreamCompletionOutput) IsInterrupt() bool {
64-
return co.interrupt
51+
func (c StreamCompletionOutput) GetContent() string {
52+
return c.Content
6553
}
6654

67-
func (co StreamCompletionOutput) IsExecutable() bool {
68-
return co.executable
55+
func (c StreamCompletionOutput) IsLast() bool {
56+
return c.Last
6957
}
7058

71-
type EngineLoggingCallbackOutput struct {
72-
content string
59+
func (c StreamCompletionOutput) IsInterrupt() bool {
60+
return c.Interrupt
7361
}
7462

75-
func (ec EngineLoggingCallbackOutput) GetContent() string {
76-
return ec.content
63+
func (c StreamCompletionOutput) IsExecutable() bool {
64+
return c.Executable
7765
}

0 commit comments

Comments
 (0)