diff --git a/.vscode/settings.json b/.vscode/settings.json index 25c1ea49f3..85f4c06cd6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -55,7 +55,6 @@ "files.associations": { "*.css": "tailwindcss" }, - "go.lintTool": "staticcheck", "gopls": { "analyses": { "QF1003": false diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 678ea77cc5..0726ea066e 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -191,14 +191,10 @@ func serverRunRouter(jwtToken string) error { func checkForUpdate() error { remoteInfo := wshutil.GetInfo() - needsRestartRaw, err := RpcClient.SendRpcRequest(wshrpc.Command_ConnUpdateWsh, remoteInfo, &wshrpc.RpcOpts{Timeout: 60000}) + needsRestart, err := wshclient.ConnUpdateWshCommand(RpcClient, remoteInfo, &wshrpc.RpcOpts{Timeout: 60000}) if err != nil { return fmt.Errorf("could not update: %w", err) } - needsRestart, ok := needsRestartRaw.(bool) - if !ok { - return fmt.Errorf("wrong return type from update") - } if needsRestart { // run the restart command here // how to get the correct path? diff --git a/cmd/wsh/cmd/wshcmd-deleteblock.go b/cmd/wsh/cmd/wshcmd-deleteblock.go index 6ff817dfcf..76518e721c 100644 --- a/cmd/wsh/cmd/wshcmd-deleteblock.go +++ b/cmd/wsh/cmd/wshcmd-deleteblock.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" ) var deleteBlockCmd = &cobra.Command{ @@ -35,7 +36,7 @@ func deleteBlockRun(cmd *cobra.Command, args []string) (rtnErr error) { deleteBlockData := &wshrpc.CommandDeleteBlockData{ BlockId: fullORef.OID, } - _, err = RpcClient.SendRpcRequest(wshrpc.Command_DeleteBlock, deleteBlockData, &wshrpc.RpcOpts{Timeout: 2000}) + err = wshclient.DeleteBlockCommand(RpcClient, *deleteBlockData, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return fmt.Errorf("delete block failed: %v", err) } diff --git a/cmd/wsh/cmd/wshcmd-editconfig.go b/cmd/wsh/cmd/wshcmd-editconfig.go index 5f2153dd77..6dc9c13f6a 100644 --- a/cmd/wsh/cmd/wshcmd-editconfig.go +++ b/cmd/wsh/cmd/wshcmd-editconfig.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" ) var editConfigMagnified bool @@ -48,7 +49,7 @@ func editConfigRun(cmd *cobra.Command, args []string) (rtnErr error) { Focused: true, } - _, err := RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) + _, err := wshclient.CreateBlockCommand(RpcClient, *wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return fmt.Errorf("opening config file: %w", err) } diff --git a/cmd/wsh/cmd/wshcmd-notify.go b/cmd/wsh/cmd/wshcmd-notify.go index 826e38ba6b..de2086e1f7 100644 --- a/cmd/wsh/cmd/wshcmd-notify.go +++ b/cmd/wsh/cmd/wshcmd-notify.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshutil" ) @@ -38,7 +39,7 @@ func notifyRun(cmd *cobra.Command, args []string) (rtnErr error) { Body: message, Silent: notifySilent, } - _, err := RpcClient.SendRpcRequest(wshrpc.Command_Notify, notificationOptions, &wshrpc.RpcOpts{Timeout: 2000, Route: wshutil.ElectronRoute}) + err := wshclient.NotifyCommand(RpcClient, *notificationOptions, &wshrpc.RpcOpts{Timeout: 2000, Route: wshutil.ElectronRoute}) if err != nil { return fmt.Errorf("sending notification: %w", err) } diff --git a/cmd/wsh/cmd/wshcmd-secret.go b/cmd/wsh/cmd/wshcmd-secret.go index f2c287579a..7d555c0dec 100644 --- a/cmd/wsh/cmd/wshcmd-secret.go +++ b/cmd/wsh/cmd/wshcmd-secret.go @@ -187,7 +187,7 @@ func secretUiRun(cmd *cobra.Command, args []string) (rtnErr error) { Focused: true, } - _, err := RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) + _, err := wshclient.CreateBlockCommand(RpcClient, *wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return fmt.Errorf("opening secrets UI: %w", err) } diff --git a/cmd/wsh/cmd/wshcmd-setmeta.go b/cmd/wsh/cmd/wshcmd-setmeta.go index 13e3a352b7..79faa7e78c 100644 --- a/cmd/wsh/cmd/wshcmd-setmeta.go +++ b/cmd/wsh/cmd/wshcmd-setmeta.go @@ -13,6 +13,7 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" ) var setMetaCmd = &cobra.Command{ @@ -192,7 +193,7 @@ func setMetaRun(cmd *cobra.Command, args []string) (rtnErr error) { ORef: *fullORef, Meta: fullMeta, } - _, err = RpcClient.SendRpcRequest(wshrpc.Command_SetMeta, setMetaWshCmd, &wshrpc.RpcOpts{Timeout: 2000}) + err = wshclient.SetMetaCommand(RpcClient, *setMetaWshCmd, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return fmt.Errorf("setting metadata: %v", err) } diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index ccba3a3d9c..b0aafe148f 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -13,6 +13,7 @@ import ( "github.com/spf13/cobra" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" ) var viewMagnified bool @@ -99,7 +100,7 @@ func viewRun(cmd *cobra.Command, args []string) (rtnErr error) { wshCmd.BlockDef.Meta[waveobj.MetaKey_Connection] = conn } } - _, err := RpcClient.SendRpcRequest(wshrpc.Command_CreateBlock, wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) + _, err := wshclient.CreateBlockCommand(RpcClient, *wshCmd, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return fmt.Errorf("running view command: %w", err) } diff --git a/frontend/app/aipanel/aitooluse.tsx b/frontend/app/aipanel/aitooluse.tsx index 3406e0a5ff..7868c188e9 100644 --- a/frontend/app/aipanel/aitooluse.tsx +++ b/frontend/app/aipanel/aitooluse.tsx @@ -146,26 +146,11 @@ interface AIToolUseBatchProps { const AIToolUseBatch = memo(({ parts, isStreaming }: AIToolUseBatchProps) => { const [userApprovalOverride, setUserApprovalOverride] = useState(null); - const partsRef = useRef(parts); - partsRef.current = parts; - // All parts in a batch have the same approval status (enforced by grouping logic in AIToolUseGroup) const firstTool = parts[0].data; const baseApproval = userApprovalOverride || firstTool.approval; const effectiveApproval = getEffectiveApprovalStatus(baseApproval, isStreaming); - useEffect(() => { - if (!isStreaming || effectiveApproval !== "needs-approval") return; - - const interval = setInterval(() => { - partsRef.current.forEach((part) => { - WaveAIModel.getInstance().toolUseKeepalive(part.data.toolcallid); - }); - }, 4000); - - return () => clearInterval(interval); - }, [isStreaming, effectiveApproval]); - const handleApprove = () => { setUserApprovalOverride("user-approved"); parts.forEach((part) => { @@ -212,8 +197,6 @@ const AIToolUse = memo(({ part, isStreaming }: AIToolUseProps) => { const showRestoreModal = restoreModalToolCallId === toolData.toolcallid; const highlightTimeoutRef = useRef(null); const highlightedBlockIdRef = useRef(null); - const toolCallIdRef = useRef(toolData.toolcallid); - toolCallIdRef.current = toolData.toolcallid; const statusIcon = toolData.status === "completed" ? "✓" : toolData.status === "error" ? "✗" : "•"; const statusColor = @@ -224,16 +207,6 @@ const AIToolUse = memo(({ part, isStreaming }: AIToolUseProps) => { const isFileWriteTool = toolData.toolname === "write_text_file" || toolData.toolname === "edit_text_file"; - useEffect(() => { - if (!isStreaming || effectiveApproval !== "needs-approval") return; - - const interval = setInterval(() => { - WaveAIModel.getInstance().toolUseKeepalive(toolCallIdRef.current); - }, 4000); - - return () => clearInterval(interval); - }, [isStreaming, effectiveApproval]); - useEffect(() => { return () => { if (highlightTimeoutRef.current) { diff --git a/frontend/app/store/global.ts b/frontend/app/store/global.ts index 35eb9f1585..42e2c635fe 100644 --- a/frontend/app/store/global.ts +++ b/frontend/app/store/global.ts @@ -458,6 +458,16 @@ function useBlockDataLoaded(blockId: string): boolean { return useAtomValue(loadedAtom); } +/** + * Safely read an atom value, returning null if the atom is null. + */ +function readAtom(atom: Atom): T { + if (atom == null) { + return null; + } + return globalStore.get(atom); +} + /** * Get the preload api. */ @@ -863,6 +873,7 @@ export { getUserName, globalPrimaryTabStartup, globalStore, + readAtom, initGlobal, initGlobalWaveEventSubs, isDev, diff --git a/frontend/app/view/term/term-model.ts b/frontend/app/view/term/term-model.ts index 4ea57faf8d..bc4d15899c 100644 --- a/frontend/app/view/term/term-model.ts +++ b/frontend/app/view/term/term-model.ts @@ -21,6 +21,8 @@ import { getOverrideConfigAtom, getSettingsKeyAtom, globalStore, + readAtom, + recordTEvent, useBlockAtom, WOS, } from "@/store/global"; @@ -478,6 +480,14 @@ export class TermViewModel implements ViewModel { } keyDownHandler(waveEvent: WaveKeyboardEvent): boolean { + if (keyutil.checkKeyPressed(waveEvent, "Ctrl:r")) { + const shellIntegrationStatus = readAtom(this.termRef?.current?.shellIntegrationStatusAtom); + if (shellIntegrationStatus === "ready") { + recordTEvent("action:term", { "action:type": "term:ctrlr" }); + } + // just for telemetry, we allow this keybinding through, back to the terminal + return false; + } if (keyutil.checkKeyPressed(waveEvent, "Cmd:Escape")) { const blockAtom = WOS.getWaveObjectAtom(`block:${this.blockId}`); const blockData = globalStore.get(blockAtom); diff --git a/frontend/app/view/term/termwrap.ts b/frontend/app/view/term/termwrap.ts index 0ad7fa7495..f34ba68a0a 100644 --- a/frontend/app/view/term/termwrap.ts +++ b/frontend/app/view/term/termwrap.ts @@ -209,6 +209,29 @@ function addTestMarkerDecoration(terminal: Terminal, marker: TermTypes.IMarker, }); } +function checkCommandForTelemetry(decodedCmd: string) { + if (!decodedCmd) { + return; + } + + if (decodedCmd.startsWith("ssh ")) { + recordTEvent("conn:connect", { "conn:conntype": "ssh-manual" }); + return; + } + + const editorsRegex = /^(vim|vi|nano|nvim)\b/; + if (editorsRegex.test(decodedCmd)) { + recordTEvent("action:term", { "action:type": "cli-edit" }); + return; + } + + const tailFollowRegex = /(^|\|\s*)tail\s+-[fF]\b/; + if (tailFollowRegex.test(decodedCmd)) { + recordTEvent("action:term", { "action:type": "cli-tailf" }); + return; + } +} + // OSC 16162 - Shell Integration Commands // See aiprompts/wave-osc-16162.md for full documentation type ShellIntegrationStatus = "ready" | "running-command"; @@ -274,9 +297,7 @@ function handleOsc16162Command(data: string, blockId: string, loaded: boolean, t const decodedCmd = base64ToString(cmd.data.cmd64); rtInfo["shell:lastcmd"] = decodedCmd; globalStore.set(termWrap.lastCommandAtom, decodedCmd); - if (decodedCmd?.startsWith("ssh ")) { - recordTEvent("conn:connect", { "conn:conntype": "ssh-manual" }); - } + checkCommandForTelemetry(decodedCmd); } catch (e) { console.error("Error decoding cmd64:", e); rtInfo["shell:lastcmd"] = null; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 56b9253b39..8a91d4d571 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -529,7 +529,6 @@ declare global { // wshrpc.CommandWaveAIToolApproveData type CommandWaveAIToolApproveData = { toolcallid: string; - keepalive?: boolean; approval?: string; }; @@ -1235,6 +1234,7 @@ declare global { "action:type"?: string; "debug:panictype"?: string; "block:view"?: string; + "block:controller"?: string; "ai:backendtype"?: string; "ai:local"?: boolean; "wsh:cmd"?: string; diff --git a/package-lock.json b/package-lock.json index 9a81ab4bf5..8503c1a213 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "waveterm", - "version": "0.13.1-beta.0", + "version": "0.13.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "waveterm", - "version": "0.13.1-beta.0", + "version": "0.13.1", "hasInstallScript": true, "license": "Apache-2.0", "workspaces": [ diff --git a/pkg/aiusechat/toolapproval.go b/pkg/aiusechat/toolapproval.go index 7c374a15b6..4009c6dd71 100644 --- a/pkg/aiusechat/toolapproval.go +++ b/pkg/aiusechat/toolapproval.go @@ -4,23 +4,37 @@ package aiusechat import ( + "context" "sync" - "time" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" -) - -const ( - InitialApprovalTimeout = 10 * time.Second - KeepAliveExtension = 10 * time.Second + "github.com/wavetermdev/waveterm/pkg/web/sse" ) type ApprovalRequest struct { - approval string - done bool - doneChan chan struct{} - timer *time.Timer - mu sync.Mutex + approval string + done bool + doneChan chan struct{} + mu sync.Mutex + onCloseUnregFn func() +} + +func (req *ApprovalRequest) updateApproval(approval string) { + req.mu.Lock() + defer req.mu.Unlock() + + if req.done { + return + } + + req.approval = approval + req.done = true + + if req.onCloseUnregFn != nil { + req.onCloseUnregFn() + } + + close(req.doneChan) } type ApprovalRegistry struct { @@ -38,6 +52,16 @@ func registerToolApprovalRequest(toolCallId string, req *ApprovalRequest) { globalApprovalRegistry.requests[toolCallId] = req } +func UnregisterToolApproval(toolCallId string) { + globalApprovalRegistry.mu.Lock() + defer globalApprovalRegistry.mu.Unlock() + req := globalApprovalRegistry.requests[toolCallId] + delete(globalApprovalRegistry.requests, toolCallId) + if req != nil { + req.updateApproval("") + } +} + func getToolApprovalRequest(toolCallId string) (*ApprovalRequest, bool) { globalApprovalRegistry.mu.Lock() defer globalApprovalRegistry.mu.Unlock() @@ -45,64 +69,43 @@ func getToolApprovalRequest(toolCallId string) (*ApprovalRequest, bool) { return req, exists } -func RegisterToolApproval(toolCallId string) { +func RegisterToolApproval(toolCallId string, sseHandler *sse.SSEHandlerCh) { req := &ApprovalRequest{ doneChan: make(chan struct{}), } - req.timer = time.AfterFunc(InitialApprovalTimeout, func() { - UpdateToolApproval(toolCallId, uctypes.ApprovalTimeout, false) + onCloseId := sseHandler.RegisterOnClose(func() { + UpdateToolApproval(toolCallId, uctypes.ApprovalCanceled) }) + req.onCloseUnregFn = func() { + sseHandler.UnregisterOnClose(onCloseId) + } + registerToolApprovalRequest(toolCallId, req) } -func UpdateToolApproval(toolCallId string, approval string, keepAlive bool) error { +func UpdateToolApproval(toolCallId string, approval string) error { req, exists := getToolApprovalRequest(toolCallId) if !exists { return nil } - req.mu.Lock() - defer req.mu.Unlock() - - if req.done { - return nil - } - - if keepAlive && approval == "" { - req.timer.Reset(KeepAliveExtension) - return nil - } - - req.approval = approval - req.done = true - - if req.timer != nil { - req.timer.Stop() - } - - close(req.doneChan) + req.updateApproval(approval) return nil } -func CurrentToolApprovalStatus(toolCallId string) string { - req, exists := getToolApprovalRequest(toolCallId) - if !exists { - return "" - } - - req.mu.Lock() - defer req.mu.Unlock() - return req.approval -} -func WaitForToolApproval(toolCallId string) string { +func WaitForToolApproval(ctx context.Context, toolCallId string) (string, error) { req, exists := getToolApprovalRequest(toolCallId) if !exists { - return "" + return "", nil } - <-req.doneChan + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-req.doneChan: + } req.mu.Lock() approval := req.approval @@ -112,5 +115,5 @@ func WaitForToolApproval(toolCallId string) string { delete(globalApprovalRegistry.requests, toolCallId) globalApprovalRegistry.mu.Unlock() - return approval + return approval, nil } diff --git a/pkg/aiusechat/uctypes/uctypes.go b/pkg/aiusechat/uctypes/uctypes.go index f8bdc21691..b857f141bd 100644 --- a/pkg/aiusechat/uctypes/uctypes.go +++ b/pkg/aiusechat/uctypes/uctypes.go @@ -180,6 +180,7 @@ const ( ApprovalUserDenied = "user-denied" ApprovalTimeout = "timeout" ApprovalAutoApproved = "auto-approved" + ApprovalCanceled = "canceled" ) type AIModeConfig struct { @@ -520,7 +521,6 @@ type WaveChatOpts struct { TabStateGenerator func() (string, []ToolDefinition, string, error) BuilderAppGenerator func() (string, string, string, error) WidgetAccess bool - RegisterToolApproval func(string) AllowNativeWebSearch bool BuilderId string BuilderAppId string diff --git a/pkg/aiusechat/usechat.go b/pkg/aiusechat/usechat.go index 9ccf847c8d..08e675c43f 100644 --- a/pkg/aiusechat/usechat.go +++ b/pkg/aiusechat/usechat.go @@ -252,11 +252,12 @@ func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCa if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval { log.Printf(" waiting for approval...\n") - approval := WaitForToolApproval(toolCall.ID) - log.Printf(" approval result: %q\n", approval) - if approval != "" { - toolCall.ToolUseData.Approval = approval + approval, err := WaitForToolApproval(context.Background(), toolCall.ID) + if err != nil || approval == "" { + approval = uctypes.ApprovalCanceled } + log.Printf(" approval result: %q\n", approval) + toolCall.ToolUseData.Approval = approval if !toolCall.ToolUseData.IsApproved() { errorMsg := "Tool use denied or timed out" @@ -264,6 +265,8 @@ func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCa errorMsg = "Tool use denied by user" } else if approval == uctypes.ApprovalTimeout { errorMsg = "Tool approval timed out" + } else if approval == uctypes.ApprovalCanceled { + errorMsg = "Tool approval canceled" } toolCall.ToolUseData.Status = uctypes.ToolUseStatusError toolCall.ToolUseData.ErrorMessage = errorMsg @@ -340,8 +343,8 @@ func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason log.Printf("AI data-tooluse %s\n", toolCall.ID) _ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, toolUseData) updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolUseData) - if toolUseData.Approval == uctypes.ApprovalNeedsApproval && chatOpts.RegisterToolApproval != nil { - chatOpts.RegisterToolApproval(toolCall.ID) + if toolUseData.Approval == uctypes.ApprovalNeedsApproval { + RegisterToolApproval(toolCall.ID, sseHandler) } } // At this point, all ToolCalls are guaranteed to have non-nil ToolUseData @@ -350,6 +353,7 @@ func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason for _, toolCall := range stopReason.ToolCalls { result := processToolCall(backend, toolCall, chatOpts, sseHandler, metrics) toolResults = append(toolResults, result) + UnregisterToolApproval(toolCall.ID) } toolResultMsgs, err := backend.ConvertToolResultsToNativeChatMessage(toolResults) @@ -666,7 +670,6 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) { ClientId: client.OID, Config: *aiOpts, WidgetAccess: req.WidgetAccess, - RegisterToolApproval: RegisterToolApproval, AllowNativeWebSearch: true, BuilderId: req.BuilderId, BuilderAppId: req.BuilderAppId, diff --git a/pkg/telemetry/telemetrydata/telemetrydata.go b/pkg/telemetry/telemetrydata/telemetrydata.go index 6bc1e6ee91..d2bf7deeab 100644 --- a/pkg/telemetry/telemetrydata/telemetrydata.go +++ b/pkg/telemetry/telemetrydata/telemetrydata.go @@ -27,6 +27,7 @@ var ValidEventNames = map[string]bool{ "action:createblock": true, "action:openwaveai": true, "action:other": true, + "action:term": true, "wsh:run": true, diff --git a/pkg/utilds/idlist.go b/pkg/utilds/idlist.go new file mode 100644 index 0000000000..08dec2aac3 --- /dev/null +++ b/pkg/utilds/idlist.go @@ -0,0 +1,64 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package utilds + +import ( + "sync" + + "github.com/google/uuid" +) + +type idListEntry[T any] struct { + id string + val T +} + +type IdList[T any] struct { + lock sync.Mutex + entries []idListEntry[T] +} + +func (il *IdList[T]) Register(val T) string { + il.lock.Lock() + defer il.lock.Unlock() + + id := uuid.New().String() + il.entries = append(il.entries, idListEntry[T]{id: id, val: val}) + return id +} + +func (il *IdList[T]) RegisterWithId(id string, val T) { + il.lock.Lock() + defer il.lock.Unlock() + + il.unregister_nolock(id) + il.entries = append(il.entries, idListEntry[T]{id: id, val: val}) +} + +func (il *IdList[T]) Unregister(id string) { + il.lock.Lock() + defer il.lock.Unlock() + + il.unregister_nolock(id) +} + +func (il *IdList[T]) unregister_nolock(id string) { + for i, entry := range il.entries { + if entry.id == id { + il.entries = append(il.entries[:i], il.entries[i+1:]...) + return + } + } +} + +func (il *IdList[T]) GetList() []T { + il.lock.Lock() + defer il.lock.Unlock() + + result := make([]T, len(il.entries)) + for i, entry := range il.entries { + result[i] = entry.val + } + return result +} \ No newline at end of file diff --git a/pkg/web/sse/ssehandler.go b/pkg/web/sse/ssehandler.go index 70c4706d8e..cdd055fbd7 100644 --- a/pkg/web/sse/ssehandler.go +++ b/pkg/web/sse/ssehandler.go @@ -11,6 +11,8 @@ import ( "strings" "sync" "time" + + "github.com/wavetermdev/waveterm/pkg/utilds" ) // see /aiprompts/usechat-streamingproto.md for protocol @@ -64,16 +66,17 @@ type SSEMessage struct { type SSEHandlerCh struct { w http.ResponseWriter rc *http.ResponseController - ctx context.Context + ctx context.Context // the r.Context() writeCh chan SSEMessage - errCh chan error - mu sync.RWMutex + lock sync.Mutex closed bool initialized bool err error - wg sync.WaitGroup + wg sync.WaitGroup + onCloseHandlers utilds.IdList[func()] + handlersRun bool } // MakeSSEHandlerCh creates a new channel-based SSE handler @@ -83,14 +86,13 @@ func MakeSSEHandlerCh(w http.ResponseWriter, ctx context.Context) *SSEHandlerCh rc: http.NewResponseController(w), ctx: ctx, writeCh: make(chan SSEMessage, 10), // Buffered to prevent blocking - errCh: make(chan error, 1), // Buffered for single error } } // SetupSSE configures the response headers and starts the writer goroutine func (h *SSEHandlerCh) SetupSSE() error { - h.mu.Lock() - defer h.mu.Unlock() + h.lock.Lock() + defer h.lock.Unlock() if h.closed { return fmt.Errorf("SSE handler is closed") @@ -127,6 +129,7 @@ func (h *SSEHandlerCh) SetupSSE() error { // writerLoop handles all writes and keepalives in a single goroutine func (h *SSEHandlerCh) writerLoop() { defer h.wg.Done() + defer h.runOnCloseHandlers() keepaliveTicker := time.NewTicker(SSEKeepaliveInterval) defer keepaliveTicker.Stop() @@ -152,6 +155,7 @@ func (h *SSEHandlerCh) writerLoop() { } case <-h.ctx.Done(): + h.setError(h.ctx.Err()) return } } @@ -159,6 +163,9 @@ func (h *SSEHandlerCh) writerLoop() { // writeMessage writes a message to the SSE stream func (h *SSEHandlerCh) writeMessage(msg SSEMessage) error { + if h.ctx.Err() != nil { + return h.ctx.Err() + } switch msg.Type { case SSEMsgData: return h.writeDirectly(msg.Data, SSEMsgData) @@ -175,8 +182,8 @@ func (h *SSEHandlerCh) writeMessage(msg SSEMessage) error { // isInitialized returns whether SetupSSE has been called func (h *SSEHandlerCh) isInitialized() bool { - h.mu.RLock() - defer h.mu.RUnlock() + h.lock.Lock() + defer h.lock.Unlock() return h.initialized } @@ -225,31 +232,30 @@ func (h *SSEHandlerCh) flush() error { // setError sets the error state thread-safely func (h *SSEHandlerCh) setError(err error) { - h.mu.Lock() - defer h.mu.Unlock() + h.lock.Lock() + defer h.lock.Unlock() if h.err == nil { h.err = err - // Send error to error channel if there's space - select { - case h.errCh <- err: - default: - } } } -// WriteData queues data to be written in SSE format -func (h *SSEHandlerCh) WriteData(data string) error { - h.mu.RLock() +// queueMessage queues an SSEMessage to be written +func (h *SSEHandlerCh) queueMessage(msg SSEMessage) error { + h.lock.Lock() closed := h.closed - h.mu.RUnlock() + h.lock.Unlock() if closed { return fmt.Errorf("SSE handler is closed") } + if err := h.Err(); err != nil { + return err + } + select { - case h.writeCh <- SSEMessage{Type: SSEMsgData, Data: data}: + case h.writeCh <- msg: return nil case <-h.ctx.Done(): return h.ctx.Err() @@ -258,6 +264,11 @@ func (h *SSEHandlerCh) WriteData(data string) error { } } +// WriteData queues data to be written in SSE format +func (h *SSEHandlerCh) WriteData(data string) error { + return h.queueMessage(SSEMessage{Type: SSEMsgData, Data: data}) +} + // WriteJsonData marshals data to JSON and queues it for writing func (h *SSEHandlerCh) WriteJsonData(data interface{}) error { jsonData, err := json.Marshal(data) @@ -282,63 +293,67 @@ func (h *SSEHandlerCh) WriteError(errorMsg string) error { // WriteEvent queues an SSE event with optional event type func (h *SSEHandlerCh) WriteEvent(eventType, data string) error { - h.mu.RLock() - closed := h.closed - h.mu.RUnlock() - - if closed { - return fmt.Errorf("SSE handler is closed") - } - - select { - case h.writeCh <- SSEMessage{Type: SSEMsgEvent, Data: data, EventType: eventType}: - return nil - case <-h.ctx.Done(): - return h.ctx.Err() - default: - return fmt.Errorf("write channel is full") - } + return h.queueMessage(SSEMessage{Type: SSEMsgEvent, Data: data, EventType: eventType}) } // WriteComment queues an SSE comment func (h *SSEHandlerCh) WriteComment(comment string) error { - h.mu.RLock() - closed := h.closed - h.mu.RUnlock() - - if closed { - return fmt.Errorf("SSE handler is closed") - } - - select { - case h.writeCh <- SSEMessage{Type: SSEMsgComment, Data: comment}: - return nil - case <-h.ctx.Done(): - return h.ctx.Err() - default: - return fmt.Errorf("write channel is full") - } + return h.queueMessage(SSEMessage{Type: SSEMsgComment, Data: comment}) } // Err returns any error that occurred during writing func (h *SSEHandlerCh) Err() error { - h.mu.RLock() - defer h.mu.RUnlock() + h.lock.Lock() + defer h.lock.Unlock() + if h.err == nil && h.ctx.Err() != nil { + h.err = h.ctx.Err() + } return h.err } +// RegisterOnClose registers a handler function to be called when the connection closes +// Returns an ID that can be used to unregister the handler +func (h *SSEHandlerCh) RegisterOnClose(fn func()) string { + h.lock.Lock() + defer h.lock.Unlock() + return h.onCloseHandlers.Register(fn) +} + +// UnregisterOnClose removes a previously registered onClose handler by ID +func (h *SSEHandlerCh) UnregisterOnClose(id string) { + h.lock.Lock() + defer h.lock.Unlock() + h.onCloseHandlers.Unregister(id) +} + +// runOnCloseHandlers runs all registered onClose handlers exactly once +func (h *SSEHandlerCh) runOnCloseHandlers() { + h.lock.Lock() + if h.handlersRun { + h.lock.Unlock() + return + } + h.handlersRun = true + h.lock.Unlock() + + handlers := h.onCloseHandlers.GetList() + for _, fn := range handlers { + fn() + } +} + // Close closes the write channel, sends [DONE], and cleans up resources func (h *SSEHandlerCh) Close() { - h.mu.Lock() + h.lock.Lock() if h.closed || !h.initialized { - h.mu.Unlock() + h.lock.Unlock() return } h.closed = true // Close the write channel, which will trigger [DONE] in writerLoop close(h.writeCh) - h.mu.Unlock() + h.lock.Unlock() // Wait for writer goroutine to finish (without holding the lock) h.wg.Wait() @@ -461,7 +476,6 @@ func (h *SSEHandlerCh) AiMsgError(errText string) error { return h.WriteJsonData(resp) } - func (h *SSEHandlerCh) AiMsgData(dataType string, id string, data interface{}) error { if !strings.HasPrefix(dataType, "data-") { panic(fmt.Sprintf("AiMsgData type must start with 'data-', got: %s", dataType)) diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 8c1bc0ddbf..14ffb14779 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -58,6 +58,7 @@ const ( Command_RouteAnnounce = "routeannounce" // special (for routing) Command_RouteUnannounce = "routeunannounce" // special (for routing) Command_Message = "message" + Command_GetMeta = "getmeta" Command_SetMeta = "setmeta" Command_SetView = "setview" @@ -817,7 +818,6 @@ type CommandGetWaveAIChatData struct { type CommandWaveAIToolApproveData struct { ToolCallId string `json:"toolcallid"` - KeepAlive bool `json:"keepalive,omitempty"` Approval string `json:"approval,omitempty"` } diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 80a6b21fe3..cc4fc1d7ea 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -1257,7 +1257,7 @@ func (ws *WshServer) GetWaveAIRateLimitCommand(ctx context.Context) (*uctypes.Ra } func (ws *WshServer) WaveAIToolApproveCommand(ctx context.Context, data wshrpc.CommandWaveAIToolApproveData) error { - return aiusechat.UpdateToolApproval(data.ToolCallId, data.Approval, data.KeepAlive) + return aiusechat.UpdateToolApproval(data.ToolCallId, data.Approval) } func (ws *WshServer) WaveAIGetToolDiffCommand(ctx context.Context, data wshrpc.CommandWaveAIGetToolDiffData) (*wshrpc.CommandWaveAIGetToolDiffRtnData, error) {