Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 70 additions & 58 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sort"
"strconv"
"strings"
"time"

"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
Expand Down Expand Up @@ -83,37 +84,26 @@ var sseCmd = &cobra.Command{
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
url := args[0]
if verbose {
log.Printf("URL: %s", url)
}

headerStrings, _ := cmd.Flags().GetStringSlice("header")
var httpClient *http.Client
if len(headerStrings) > 0 {
headers := parseHeaders(headerStrings)
httpClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: headers,
},
}
}

ctx := context.Background()
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)

transport := &mcp.SSEClientTransport{Endpoint: url, HTTPClient: httpClient}
session, err := client.Connect(ctx, transport, nil)
if err != nil {
log.Fatalf("Failed to connect to SSE server: %v", err)
}
defer session.Close()

if verbose {
log.Println("Connected to SSE server")
connect := func() (*mcp.ClientSession, error) {
var httpClient *http.Client
if len(headerStrings) > 0 {
headers := parseHeaders(headerStrings)
httpClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: headers,
},
}
}
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)
transport := &mcp.SSEClientTransport{Endpoint: url, HTTPClient: httpClient}
return client.Connect(ctx, transport, nil)
}

handleSession(ctx, session)
runSessionWithReconnect(ctx, connect)
},
}

Expand All @@ -123,37 +113,26 @@ var httpCmd = &cobra.Command{
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
url := args[0]
if verbose {
log.Printf("URL: %s", url)
}

headerStrings, _ := cmd.Flags().GetStringSlice("header")
var httpClient *http.Client
if len(headerStrings) > 0 {
headers := parseHeaders(headerStrings)
httpClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: headers,
},
}
}

ctx := context.Background()
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)

transport := &mcp.StreamableClientTransport{Endpoint: url, HTTPClient: httpClient}
session, err := client.Connect(ctx, transport, nil)
if err != nil {
log.Fatalf("Failed to connect to streamable HTTP server: %v", err)
}
defer session.Close()

if verbose {
log.Println("Connected to streamable HTTP server")
connect := func() (*mcp.ClientSession, error) {
var httpClient *http.Client
if len(headerStrings) > 0 {
headers := parseHeaders(headerStrings)
httpClient = &http.Client{
Transport: &headerTransport{
base: http.DefaultTransport,
headers: headers,
},
}
}
client := mcp.NewClient(&mcp.Implementation{Name: "mcp-cli", Version: "v0.1.0"}, nil)
transport := &mcp.StreamableClientTransport{Endpoint: url, HTTPClient: httpClient}
return client.Connect(ctx, transport, nil)
}

handleSession(ctx, session)
runSessionWithReconnect(ctx, connect)
},
}

Expand Down Expand Up @@ -184,6 +163,31 @@ func parseHeaders(headerStrings []string) http.Header {
return headers
}

type connectFn func() (*mcp.ClientSession, error)

func runSessionWithReconnect(ctx context.Context, connect connectFn) {
for {
log.Println("Attempting to connect to server...")
session, err := connect()
if err != nil {
log.Printf("Failed to connect: %v. Retrying in 5 seconds...", err)
time.Sleep(5 * time.Second)
continue
}

log.Println("Connected to server.")
err = handleSession(ctx, session)
session.Close()

if err != nil {
log.Printf("Session ended with error: %v. Reconnecting...", err)
} else {
log.Println("Session closed cleanly. Exiting.")
break
}
}
}

// -- Bubble Tea TUI -----------------------------------------------------------

type viewState int
Expand Down Expand Up @@ -384,7 +388,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case toolResult:
if msg.err != nil {
m.err = msg.err
return m, nil
return m, tea.Quit
}
if verbose {
m.logf("Tool result received")
Expand All @@ -396,7 +400,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case resourceResult:
if msg.err != nil {
m.err = msg.err
return m, nil
return m, tea.Quit
}
if verbose {
m.logf("Resource result received")
Expand Down Expand Up @@ -778,7 +782,7 @@ func (m *AppModel) readResourceCmd() tea.Cmd {
}
}

func handleSession(ctx context.Context, session *mcp.ClientSession) {
func handleSession(ctx context.Context, session *mcp.ClientSession) error {
if verbose {
f, err := tea.LogToFile("debug.log", "debug")
if err != nil {
Expand All @@ -787,10 +791,18 @@ func handleSession(ctx context.Context, session *mcp.ClientSession) {
}
defer f.Close()
}
p := tea.NewProgram(initialModel(ctx, session), tea.WithAltScreen(), tea.WithMouseCellMotion())
if _, err := p.Run(); err != nil {
log.Fatalf("Error running program: %v", err)
model := initialModel(ctx, session)
p := tea.NewProgram(model, tea.WithAltScreen(), tea.WithMouseCellMotion())
finalModel, err := p.Run()
if err != nil {
return fmt.Errorf("error running program: %w", err)
}

appModel, ok := finalModel.(*AppModel)
if !ok {
return fmt.Errorf("unexpected model type: %T", finalModel)

return appModel.err
}

func main() {
Expand Down
Loading