diff --git a/main.go b/main.go index 6c91776..987c3ca 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "sort" "strconv" "strings" + "time" "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/list" @@ -83,37 +84,26 @@ var sseCmd = &cobra.Command{ Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { url := args[0] - if verbose { - log.Printf("URL: %s", url) - } - headerStrings, _ := cmd.Flags().GetStringSlice("header") - var httpClient *http.Client - if len(headerStrings) > 0 { - headers := parseHeaders(headerStrings) - httpClient = &http.Client{ - Transport: &headerTransport{ - base: http.DefaultTransport, - headers: headers, - }, - } - } - ctx := context.Background() - client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil) - transport := &mcp.SSEClientTransport{Endpoint: url, HTTPClient: httpClient} - session, err := client.Connect(ctx, transport, nil) - if err != nil { - log.Fatalf("Failed to connect to SSE server: %v", err) - } - defer session.Close() - - if verbose { - log.Println("Connected to SSE server") + connect := func() (*mcp.ClientSession, error) { + var httpClient *http.Client + if len(headerStrings) > 0 { + headers := parseHeaders(headerStrings) + httpClient = &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: headers, + }, + } + } + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil) + transport := &mcp.SSEClientTransport{Endpoint: url, HTTPClient: httpClient} + return client.Connect(ctx, transport, nil) } - handleSession(ctx, session) + runSessionWithReconnect(ctx, connect) }, } @@ -123,37 +113,26 @@ var httpCmd = &cobra.Command{ Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { url := args[0] - if verbose { - log.Printf("URL: %s", url) - } - headerStrings, _ := cmd.Flags().GetStringSlice("header") - var httpClient *http.Client - if len(headerStrings) > 0 { - headers := parseHeaders(headerStrings) - httpClient = &http.Client{ - Transport: &headerTransport{ - base: http.DefaultTransport, - headers: headers, - }, - } - } - ctx := context.Background() - client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil) - transport := &mcp.StreamableClientTransport{Endpoint: url, HTTPClient: httpClient} - session, err := client.Connect(ctx, transport, nil) - if err != nil { - log.Fatalf("Failed to connect to streamable HTTP server: %v", err) - } - defer session.Close() - - if verbose { - log.Println("Connected to streamable HTTP server") + connect := func() (*mcp.ClientSession, error) { + var httpClient *http.Client + if len(headerStrings) > 0 { + headers := parseHeaders(headerStrings) + httpClient = &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: headers, + }, + } + } + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil) + transport := &mcp.StreamableClientTransport{Endpoint: url, HTTPClient: httpClient} + return client.Connect(ctx, transport, nil) } - handleSession(ctx, session) + runSessionWithReconnect(ctx, connect) }, } @@ -184,6 +163,31 @@ func parseHeaders(headerStrings []string) http.Header { return headers } +type connectFn func() (*mcp.ClientSession, error) + +func runSessionWithReconnect(ctx context.Context, connect connectFn) { + for { + log.Println("Attempting to connect to server...") + session, err := connect() + if err != nil { + log.Printf("Failed to connect: %v. Retrying in 5 seconds...", err) + time.Sleep(5 * time.Second) + continue + } + + log.Println("Connected to server.") + err = handleSession(ctx, session) + session.Close() + + if err != nil { + log.Printf("Session ended with error: %v. Reconnecting...", err) + } else { + log.Println("Session closed cleanly. Exiting.") + break + } + } +} + // -- Bubble Tea TUI ----------------------------------------------------------- type viewState int @@ -384,7 +388,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case toolResult: if msg.err != nil { m.err = msg.err - return m, nil + return m, tea.Quit } if verbose { m.logf("Tool result received") @@ -396,7 +400,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case resourceResult: if msg.err != nil { m.err = msg.err - return m, nil + return m, tea.Quit } if verbose { m.logf("Resource result received") @@ -778,7 +782,7 @@ func (m *AppModel) readResourceCmd() tea.Cmd { } } -func handleSession(ctx context.Context, session *mcp.ClientSession) { +func handleSession(ctx context.Context, session *mcp.ClientSession) error { if verbose { f, err := tea.LogToFile("debug.log", "debug") if err != nil { @@ -787,10 +791,18 @@ func handleSession(ctx context.Context, session *mcp.ClientSession) { } defer f.Close() } - p := tea.NewProgram(initialModel(ctx, session), tea.WithAltScreen(), tea.WithMouseCellMotion()) - if _, err := p.Run(); err != nil { - log.Fatalf("Error running program: %v", err) + model := initialModel(ctx, session) + p := tea.NewProgram(model, tea.WithAltScreen(), tea.WithMouseCellMotion()) + finalModel, err := p.Run() + if err != nil { + return fmt.Errorf("error running program: %w", err) } + + appModel, ok := finalModel.(*AppModel) + if !ok { + return fmt.Errorf("unexpected model type: %T", finalModel) + + return appModel.err } func main() {