Skip to content

Commit dc2adb6

Browse files
committed
feat: refactor chat history storage and enhance conversation handling
- Added a new dependency `github.com/charmbracelet/x/exp/ordered v0.1.0` to `go.mod`. - Removed unused imports (`context` and `convo`) from `cli.go`. - Simplified the `postRunHook` function by removing unnecessary conversation storage logic. - Renamed methods in `chat_history_store.go` from `Persistent` and `Invalidate` to `PersistentMessages` and `InvalidateMessages` respectively. - Added directory creation logic before saving messages in `chat_history_store.go`. - Updated test cases in `chat_history_store_test.go` to use the renamed methods. - Added a new test case for reading messages in `chat_history_store_test.go`. - Updated the `ChatMessageHistory` interface in `conversation.go` to reflect method name changes. - Exported the `Sha1reg` variable and added a `MatchSha1` function in `sha.go`. - Updated the SQLite store path to include a `conversations` subdirectory in `convo_store.go`. - Added a `GetConvoStore` method to the `Engine` struct in `engine.go`. - Modified `CreateStreamCompletion` to accept a `context.Context` parameter and added chat context setup logic. - Added new flags (`show`, `show-last`, `continue`, `continue-last`, `title`) in `basic_flags.go`. - Added new configuration fields (`ContinueLast`, `Continue`, `Title`, `Show`, `ShowLast`, `CacheReadFromID`, `CacheWriteToID`, `CacheWriteToTitle`) in `config.go`. - Added logic to handle conversation saving and retrieval in `chat.go`. - Added a `saveConversation` function to persist conversations in `chat.go`. - Added utility functions (`lastPrompt`, `firstLine`) to extract the first line of the last human message in `chat.go`. Signed-off-by: codiing-hui <wecoding@yeah.net>
1 parent dcc787f commit dc2adb6

File tree

12 files changed

+229
-48
lines changed

12 files changed

+229
-48
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ require (
1414
github.com/charmbracelet/bubbletea v1.2.5-0.20241205214244-9306010a31ee
1515
github.com/charmbracelet/glamour v0.8.0
1616
github.com/charmbracelet/lipgloss v1.0.0
17+
github.com/charmbracelet/x/exp/ordered v0.1.0
1718
github.com/coding-hui/common v0.8.7
1819
github.com/coding-hui/go-prompt v0.2.8
1920
github.com/coding-hui/wecoding-sdk-go v0.8.7

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ github.com/charmbracelet/x/ansi v0.4.5 h1:LqK4vwBNaXw2AyGIICa5/29Sbdq58GbGdFngSe
5858
github.com/charmbracelet/x/ansi v0.4.5/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
5959
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
6060
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
61+
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
62+
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
6163
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
6264
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
6365
github.com/coding-hui/common v0.8.7 h1:f9iHZcdQLgRFW/nIJHVfBINuzzCKsUcwVIhfboxXe2s=

internal/cli/cli.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package cli
77

88
import (
9-
"context"
109
"flag"
1110
"io"
1211
"os"
@@ -28,7 +27,6 @@ import (
2827
"github.com/coding-hui/ai-terminal/internal/cli/manpage"
2928
"github.com/coding-hui/ai-terminal/internal/cli/review"
3029
"github.com/coding-hui/ai-terminal/internal/cli/version"
31-
"github.com/coding-hui/ai-terminal/internal/convo"
3230
"github.com/coding-hui/ai-terminal/internal/errbook"
3331
"github.com/coding-hui/ai-terminal/internal/options"
3432
"github.com/coding-hui/ai-terminal/internal/util/debug"
@@ -142,14 +140,5 @@ func postRunHook(cfg *options.Config) error {
142140
if err := flushProfiling(); err != nil {
143141
return err
144142
}
145-
146-
convoStore, _ := convo.GetConversationStore(cfg)
147-
if !cfg.NoCache && convoStore != nil {
148-
err := convoStore.SaveConversation(context.Background(), cfg.ConversationID, "convo", cfg.Model)
149-
if err != nil {
150-
return err
151-
}
152-
}
153-
154143
return nil
155144
}

internal/convo/chat_history_store.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ func (h *SimpleChatHistoryStore) AddMessage(_ context.Context, convoID string, m
5151

5252
func (h *SimpleChatHistoryStore) SetMessages(ctx context.Context, convoID string, messages []llms.ChatMessage) error {
5353
h.messages[convoID] = messages
54-
if err := h.Invalidate(ctx, convoID); err != nil && !errors.Is(err, os.ErrNotExist) {
54+
if err := h.InvalidateMessages(ctx, convoID); err != nil && !errors.Is(err, os.ErrNotExist) {
5555
return err
5656
}
57-
return h.Persistent(ctx, convoID, h.messages[convoID])
57+
return h.PersistentMessages(ctx, convoID)
5858
}
5959

6060
func (h *SimpleChatHistoryStore) Messages(_ context.Context, convoID string) ([]llms.ChatMessage, error) {
6161
if !h.loaded[convoID] {
62-
if err := h.load(convoID); err != nil {
62+
if err := h.load(convoID); err != nil && !errors.Is(err, os.ErrNotExist) {
6363
return nil, err
6464
}
6565
h.loaded[convoID] = true
@@ -91,19 +91,24 @@ func (h *SimpleChatHistoryStore) load(convoID string) error {
9191
return nil
9292
}
9393

94-
func (h *SimpleChatHistoryStore) Persistent(_ context.Context, convoID string, messages []llms.ChatMessage) error {
94+
func (h *SimpleChatHistoryStore) PersistentMessages(_ context.Context, convoID string) error {
9595
if convoID == "" {
9696
return fmt.Errorf("write: %w", errInvalidID)
9797
}
9898

99+
// Ensure directory exists
100+
if err := os.MkdirAll(h.dir, 0755); err != nil {
101+
return fmt.Errorf("create directory: %w", err)
102+
}
103+
99104
file, err := os.Create(filepath.Join(h.dir, convoID+cacheExt))
100105
if err != nil {
101106
return fmt.Errorf("write: %w", err)
102107
}
103108
defer file.Close() //nolint:errcheck
104109

105110
var rawMessages []llms.ChatMessageModel
106-
for _, v := range messages {
111+
for _, v := range h.messages[convoID] {
107112
if v != nil {
108113
rawMessages = append(rawMessages, llms.ConvertChatMessageToModel(v))
109114
}
@@ -115,7 +120,7 @@ func (h *SimpleChatHistoryStore) Persistent(_ context.Context, convoID string, m
115120
return nil
116121
}
117122

118-
func (h *SimpleChatHistoryStore) Invalidate(_ context.Context, convoID string) error {
123+
func (h *SimpleChatHistoryStore) InvalidateMessages(_ context.Context, convoID string) error {
119124
if convoID == "" {
120125
return fmt.Errorf("delete: %w", errInvalidID)
121126
}

internal/convo/chat_history_store_test.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,28 @@ import (
1313
func TestStore(t *testing.T) {
1414
convoID := NewConversationID()
1515

16+
t.Run("read", func(t *testing.T) {
17+
store := NewSimpleChatHistoryStore("/Users/bytedance/Codes/ai-terminal/bin/conversations/")
18+
err := store.load("3b7e9fb1ae86660c32cef53346a3230f4bfc8797")
19+
require.NoError(t, err)
20+
messages, err := store.Messages(context.Background(), "3b7e9fb1ae86660c32cef53346a3230f4bfc8797")
21+
require.NoError(t, err)
22+
require.Len(t, messages, 1)
23+
})
24+
1625
t.Run("read non-existent", func(t *testing.T) {
1726
store := NewSimpleChatHistoryStore(t.TempDir())
1827
err := store.load("super-fake")
1928
require.ErrorIs(t, err, os.ErrNotExist)
2029
defer func() {
21-
_ = store.Invalidate(context.Background(), convoID)
30+
_ = store.InvalidateMessages(context.Background(), convoID)
2231
}()
2332
})
2433

2534
t.Run("set messages", func(t *testing.T) {
2635
ctx := context.Background()
2736
store := NewSimpleChatHistoryStore(t.TempDir())
28-
_ = store.Invalidate(context.Background(), convoID)
37+
_ = store.InvalidateMessages(context.Background(), convoID)
2938
require.NoError(t, store.AddUserMessage(ctx, convoID, "hello"))
3039
require.NoError(t, store.AddAIMessage(ctx, convoID, "hi"))
3140
require.NoError(t, store.AddAIMessage(ctx, convoID, "bye"))
@@ -36,13 +45,13 @@ func TestStore(t *testing.T) {
3645
require.Equal(t, 3, len(messages))
3746

3847
// After persist, messages should be saved
39-
require.NoError(t, store.Persistent(ctx, convoID, messages))
48+
require.NoError(t, store.PersistentMessages(ctx, convoID))
4049
persistedMessages, err := store.Messages(ctx, convoID)
4150
require.NoError(t, err)
4251
require.Equal(t, messages, persistedMessages)
4352

4453
defer func() {
45-
_ = store.Invalidate(context.Background(), convoID)
54+
_ = store.InvalidateMessages(context.Background(), convoID)
4655
}()
4756
})
4857

@@ -53,21 +62,21 @@ func TestStore(t *testing.T) {
5362
llms.HumanChatMessage{Content: "bar"},
5463
llms.AIChatMessage{Content: "zoo"},
5564
}
56-
require.NoError(t, store.Persistent(context.Background(), convoID, messages))
65+
require.NoError(t, store.SetMessages(context.Background(), convoID, messages))
5766
require.NoError(t, store.load(convoID))
5867
require.ElementsMatch(t, messages, store.messages[convoID])
5968
defer func() {
60-
_ = store.Invalidate(context.Background(), convoID)
69+
_ = store.InvalidateMessages(context.Background(), convoID)
6170
}()
6271
})
6372

6473
t.Run("delete", func(t *testing.T) {
6574
store := NewSimpleChatHistoryStore(t.TempDir())
66-
require.NoError(t, store.Persistent(context.Background(), convoID, []llms.ChatMessage{}))
67-
require.NoError(t, store.Invalidate(context.Background(), convoID))
75+
require.NoError(t, store.PersistentMessages(context.Background(), convoID))
76+
require.NoError(t, store.InvalidateMessages(context.Background(), convoID))
6877
require.ErrorIs(t, store.load(convoID), os.ErrNotExist)
6978
defer func() {
70-
_ = store.Invalidate(context.Background(), convoID)
79+
_ = store.InvalidateMessages(context.Background(), convoID)
7180
}()
7281
})
7382

internal/convo/conversation.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ type ChatMessageHistory interface {
8080
SetMessages(ctx context.Context, convoID string, messages []llms.ChatMessage) error
8181
// Messages retrieves all messages from the store
8282
Messages(ctx context.Context, convoID string) ([]llms.ChatMessage, error)
83-
// Persistent saves messages to persistent storage
84-
Persistent(ctx context.Context, convoID string, messages []llms.ChatMessage) error
85-
// Invalidate removes messages from persistent storage
86-
Invalidate(ctx context.Context, convoID string) error
83+
// PersistentMessages saves messages to persistent storage
84+
PersistentMessages(ctx context.Context, convoID string) error
85+
// InvalidateMessages removes messages from persistent storage
86+
InvalidateMessages(ctx context.Context, convoID string) error
8787
}
8888

8989
// Store is the interface for chat history convo store.

internal/convo/sha.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ const (
1313
Sha1ReadBlockSize = 4096
1414
)
1515

16-
var sha1reg = regexp.MustCompile(`\b[0-9a-f]{40}\b`)
16+
var Sha1reg = regexp.MustCompile(`\b[0-9a-f]{40}\b`)
1717

1818
func NewConversationID() string {
1919
b := make([]byte, Sha1ReadBlockSize)
2020
_, _ = rand.Read(b)
2121
return fmt.Sprintf("%x", sha1.Sum(b)) //nolint: gosec
2222
}
23+
24+
func MatchSha1(s string) bool {
25+
return Sha1reg.MatchString(s)
26+
}

internal/convo/sqlite3/convo_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func (s *sqliteStoreFactor) Type() string {
3333

3434
func (s *sqliteStoreFactor) Create(options *options.Config) (convo.Store, error) {
3535
return NewSqliteStore(
36-
WithDataPath(options.DataStore.CachePath),
36+
WithDataPath(filepath.Join(options.DataStore.CachePath, "conversations")),
3737
WithConversation(options.ConversationID),
3838
WithDBAddress(filepath.Join(options.DataStore.CachePath, "convo.db")),
3939
), nil

internal/llm/engine.go

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ func (e *Engine) GetChannel() chan StreamCompletionOutput {
132132
return e.channel
133133
}
134134

135+
func (e *Engine) GetConvoStore() convo.Store {
136+
return e.convoStore
137+
}
138+
135139
func (e *Engine) Interrupt() {
136140
e.channel <- StreamCompletionOutput{
137141
content: "[Interrupt]",
@@ -169,9 +173,7 @@ func (e *Engine) CreateCompletion(messages []llms.ChatMessage) (*EngineExecOutpu
169173
return &output, nil
170174
}
171175

172-
func (e *Engine) CreateStreamCompletion(messages []llms.ChatMessage) tea.Msg {
173-
ctx := context.Background()
174-
176+
func (e *Engine) CreateStreamCompletion(ctx context.Context, messages []llms.ChatMessage) tea.Msg {
175177
e.running = true
176178

177179
streamingFunc := func(ctx context.Context, chunk []byte) error {
@@ -184,6 +186,17 @@ func (e *Engine) CreateStreamCompletion(messages []llms.ChatMessage) tea.Msg {
184186
return nil
185187
}
186188

189+
if err := e.setupChatContext(ctx, &messages); err != nil {
190+
return err
191+
}
192+
193+
for _, v := range messages {
194+
err := e.convoStore.AddMessage(ctx, e.config.CacheWriteToID, v)
195+
if err != nil {
196+
errbook.HandleError(errbook.Wrap("Failed to add user chat input message to history", err))
197+
}
198+
}
199+
187200
messageParts := slices.Map(messages, convert)
188201
rsp, err := e.Model.GenerateContent(ctx, messageParts, e.callOptions(streamingFunc)...)
189202
if err != nil {
@@ -235,16 +248,25 @@ func (e *Engine) callOptions(streamingFunc ...func(ctx context.Context, chunk []
235248
return opts
236249
}
237250

238-
func (e *Engine) appendUserMessage(content string) {
239-
if len(strings.TrimSpace(content)) == 0 {
240-
errbook.HandleError(errbook.New("empty input is not allowed."))
241-
return
251+
func (e *Engine) setupChatContext(ctx context.Context, messages *[]llms.ChatMessage) error {
252+
store := e.convoStore
253+
if store == nil {
254+
return errbook.New("no chat history store found")
242255
}
243-
if e.convoStore != nil {
244-
if err := e.convoStore.AddUserMessage(context.Background(), e.config.ConversationID, content); err != nil {
245-
errbook.HandleError(errbook.Wrap("failed to add user chat input message to history", err))
256+
257+
if !e.config.NoCache && e.config.CacheReadFromID != "" {
258+
history, err := store.Messages(ctx, e.config.CacheReadFromID)
259+
if err != nil {
260+
return errbook.Wrap(fmt.Sprintf(
261+
"There was a problem reading the cache. Use %s / %s to disable it.",
262+
console.StderrStyles().InlineCode.Render("--no-cache"),
263+
console.StderrStyles().InlineCode.Render("NO_CACHE"),
264+
), err)
246265
}
266+
*messages = append(*messages, history...)
247267
}
268+
269+
return nil
248270
}
249271

250272
func (e *Engine) appendAssistantMessage(content string) {

internal/options/basic_flags.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ func AddBasicFlags(flags *pflag.FlagSet, cfg *Config) {
3030
flags.UintVar(&cfg.Fanciness, "fanciness", cfg.Fanciness, console.StdoutStyles().FlagDesc.Render(help["fanciness"]))
3131
flags.StringVar(&cfg.LoadingText, "loading-text", cfg.LoadingText, console.StdoutStyles().FlagDesc.Render(help["status-text"]))
3232
flags.BoolVar(&cfg.NoCache, "no-cache", cfg.NoCache, console.StdoutStyles().FlagDesc.Render(help["no-cache"]))
33-
//flags.BoolVar(&cfg.Dirs, "dirs", false, console.StdoutStyles().FlagDesc.Render(help["dirs"]))
33+
flags.StringVarP(&cfg.Show, "show", "s", cfg.Show, console.StdoutStyles().FlagDesc.Render(help["show"]))
34+
flags.BoolVarP(&cfg.ShowLast, "show-last", "S", false, console.StdoutStyles().FlagDesc.Render(help["show-last"]))
35+
flags.StringVarP(&cfg.Continue, "continue", "c", "", console.StdoutStyles().FlagDesc.Render(help["continue"]))
36+
flags.BoolVarP(&cfg.ContinueLast, "continue-last", "C", false, console.StdoutStyles().FlagDesc.Render(help["continue-last"]))
37+
flags.StringVarP(&cfg.Title, "title", "T", cfg.Title, console.StdoutStyles().FlagDesc.Render(help["title"]))
3438
//flags.StringVarP(&cfg.Role, "role", "R", cfg.Role, console.StdoutStyles().FlagDesc.Render(help["role"]))
3539
//flags.BoolVar(&cfg.ListRoles, "list-roles", cfg.ListRoles, console.StdoutStyles().FlagDesc.Render(help["list-roles"]))
3640
//flags.StringVar(&cfg.Theme, "theme", "charm", console.StdoutStyles().FlagDesc.Render(help["theme"]))

0 commit comments

Comments
 (0)