diff --git a/pkg/web/web.go b/pkg/web/web.go index c5571c694b..28bcccd5a3 100644 --- a/pkg/web/web.go +++ b/pkg/web/web.go @@ -5,6 +5,7 @@ package web import ( "bytes" + "context" "encoding/base64" "encoding/json" "fmt" @@ -254,7 +255,7 @@ func handleRemoteStreamFile(w http.ResponseWriter, req *http.Request, conn strin return handleRemoteStreamFileFromCh(w, req, path, rtnCh, rpcOpts.StreamCancelFn, no404) } -func handleRemoteStreamFileFromCh(w http.ResponseWriter, req *http.Request, path string, rtnCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData], streamCancelFn func(), no404 bool) error { +func handleRemoteStreamFileFromCh(w http.ResponseWriter, req *http.Request, path string, rtnCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData], streamCancelFn func(context.Context) error, no404 bool) error { firstPk := true var fileInfo *wshrpc.FileInfo loopDone := false @@ -270,7 +271,9 @@ func handleRemoteStreamFileFromCh(w http.ResponseWriter, req *http.Request, path select { case <-ctx.Done(): if streamCancelFn != nil { - streamCancelFn() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + streamCancelFn(ctx) } return ctx.Err() case respUnion, ok := <-rtnCh: diff --git a/pkg/wshrpc/wshclient/wshclientutil.go b/pkg/wshrpc/wshclient/wshclientutil.go index 327466b9ff..52d311c0a9 100644 --- a/pkg/wshrpc/wshclient/wshclientutil.go +++ b/pkg/wshrpc/wshclient/wshclientutil.go @@ -4,6 +4,7 @@ package wshclient import ( + "context" "errors" "github.com/wavetermdev/waveterm/pkg/panichandler" @@ -62,9 +63,8 @@ func sendRpcRequestResponseStreamHelper[T any](w *wshutil.WshRpc, command string rtnErr(respChan, err) return respChan } - opts.StreamCancelFn = func() { - // TODO coordinate the cancel with the for loop below - reqHandler.SendCancel() + opts.StreamCancelFn = func(ctx context.Context) error { + return reqHandler.SendCancel(ctx) } go func() { defer func() { diff --git a/pkg/wshutil/wshadapter.go b/pkg/wshutil/wshadapter.go index f1760381fc..22667dbfe2 100644 --- a/pkg/wshutil/wshadapter.go +++ b/pkg/wshutil/wshadapter.go @@ -96,7 +96,7 @@ func serverImplAdapter(impl any) func(*RpcResponseHandler) bool { } rmethod := findCmdMethod(impl, cmd) if rmethod == nil { - if !handler.NeedsResponse() { + if !handler.NeedsResponse() && cmd != wshrpc.Command_Message { // we also send an out of band message here since this is likely unexpected and will require debugging handler.SendMessage(fmt.Sprintf("command %q method %q not found", handler.GetCommand(), methodDecl.MethodName)) } diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index acf5632c3b..80c6a038a6 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -169,26 +169,26 @@ func (router *WshRouter) getRouteInfo(rpcId string) *routeInfo { } func (router *WshRouter) handleAnnounceMessage(msg RpcMessage, input msgAndRoute) { - // if we have an upstream, send it there - // if we don't (we are the terminal router), then add it to our announced route map + if msg.Source != input.fromRouteId { + router.Lock.Lock() + router.AnnouncedRoutes[msg.Source] = input.fromRouteId + router.Lock.Unlock() + } upstream := router.GetUpstreamClient() if upstream != nil { upstream.SendRpcMessage(input.msgBytes, "announce-upstream") - return } - if msg.Source == input.fromRouteId { - // not necessary to save the id mapping - return - } - router.Lock.Lock() - defer router.Lock.Unlock() - router.AnnouncedRoutes[msg.Source] = input.fromRouteId } -func (router *WshRouter) handleUnannounceMessage(msg RpcMessage) { +func (router *WshRouter) handleUnannounceMessage(msg RpcMessage, input msgAndRoute) { router.Lock.Lock() - defer router.Lock.Unlock() delete(router.AnnouncedRoutes, msg.Source) + router.Lock.Unlock() + + upstream := router.GetUpstreamClient() + if upstream != nil { + upstream.SendRpcMessage(input.msgBytes, "unannounce-upstream") + } } func (router *WshRouter) getAnnouncedRoute(routeId string) string { @@ -204,21 +204,21 @@ func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string) bool rpc.SendRpcMessage(msgBytes, "route") return true } + localRouteId := router.getAnnouncedRoute(routeId) + if localRouteId != "" { + rpc := router.GetRpc(localRouteId) + if rpc != nil { + rpc.SendRpcMessage(msgBytes, "route-local") + return true + } + } upstream := router.GetUpstreamClient() if upstream != nil { upstream.SendRpcMessage(msgBytes, "route-upstream") return true - } else { - // we are the upstream, so consult our announced routes map - localRouteId := router.getAnnouncedRoute(routeId) - rpc := router.GetRpc(localRouteId) - if rpc == nil { - log.Printf("[router] no rpc for route id %q\n", routeId) - return false - } - rpc.SendRpcMessage(msgBytes, "route-local") - return true } + log.Printf("[router] no rpc for route id %q\n", routeId) + return false } func (router *WshRouter) runServer() { @@ -236,7 +236,7 @@ func (router *WshRouter) runServer() { continue } if msg.Command == wshrpc.Command_RouteUnannounce { - router.handleUnannounceMessage(msg) + router.handleUnannounceMessage(msg, input) continue } if msg.Command != "" { @@ -353,14 +353,22 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, sh func (router *WshRouter) UnregisterRoute(routeId string) { log.Printf("[router] unregistering wsh route %q\n", routeId) router.Lock.Lock() - defer router.Lock.Unlock() delete(router.RouteMap, routeId) // clear out announced routes - for routeId, localRouteId := range router.AnnouncedRoutes { + for announcedRouteId, localRouteId := range router.AnnouncedRoutes { if localRouteId == routeId { - delete(router.AnnouncedRoutes, routeId) + delete(router.AnnouncedRoutes, announcedRouteId) } } + upstream := router.UpstreamClient + router.Lock.Unlock() + + if upstream != nil { + unannounceMsg := RpcMessage{Command: wshrpc.Command_RouteUnannounce, Source: routeId} + unannounceBytes, _ := json.Marshal(unannounceMsg) + upstream.SendRpcMessage(unannounceBytes, "route-unannounce") + } + go func() { defer func() { panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover()) diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index 8a90a0790e..ebfca5cc9f 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -313,7 +313,9 @@ func (w *WshRpc) handleRequestInternal(req *RpcMessage, pprofCtx context.Context } respHandler.contextCancelFn.Store(&cancelFn) respHandler.ctx = withRespHandler(ctx, respHandler) - w.registerResponseHandler(req.ReqId, respHandler) + if req.ReqId != "" { + w.registerResponseHandler(req.ReqId, respHandler) + } isAsync := false defer func() { panicErr := panichandler.PanicHandler("handleRequest", recover()) @@ -502,7 +504,7 @@ func (handler *RpcRequestHandler) Context() context.Context { return handler.ctx } -func (handler *RpcRequestHandler) SendCancel() { +func (handler *RpcRequestHandler) SendCancel(ctx context.Context) error { defer func() { panichandler.PanicHandler("SendCancel", recover()) }() @@ -512,8 +514,14 @@ func (handler *RpcRequestHandler) SendCancel() { AuthToken: handler.w.GetAuthToken(), } barr, _ := json.Marshal(msg) // will never fail - handler.w.OutputCh <- barr - handler.finalize() + select { + case handler.w.OutputCh <- barr: + handler.finalize() + return nil + case <-ctx.Done(): + handler.finalize() + return fmt.Errorf("timeout sending cancel") + } } func (handler *RpcRequestHandler) ResponseDone() bool { @@ -607,24 +615,28 @@ func (handler *RpcResponseHandler) SendMessage(msg string) { Message: msg, }, AuthToken: handler.w.GetAuthToken(), + Route: handler.source, // send back to source } msgBytes, _ := json.Marshal(rpcMsg) // will never fail - handler.w.OutputCh <- msgBytes + select { + case handler.w.OutputCh <- msgBytes: + case <-handler.ctx.Done(): + } } func (handler *RpcResponseHandler) SendResponse(data any, done bool) error { defer func() { panichandler.PanicHandler("SendResponse", recover()) }() - if handler.reqId == "" { - return nil // no response expected - } if handler.done.Load() { return fmt.Errorf("request already done, cannot send additional response") } if done { defer handler.close() } + if handler.reqId == "" { + return nil + } msg := &RpcMessage{ ResId: handler.reqId, Data: data, @@ -635,25 +647,35 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error { if err != nil { return err } - handler.w.OutputCh <- barr - return nil + select { + case handler.w.OutputCh <- barr: + return nil + case <-handler.ctx.Done(): + return fmt.Errorf("timeout sending response") + } } func (handler *RpcResponseHandler) SendResponseError(err error) { defer func() { panichandler.PanicHandler("SendResponseError", recover()) }() - if handler.reqId == "" || handler.done.Load() { + if handler.done.Load() { return } defer handler.close() + if handler.reqId == "" { + return + } msg := &RpcMessage{ ResId: handler.reqId, Error: err.Error(), AuthToken: handler.w.GetAuthToken(), } barr, _ := json.Marshal(msg) // will never fail - handler.w.OutputCh <- barr + select { + case handler.w.OutputCh <- barr: + case <-handler.ctx.Done(): + } } func (handler *RpcResponseHandler) IsCanceled() bool { @@ -675,11 +697,11 @@ func (handler *RpcResponseHandler) Finalize() { if handler.reqId != "" { handler.w.unregisterResponseHandler(handler.reqId) } - if handler.reqId == "" || handler.done.Load() { + if handler.done.Load() { return } + // SendResponse with done=true will call close() via defer, even when reqId is empty handler.SendResponse(nil, true) - handler.close() } func (handler *RpcResponseHandler) IsDone() bool { @@ -726,8 +748,13 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp return nil, err } handler.respCh = w.registerRpc(handler, command, opts.Route, handler.reqId) - w.OutputCh <- barr - return handler, nil + select { + case w.OutputCh <- barr: + return handler, nil + case <-handler.ctx.Done(): + handler.finalize() + return nil, fmt.Errorf("timeout sending request") + } } func (w *WshRpc) IsServerDone() bool {