diff --git a/experimental/apps-mcp/cmd/install.go b/experimental/apps-mcp/cmd/install.go index b75e69224b..8ff4528eb5 100644 --- a/experimental/apps-mcp/cmd/install.go +++ b/experimental/apps-mcp/cmd/install.go @@ -2,13 +2,24 @@ package mcp import ( "context" + "errors" "fmt" "os" + "slices" + "strings" "time" "github.com/databricks/cli/experimental/apps-mcp/lib/agents" + "github.com/databricks/cli/experimental/apps-mcp/lib/middlewares" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/service/sql" "github.com/fatih/color" + "github.com/manifoldco/promptui" "github.com/spf13/cobra" ) @@ -18,14 +29,20 @@ func newInstallCmd() *cobra.Command { Short: "Install the Apps MCP server in coding agents", Long: `Install the Databricks Apps MCP server in coding agents like Claude Code and Cursor.`, RunE: func(cmd *cobra.Command, args []string) error { - return runInstall(cmd.Context()) + return runInstall(cmd) }, } + cmd.Flags().StringP("profile", "p", "", "~/.databrickscfg profile") + cmd.RegisterFlagCompletionFunc("profile", profile.ProfileCompletion) + cmd.Flags().StringP("warehouse-id", "w", "", "Databricks SQL warehouse ID") + cmd.Flags().StringSliceP("agent", "a", []string{}, "Agents to install the MCP server for (valid values: claude, cursor)") + return cmd } -func runInstall(ctx context.Context) error { +func runInstall(cmd *cobra.Command) error { + ctx := cmd.Context() cmdio.LogString(ctx, "") green := color.New(color.FgGreen).SprintFunc() cmdio.LogString(ctx, " "+green("[")+"████████"+green("]")+" Databricks Experimental Apps MCP") @@ -39,18 +56,57 @@ func runInstall(ctx context.Context) error { cmdio.LogString(ctx, yellow("╚════════════════════════════════════════════════════════════════╝")) cmdio.LogString(ctx, "") - cmdio.LogString(ctx, "Which coding agents would you like to install the MCP server for?") + // Check for profile configuration + selectedProfile, err := selectProfile(cmd) + if err != nil { + return err + } + cmdio.LogString(ctx, "") + cmdio.LogString(ctx, fmt.Sprintf("Using profile: %s (%s)", color.CyanString(selectedProfile.Name), selectedProfile.Host)) - anySuccess := false + warehouse, err := selectAndValidateWarehouse(ctx, cmd.Flag("warehouse-id").Value.String(), selectedProfile) + if err != nil { + return err + } + cmdio.LogString(ctx, fmt.Sprintf("Using warehouse: %s (%s)", color.CyanString(warehouse.Name), warehouse.Id)) + cmdio.LogString(ctx, "") - ans, err := cmdio.AskSelect(ctx, "Install for Claude Code?", []string{"yes", "no"}) + // Check if --agent flag is set + requestedAgents, err := cmd.Flags().GetStringSlice("agent") if err != nil { return err } - if ans == "yes" { + + // Normalize and validate agent names + for i, agent := range requestedAgents { + agent = strings.TrimSpace(strings.ToLower(agent)) + requestedAgents[i] = agent + if agent != "" && agent != "claude" && agent != "cursor" { + return fmt.Errorf("invalid agent %q. Valid agents are: claude, cursor", agent) + } + } + + anySuccess := false + + // Install for Claude Code + installClaude := false + if len(requestedAgents) > 0 { + installClaude = slices.Contains(requestedAgents, "claude") + } else { + // Prompt the user + cmdio.LogString(ctx, "Which coding agents would you like to install the MCP server for?") + cmdio.LogString(ctx, "") + ans, err := cmdio.AskSelect(ctx, "Install for Claude Code?", []string{"yes", "no"}) + if err != nil { + return err + } + installClaude = ans == "yes" + } + + if installClaude { fmt.Fprint(os.Stderr, "Installing MCP server for Claude Code...") - if err := agents.InstallClaude(); err != nil { + if err := agents.InstallClaude(selectedProfile, warehouse.Id); err != nil { fmt.Fprint(os.Stderr, "\r"+color.YellowString("⊘ Skipped Claude Code: "+err.Error())+"\n") } else { fmt.Fprint(os.Stderr, "\r"+color.GreenString("✓ Installed for Claude Code")+" \n") @@ -59,13 +115,22 @@ func runInstall(ctx context.Context) error { cmdio.LogString(ctx, "") } - ans, err = cmdio.AskSelect(ctx, "Install for Cursor?", []string{"yes", "no"}) - if err != nil { - return err + // Install for Cursor + installCursor := false + if len(requestedAgents) > 0 { + installCursor = slices.Contains(requestedAgents, "cursor") + } else { + // Prompt the user + ans, err := cmdio.AskSelect(ctx, "Install for Cursor?", []string{"yes", "no"}) + if err != nil { + return err + } + installCursor = ans == "yes" } - if ans == "yes" { + + if installCursor { fmt.Fprint(os.Stderr, "Installing MCP server for Cursor...") - if err := agents.InstallCursor(); err != nil { + if err := agents.InstallCursor(selectedProfile, warehouse.Id); err != nil { fmt.Fprint(os.Stderr, "\r"+color.YellowString("⊘ Skipped Cursor: "+err.Error())+"\n") } else { // Brief delay so users see the "Installing..." message before it's replaced @@ -76,14 +141,17 @@ func runInstall(ctx context.Context) error { cmdio.LogString(ctx, "") } - ans, err = cmdio.AskSelect(ctx, "Show manual installation instructions for other agents?", []string{"yes", "no"}) - if err != nil { - return err - } - if ans == "yes" { - if err := agents.ShowCustomInstructions(ctx); err != nil { + // Only show custom instructions if no agents were specified or installed + if len(requestedAgents) == 0 { + ans, err := cmdio.AskSelect(ctx, "Show manual installation instructions for other agents?", []string{"yes", "no"}) + if err != nil { return err } + if ans == "yes" { + if err := agents.ShowCustomInstructions(ctx, selectedProfile, warehouse.Id); err != nil { + return err + } + } } if anySuccess { @@ -95,3 +163,167 @@ func runInstall(ctx context.Context) error { return nil } + +func selectAndValidateWarehouse(ctx context.Context, warehouseIdFlag string, selectedProfile *profile.Profile) (*sql.EndpointInfo, error) { + w, err := databricks.NewWorkspaceClient(&databricks.Config{ + Profile: selectedProfile.Name, + }) + if err != nil { + return nil, err + } + + var warehouse *sql.EndpointInfo + if warehouseIdFlag != "" { + warehouseResponse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{ + Id: warehouseIdFlag, + }) + if err != nil { + return nil, fmt.Errorf("get warehouse: %w", err) + } + warehouse = &sql.EndpointInfo{ + Id: warehouseResponse.Id, + Name: warehouseResponse.Name, + State: warehouseResponse.State, + } + } else { + // Auto-detect warehouse + + clientCfg, err := config.HTTPClientConfigFromConfig(w.Config) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client config: %w", err) + } + apiClient := httpclient.NewApiClient(clientCfg) + warehouse, err = middlewares.GetDefaultWarehouse(ctx, apiClient) + if err != nil { + return nil, err + } + } + + if warehouse == nil { + return nil, errors.New("no warehouse found") + } + + // Validate warehouse connection with a simple query + _, err = w.StatementExecution.ExecuteAndWait(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouse.Id, + Statement: "SELECT 1", + WaitTimeout: "30s", + }) + if err != nil { + return nil, fmt.Errorf("failed to validate warehouse connection: %w", err) + } + + return warehouse, nil +} + +// selectProfile checks if a profile is available and prompts the user to select one if needed. +func selectProfile(cmd *cobra.Command) (*profile.Profile, error) { + ctx := cmd.Context() + profiler := profile.GetProfiler(ctx) + + // Load all workspace profiles + profiles, err := profiler.LoadProfiles(ctx, profile.MatchWorkspaceProfiles) + if err != nil { + return nil, fmt.Errorf("failed to load profiles: %w", err) + } + + // If no profiles are available, ask the user to login + if len(profiles) == 0 { + cmdio.LogString(ctx, color.RedString("No Databricks profiles found.")) + cmdio.LogString(ctx, "") + cmdio.LogString(ctx, "To authenticate, please run:") + cmdio.LogString(ctx, " "+color.YellowString("databricks auth login --host ")) + cmdio.LogString(ctx, "") + cmdio.LogString(ctx, "Then run this command again.") + return nil, errors.New("no profiles configured") + } + + // Check if --profile flag is set + profileFlag := cmd.Flag("profile") + if profileFlag != nil && profileFlag.Value.String() != "" { + requestedProfile := profileFlag.Value.String() + + // Find the requested profile + var found *profile.Profile + for i := range profiles { + if profiles[i].Name == requestedProfile { + found = &profiles[i] + break + } + } + + if found == nil { + return nil, fmt.Errorf("profile %q not found in ~/.databrickscfg. Run `databricks auth login -p %s` to create this profile and then run this command again", requestedProfile, requestedProfile) + } + + return found, nil + } + + // Get the current profile name from environment variable + currentProfileName := env.Get(ctx, "DATABRICKS_CONFIG_PROFILE") + if currentProfileName == "" { + currentProfileName = "DEFAULT" + } + + // Find the current profile in the list + var currentProfile *profile.Profile + for i := range profiles { + if profiles[i].Name == currentProfileName { + currentProfile = &profiles[i] + break + } + } + + // If a profile is already selected, show it and ask if they want to use it + if currentProfile != nil { + cmdio.LogString(ctx, "Current Databricks profile:") + cmdio.LogString(ctx, " Name: "+color.CyanString(currentProfile.Name)) + cmdio.LogString(ctx, " Host: "+color.CyanString(currentProfile.Host)) + cmdio.LogString(ctx, "") + + ans, err := cmdio.AskSelect(ctx, "Use this profile?", []string{"yes", "no"}) + if err != nil { + return nil, err + } + + if ans == "yes" { + return currentProfile, nil + } + } + + // User wants to select a different profile, or no current profile set + // Show all available profiles for selection + if len(profiles) == 1 { + // Only one profile available, use it + selectedProfile := profiles[0] + cmdio.LogString(ctx, fmt.Sprintf("Using profile: %s (%s)", color.CyanString(selectedProfile.Name), selectedProfile.Host)) + cmdio.LogString(ctx, "") + cmdio.LogString(ctx, "Set this profile by running:") + cmdio.LogString(ctx, " "+color.YellowString("export DATABRICKS_CONFIG_PROFILE="+selectedProfile.Name)) + return &selectedProfile, nil + } + + cmdio.LogString(ctx, "Which Databricks profile would you like to use with the MCP server?") + cmdio.LogString(ctx, "(You can change the profile later by running this install command again)") + cmdio.LogString(ctx, "") + + // Multiple profiles available, let the user select + i, _, err := cmdio.RunSelect(ctx, &promptui.Select{ + Label: "Select a Databricks profile", + Items: profiles, + Searcher: profiles.SearchCaseInsensitive, + StartInSearchMode: true, + Templates: &promptui.SelectTemplates{ + Label: "{{ . | faint }}", + Active: `{{.Name | bold}} ({{.Host|faint}})`, + Inactive: `{{.Name}} ({{.Host}})`, + Selected: `{{ "Selected profile" | faint }}: {{ .Name | bold }}`, + }, + }) + if err != nil { + return nil, err + } + + selectedProfile := profiles[i] + return &selectedProfile, nil +} diff --git a/experimental/apps-mcp/lib/agents/claude.go b/experimental/apps-mcp/lib/agents/claude.go index 83492800e2..1138075947 100644 --- a/experimental/apps-mcp/lib/agents/claude.go +++ b/experimental/apps-mcp/lib/agents/claude.go @@ -5,6 +5,8 @@ import ( "fmt" "os" "os/exec" + + "github.com/databricks/cli/libs/databrickscfg/profile" ) // DetectClaude checks if Claude Code CLI is installed and available on PATH. @@ -14,7 +16,7 @@ func DetectClaude() bool { } // InstallClaude installs the Databricks MCP server in Claude Code. -func InstallClaude() error { +func InstallClaude(profile *profile.Profile, warehouseID string) error { if !DetectClaude() { return errors.New("claude Code CLI is not installed or not on PATH\n\nPlease install Claude Code and ensure 'claude' is available on your system PATH.\nFor installation instructions, visit: https://docs.anthropic.com/en/docs/claude-code") } @@ -27,12 +29,19 @@ func InstallClaude() error { removeCmd := exec.Command("claude", "mcp", "remove", "--scope", "user", "databricks-mcp") _ = removeCmd.Run() - cmd := exec.Command("claude", "mcp", "add", + args := []string{ + "mcp", "add", "--scope", "user", "--transport", "stdio", "databricks-mcp", - "--", - databricksPath, "experimental", "apps-mcp") + "--env", "DATABRICKS_CONFIG_PROFILE=" + profile.Name, + "--env", "DATABRICKS_HOST=" + profile.Host, + "--env", "DATABRICKS_WAREHOUSE_ID=" + warehouseID, + } + + args = append(args, "--", databricksPath, "experimental", "apps-mcp") + + cmd := exec.Command("claude", args...) output, err := cmd.CombinedOutput() if err != nil { diff --git a/experimental/apps-mcp/lib/agents/cursor.go b/experimental/apps-mcp/lib/agents/cursor.go index 5bd005133b..9d880c87af 100644 --- a/experimental/apps-mcp/lib/agents/cursor.go +++ b/experimental/apps-mcp/lib/agents/cursor.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "runtime" + + "github.com/databricks/cli/libs/databrickscfg/profile" ) type cursorConfig struct { @@ -47,7 +49,7 @@ func DetectCursor() bool { } // InstallCursor installs the Databricks MCP server in Cursor. -func InstallCursor() error { +func InstallCursor(profile *profile.Profile, warehouseID string) error { configPath, err := getCursorConfigPath() if err != nil { return fmt.Errorf("failed to determine Cursor config path: %w", err) @@ -81,10 +83,18 @@ func InstallCursor() error { return fmt.Errorf("failed to determine Databricks path: %w", err) } + // Build environment variables + envVars := map[string]string{ + "DATABRICKS_CONFIG_PROFILE": profile.Name, + "DATABRICKS_HOST": profile.Host, + "DATABRICKS_WAREHOUSE_ID": warehouseID, + } + // Add or update the Databricks MCP server entry config.McpServers["databricks-mcp"] = mcpServer{ Command: databricksPath, Args: []string{"experimental", "apps-mcp"}, + Env: envVars, } // Write back to file with pretty printing diff --git a/experimental/apps-mcp/lib/agents/custom.go b/experimental/apps-mcp/lib/agents/custom.go index dfc450e595..1b5ea17e89 100644 --- a/experimental/apps-mcp/lib/agents/custom.go +++ b/experimental/apps-mcp/lib/agents/custom.go @@ -2,12 +2,14 @@ package agents import ( "context" + "fmt" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg/profile" ) // ShowCustomInstructions displays instructions for manually installing the MCP server. -func ShowCustomInstructions(ctx context.Context) error { +func ShowCustomInstructions(ctx context.Context, profile *profile.Profile, warehouseID string) error { instructions := ` To install the Databricks CLI MCP server in your coding agent: @@ -21,11 +23,16 @@ Example MCP server configuration: "databricks": { "command": "databricks", "args": ["experimental", "apps-mcp"] + "env": { + "DATABRICKS_CONFIG_PROFILE": "%s", + "DATABRICKS_HOST": "%s", + "DATABRICKS_WAREHOUSE_ID": "%s" + } } } } ` - cmdio.LogString(ctx, instructions) + cmdio.LogString(ctx, fmt.Sprintf(instructions, profile.Name, profile.Host, warehouseID)) _, err := cmdio.Ask(ctx, "Press Enter to continue", "") if err != nil { diff --git a/experimental/apps-mcp/lib/middlewares/databricks_client.go b/experimental/apps-mcp/lib/middlewares/databricks_client.go index b7a9bf7146..e952c68508 100644 --- a/experimental/apps-mcp/lib/middlewares/databricks_client.go +++ b/experimental/apps-mcp/lib/middlewares/databricks_client.go @@ -15,7 +15,8 @@ import ( ) const ( - DatabricksClientKey = "databricks_client" + DatabricksClientKey = "databricks_client" + DatabricksProfileKey = "databricks_profile" ) func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middleware { @@ -40,6 +41,18 @@ func NewDatabricksClientMiddleware(unauthorizedToolNames []string) mcp.Middlewar }) } +func MustGetDatabricksProfile(ctx context.Context) string { + sess, err := session.GetSession(ctx) + if err != nil { + panic(err) + } + profile, ok := sess.Get(DatabricksProfileKey) + if !ok { + return "" + } + return profile.(string) +} + func MustGetApiClient(ctx context.Context) (*httpclient.ApiClient, error) { w := MustGetDatabricksClient(ctx) clientCfg, err := config.HTTPClientConfigFromConfig(w.Config) diff --git a/experimental/apps-mcp/lib/middlewares/warehouse.go b/experimental/apps-mcp/lib/middlewares/warehouse.go index 680b26d3ef..4766ea7f2e 100644 --- a/experimental/apps-mcp/lib/middlewares/warehouse.go +++ b/experimental/apps-mcp/lib/middlewares/warehouse.go @@ -33,7 +33,12 @@ func loadWarehouseInBackground(ctx context.Context) { defer wg.Done() - warehouse, err := getDefaultWarehouse(ctx) + apiClient, err := MustGetApiClient(ctx) + if err != nil { + return + } + + warehouse, err := GetDefaultWarehouse(ctx, apiClient) if err != nil { sess.Set(warehouseErrorKey, err) return @@ -64,7 +69,12 @@ func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) { warehouse, ok := sess.Get("warehouse_endpoint") if !ok { // Fallback: synchronously load if background loading didn't happen - warehouse, err = getDefaultWarehouse(ctx) + apiClient, err := MustGetApiClient(ctx) + if err != nil { + return nil, err + } + + warehouse, err = GetDefaultWarehouse(ctx, apiClient) if err != nil { return nil, err } @@ -82,7 +92,7 @@ func GetWarehouseID(ctx context.Context) (string, error) { return warehouse.Id, nil } -func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { +func GetDefaultWarehouse(ctx context.Context, apiClient *httpclient.ApiClient) (*sql.EndpointInfo, error) { // first resolve DATABRICKS_WAREHOUSE_ID env variable warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID") if warehouseID != "" { @@ -100,18 +110,13 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { }, nil } - apiClient, err := MustGetApiClient(ctx) - if err != nil { - return nil, err - } - apiPath := "/api/2.0/sql/warehouses" params := url.Values{} params.Add("skip_cannot_use", "true") fullPath := fmt.Sprintf("%s?%s", apiPath, params.Encode()) var response sql.ListWarehousesResponse - err = apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response)) + err := apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response)) if err != nil { return nil, err } diff --git a/experimental/apps-mcp/lib/providers/clitools/configure_auth.go b/experimental/apps-mcp/lib/providers/clitools/configure_auth.go index 00a40713cf..a152de3e3f 100644 --- a/experimental/apps-mcp/lib/providers/clitools/configure_auth.go +++ b/experimental/apps-mcp/lib/providers/clitools/configure_auth.go @@ -55,6 +55,10 @@ func ConfigureAuth(ctx context.Context, sess *session.Session, host, profile *st // Store client in session data sess.Set(middlewares.DatabricksClientKey, client) + if profile == nil { + sess.Set(middlewares.DatabricksProfileKey, client.Config.Profile) + } + return client, nil } diff --git a/experimental/apps-mcp/lib/providers/clitools/invoke_databricks_cli.go b/experimental/apps-mcp/lib/providers/clitools/invoke_databricks_cli.go index 7906c43339..ac4a7ec81a 100644 --- a/experimental/apps-mcp/lib/providers/clitools/invoke_databricks_cli.go +++ b/experimental/apps-mcp/lib/providers/clitools/invoke_databricks_cli.go @@ -25,7 +25,14 @@ func InvokeDatabricksCLI(ctx context.Context, command []string, workingDirectory cmd := exec.CommandContext(ctx, cliPath, command...) cmd.Dir = workingDirectory env := os.Environ() - env = append(env, "DATABRICKS_HOST="+host) + + profile := middlewares.MustGetDatabricksProfile(ctx) + if profile != "" { + env = append(env, "DATABRICKS_CONFIG_PROFILE="+profile) + } + if host != "" { + env = append(env, "DATABRICKS_HOST="+host) + } cmd.Env = env output, err := cmd.CombinedOutput()