diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 1ef2babfd..20131c66b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,7 +11,7 @@ jobs: build: strategy: matrix: - os: [ubuntu-latest, macos-13, windows-latest] + os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} environment: CI steps: diff --git a/main/main.go b/main/main.go index afc174703..ebc21121d 100644 --- a/main/main.go +++ b/main/main.go @@ -26,6 +26,7 @@ import ( "github.com/aliyun/aliyun-cli/v3/config" go_migrate "github.com/aliyun/aliyun-cli/v3/go-migrate" "github.com/aliyun/aliyun-cli/v3/i18n" + "github.com/aliyun/aliyun-cli/v3/mcpproxy" "github.com/aliyun/aliyun-cli/v3/openapi" "github.com/aliyun/aliyun-cli/v3/oss/lib" "github.com/aliyun/aliyun-cli/v3/ossutil" @@ -76,6 +77,8 @@ func Main(args []string) { rootCmd.AddSubCommand(lib.NewOssCommand()) rootCmd.AddSubCommand(cli.NewVersionCommand()) rootCmd.AddSubCommand(cli.NewAutoCompleteCommand()) + // mcp proxy command + rootCmd.AddSubCommand(mcpproxy.NewMCPProxyCommand()) // go v1 to v2 migrate command rootCmd.AddSubCommand(go_migrate.NewGoMigrateCommand()) // new oss command diff --git a/mcpproxy/command.go b/mcpproxy/command.go new file mode 100644 index 000000000..a73d135e2 --- /dev/null +++ b/mcpproxy/command.go @@ -0,0 +1,260 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "encoding/json" + "fmt" + "net/url" + "os" + "os/signal" + "strconv" + "syscall" + + "github.com/aliyun/aliyun-cli/v3/cli" + "github.com/aliyun/aliyun-cli/v3/i18n" +) + +func NewMCPProxyCommand() *cli.Command { + cmd := &cli.Command{ + Name: "mcp-proxy", + Short: i18n.T("Start MCP server proxy", "启动 MCP 服务器代理"), + Long: i18n.T( + "Start a local proxy server for Aliyun API MCP Servers. "+ + "The proxy handles OAuth authentication automatically, "+ + "allowing MCP clients to connect without managing credentials.", + "启动阿里云 API MCP Server 的本地代理服务。"+ + "代理自动处理 OAuth 认证,"+ + "允许 MCP 客户端无需管理凭证即可连接。", + ), + Usage: "aliyun mcp-proxy [--port PORT] [--host HOST] [--region-type REGION_TYPE] [--upstream-url URL] [--oauth-app-name NAME]", + Sample: "aliyun mcp-proxy --region-type CN --port 8088", + Run: func(ctx *cli.Context, args []string) error { + return runMCPProxy(ctx) + }, + } + + cmd.Flags().Add(&cli.Flag{ + Name: "port", + DefaultValue: "8088", + Short: i18n.T( + "Proxy server port", + "代理服务器端口", + ), + }) + + cmd.Flags().Add(&cli.Flag{ + Name: "host", + DefaultValue: "127.0.0.1", + Short: i18n.T( + "Proxy server host (use 0.0.0.0 to listen on all interfaces)", + "代理服务器地址 (使用 0.0.0.0 监听所有网络接口)", + ), + }) + + cmd.Flags().Add(&cli.Flag{ + Name: "region-type", + DefaultValue: "CN", + Short: i18n.T( + "Region type: CN or INTL", + "地域类型: CN 或 INTL", + ), + }) + + cmd.Flags().Add(&cli.Flag{ + Name: "no-browser", + Short: i18n.T( + "Disable automatic browser opening. Use manual code input mode instead", + "使用手动输入授权码模式,不自动打开浏览器", + ), + }) + + cmd.Flags().Add(&cli.Flag{ + Name: "scope", + DefaultValue: "/acs/mcp-server", + Short: i18n.T( + "OAuth predefined scope (default: /acs/mcp-server)", + "OAuth 预定义权限范围(默认: /acs/mcp-server)", + ), + }) + + cmd.Flags().Add(&cli.Flag{ + Name: "upstream-url", + Short: i18n.T( + "Custom upstream MCP server URL (overrides EndpointMap configuration)", + "自定义上游 MCP 服务器地址(覆盖 EndpointMap 配置)", + ), + }) + + cmd.Flags().Add(&cli.Flag{ + Name: "oauth-app-name", + Short: i18n.T( + "Use existing OAuth application by name (for users without create permission)", + "使用已存在的 OAuth 应用名称(适用于没有创建权限的用户)", + ), + }) + + return cmd +} + +// ProxyConfig 封装了启动 MCP Proxy 所需的所有配置参数 +type StartProxyConfig struct { + McpProfile *McpProfile + RegionType RegionType + Host string + Port int + NoBrowser bool + Scope string + UpstreamURL string +} + +func runMCPProxy(ctx *cli.Context) error { + portStr := ctx.Flags().Get("port").GetStringOrDefault("8088") + host := ctx.Flags().Get("host").GetStringOrDefault("127.0.0.1") + regionStr := ctx.Flags().Get("region-type").GetStringOrDefault("CN") + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + var regionType RegionType + switch regionStr { + case "CN": + regionType = RegionCN + case "INTL": + regionType = RegionINTL + default: + return fmt.Errorf("invalid region type: %s, must be CN or INTL", regionStr) + } + + noBrowser := ctx.Flags().Get("no-browser").IsAssigned() + scope := ctx.Flags().Get("scope").GetStringOrDefault("/acs/mcp-server") + upstreamURL := ctx.Flags().Get("upstream-url").GetStringOrDefault("") + oauthAppName := ctx.Flags().Get("oauth-app-name").GetStringOrDefault("") + + proxyConfig := ProxyConfig{ + Host: host, + Port: port, + RegionType: regionType, + Scope: scope, + AutoOpenBrowser: !noBrowser, + UpstreamBaseURL: upstreamURL, + OAuthAppName: oauthAppName, + } + + mcpProfile, err := getOrCreateMCPProfile(ctx, proxyConfig) + if err != nil { + return err + } + proxyConfig.McpProfile = mcpProfile + return startMCPProxy(ctx, proxyConfig) +} + +func startMCPProxy(ctx *cli.Context, config ProxyConfig) error { + servers, err := ListMCPServers(ctx, config.RegionType) + if err != nil { + return fmt.Errorf("failed to list MCP servers: %w", err) + } + + if len(servers) == 0 { + return fmt.Errorf("no MCP servers found") + } + + config.CallbackManager = NewOAuthCallbackManager() + config.ExistMcpServers = servers + + proxy := NewMCPProxy(config) + go proxy.TokenRefresher.Start() + + printProxyInfo(ctx, proxy) + + // 设置信号处理,捕获 Ctrl+C (SIGINT) 和 SIGTERM + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + // 在 goroutine 中启动服务器 + serverErrChan := make(chan error, 1) + go func() { + if err := proxy.Start(); err != nil { + serverErrChan <- err + } + }() + + // 等待信号、服务器错误或致命错误 + select { + case sig := <-sigChan: + cli.Printf(ctx.Stdout(), "\nReceived signal: %v, shutting down gracefully...\n", sig) + if proxy.TokenRefresher != nil { + proxy.TokenRefresher.Stop() + } + if err := proxy.Stop(); err != nil { + // 如果是超时错误,记录日志但不返回错误,因为服务器已经关闭 + cli.Printf(ctx.Stderr(), "Warning: %v\n", err) + } + cli.Println(ctx.Stdout(), "MCP Proxy stopped successfully") + return nil + case err := <-serverErrChan: + return err + case fatalErr := <-proxy.TokenRefresher.fatalErrCh: + cli.Printf(ctx.Stderr(), "\nFatal error: %v\n", fatalErr) + cli.Printf(ctx.Stdout(), "Shutting down gracefully...\n") + if proxy.TokenRefresher != nil { + proxy.TokenRefresher.Stop() + } + if err := proxy.Stop(); err != nil { + cli.Printf(ctx.Stderr(), "Warning: %v\n", err) + } + return fatalErr + } +} + +func printProxyInfo(ctx *cli.Context, proxy *MCPProxy) { + cli.Printf(ctx.Stdout(), "\nMCP Proxy Server Started\nListen: %s:%d\nRegion: %s\n", + proxy.Host, proxy.Port, proxy.RegionType) + + cli.Println(ctx.Stdout(), "\nAvailable Servers:") + for _, server := range proxy.ExistMcpServers { + cli.Printf(ctx.Stdout(), " - %s\n", server.Name) + if server.Urls.MCP != "" { + if upstreamURL, err := url.Parse(server.Urls.MCP); err == nil { + cli.Printf(ctx.Stdout(), " MCP: http://%s:%d%s\n", proxy.Host, proxy.Port, upstreamURL.Path) + } + } + if server.Urls.SSE != "" { + if upstreamURL, err := url.Parse(server.Urls.SSE); err == nil { + cli.Printf(ctx.Stdout(), " SSE: http://%s:%d%s\n", proxy.Host, proxy.Port, upstreamURL.Path) + } + } + } + + cli.Println(ctx.Stdout(), "\nPress Ctrl+C to stop") +} + +func GetContentFromApiResponse(response map[string]any) ([]byte, error) { + responseBody := response["body"] + if responseBody == nil { + return nil, fmt.Errorf("response body is nil") + } + switch v := responseBody.(type) { + case string: + return []byte(v), nil + case map[string]any, []any: + jsonData, _ := json.Marshal(v) + return jsonData, nil + case []byte: + return v, nil + default: + return []byte(fmt.Sprintf("%v", v)), nil + } +} diff --git a/mcpproxy/command_test.go b/mcpproxy/command_test.go new file mode 100644 index 000000000..ee9e3e8a7 --- /dev/null +++ b/mcpproxy/command_test.go @@ -0,0 +1,298 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "bytes" + "os" + "testing" + + "github.com/aliyun/aliyun-cli/v3/cli" + "github.com/stretchr/testify/assert" +) + +func TestNewMCPProxyCommand(t *testing.T) { + cmd := NewMCPProxyCommand() + assert.NotNil(t, cmd) + assert.Equal(t, "mcp-proxy", cmd.Name) + assert.NotEmpty(t, cmd.Short) + assert.NotEmpty(t, cmd.Long) + assert.NotEmpty(t, cmd.Usage) + assert.NotNil(t, cmd.Run) + + // 检查标志 + flags := cmd.Flags() + assert.NotNil(t, flags) + + portFlag := flags.Get("port") + assert.NotNil(t, portFlag) + assert.Equal(t, "8088", portFlag.DefaultValue) + + hostFlag := flags.Get("host") + assert.NotNil(t, hostFlag) + assert.Equal(t, "127.0.0.1", hostFlag.DefaultValue) + + regionFlag := flags.Get("region-type") + assert.NotNil(t, regionFlag) + assert.Equal(t, "CN", regionFlag.DefaultValue) + + scopeFlag := flags.Get("scope") + assert.NotNil(t, scopeFlag) + assert.Equal(t, "/acs/mcp-server", scopeFlag.DefaultValue) +} + +func TestGetContentFromApiResponse(t *testing.T) { + tests := []struct { + name string + response map[string]any + wantErr bool + validate func(t *testing.T, result []byte) + }{ + { + name: "string body", + response: map[string]any{ + "body": "test string", + }, + wantErr: false, + validate: func(t *testing.T, result []byte) { + assert.Equal(t, "test string", string(result)) + }, + }, + { + name: "map body", + response: map[string]any{ + "body": map[string]any{ + "key": "value", + }, + }, + wantErr: false, + validate: func(t *testing.T, result []byte) { + assert.Contains(t, string(result), "key") + assert.Contains(t, string(result), "value") + }, + }, + { + name: "array body", + response: map[string]any{ + "body": []any{"item1", "item2"}, + }, + wantErr: false, + validate: func(t *testing.T, result []byte) { + assert.Contains(t, string(result), "item1") + assert.Contains(t, string(result), "item2") + }, + }, + { + name: "byte slice body", + response: map[string]any{ + "body": []byte("test bytes"), + }, + wantErr: false, + validate: func(t *testing.T, result []byte) { + assert.Equal(t, "test bytes", string(result)) + }, + }, + { + name: "int body", + response: map[string]any{ + "body": 123, + }, + wantErr: false, + validate: func(t *testing.T, result []byte) { + assert.Contains(t, string(result), "123") + }, + }, + { + name: "nil body", + response: map[string]any{ + "body": nil, + }, + wantErr: true, + validate: func(t *testing.T, result []byte) { + assert.Nil(t, result) + }, + }, + { + name: "no body key", + response: map[string]any{ + "other": "value", + }, + wantErr: true, + validate: func(t *testing.T, result []byte) { + assert.Nil(t, result) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := GetContentFromApiResponse(tt.response) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + } + if tt.validate != nil { + tt.validate(t, result) + } + }) + } +} + +func TestRunMCPProxy_InvalidPort(t *testing.T) { + // 保存原始环境变量 + originalHome := os.Getenv("HOME") + originalIgnoreProfile := os.Getenv("ALIBABA_CLOUD_IGNORE_PROFILE") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + if originalIgnoreProfile != "" { + os.Setenv("ALIBABA_CLOUD_IGNORE_PROFILE", originalIgnoreProfile) + } else { + os.Unsetenv("ALIBABA_CLOUD_IGNORE_PROFILE") + } + }() + + // 设置临时目录和忽略配置文件,确保不会使用真实账号 + tmpDir := t.TempDir() + os.Setenv("HOME", tmpDir) + os.Setenv("ALIBABA_CLOUD_IGNORE_PROFILE", "TRUE") + + ctx := cli.NewCommandContext(bytes.NewBuffer(nil), bytes.NewBuffer(nil)) + portFlag := &cli.Flag{ + Name: "port", + DefaultValue: "invalid", + } + portFlag.SetAssigned(true) + portFlag.SetValue("invalid") + ctx.Flags().Add(portFlag) + ctx.Flags().Add(&cli.Flag{ + Name: "host", + DefaultValue: "127.0.0.1", + }) + ctx.Flags().Add(&cli.Flag{ + Name: "region-type", + DefaultValue: "CN", + }) + + err := runMCPProxy(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid port") +} + +func TestRunMCPProxy_InvalidRegionType(t *testing.T) { + originalHome := os.Getenv("HOME") + originalIgnoreProfile := os.Getenv("ALIBABA_CLOUD_IGNORE_PROFILE") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + if originalIgnoreProfile != "" { + os.Setenv("ALIBABA_CLOUD_IGNORE_PROFILE", originalIgnoreProfile) + } else { + os.Unsetenv("ALIBABA_CLOUD_IGNORE_PROFILE") + } + }() + + // 设置临时目录和忽略配置文件,确保不会使用真实账号 + tmpDir := t.TempDir() + os.Setenv("HOME", tmpDir) + os.Setenv("ALIBABA_CLOUD_IGNORE_PROFILE", "TRUE") + + ctx := cli.NewCommandContext(bytes.NewBuffer(nil), bytes.NewBuffer(nil)) + portFlag := &cli.Flag{ + Name: "port", + DefaultValue: "8088", + } + portFlag.SetAssigned(true) + portFlag.SetValue("8088") + ctx.Flags().Add(portFlag) + ctx.Flags().Add(&cli.Flag{ + Name: "host", + DefaultValue: "127.0.0.1", + }) + regionFlag := &cli.Flag{ + Name: "region-type", + DefaultValue: "INVALID", + } + regionFlag.SetAssigned(true) + regionFlag.SetValue("INVALID") + ctx.Flags().Add(regionFlag) + + err := runMCPProxy(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid region type") +} + +func TestRunMCPProxy_ValidRegionTypes(t *testing.T) { + // 保存原始环境变量 + originalHome := os.Getenv("HOME") + originalIgnoreProfile := os.Getenv("ALIBABA_CLOUD_IGNORE_PROFILE") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + if originalIgnoreProfile != "" { + os.Setenv("ALIBABA_CLOUD_IGNORE_PROFILE", originalIgnoreProfile) + } else { + os.Unsetenv("ALIBABA_CLOUD_IGNORE_PROFILE") + } + }() + + regionTypes := []string{"CN", "INTL"} + for _, regionType := range regionTypes { + t.Run(regionType, func(t *testing.T) { + // 设置临时目录和忽略配置文件,确保不会使用真实账号 + tmpDir := t.TempDir() + os.Setenv("HOME", tmpDir) + os.Setenv("ALIBABA_CLOUD_IGNORE_PROFILE", "TRUE") + + ctx := cli.NewCommandContext(bytes.NewBuffer(nil), bytes.NewBuffer(nil)) + portFlag := &cli.Flag{ + Name: "port", + DefaultValue: "8088", + } + portFlag.SetAssigned(true) + portFlag.SetValue("8088") + ctx.Flags().Add(portFlag) + ctx.Flags().Add(&cli.Flag{ + Name: "host", + DefaultValue: "127.0.0.1", + }) + regionFlag := &cli.Flag{ + Name: "region-type", + DefaultValue: regionType, + } + regionFlag.SetAssigned(true) + regionFlag.SetValue(regionType) + ctx.Flags().Add(regionFlag) + + err := runMCPProxy(ctx) + // 预期会因为缺少profile或配置而失败 + assert.Error(t, err) + // 验证错误不是 "invalid region type",说明 region type 解析成功 + assert.NotContains(t, err.Error(), "invalid region type") + assert.Contains(t, err.Error(), "failed to load profile") + }) + } +} diff --git a/mcpproxy/mcp_profile.go b/mcpproxy/mcp_profile.go new file mode 100644 index 000000000..bf0becbe0 --- /dev/null +++ b/mcpproxy/mcp_profile.go @@ -0,0 +1,168 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + + "github.com/aliyun/aliyun-cli/v3/cli" + "github.com/aliyun/aliyun-cli/v3/config" + "github.com/aliyun/aliyun-cli/v3/util" +) + +type McpProfile struct { + Name string `json:"name"` + MCPOAuthAppName string `json:"mcp_oauth_app_name,omitempty"` + MCPOAuthAppId string `json:"mcp_oauth_app_id,omitempty"` + MCPOAuthSiteType string `json:"mcp_oauth_site_type,omitempty"` // CN or INTL + MCPOAuthAccessToken string `json:"mcp_oauth_access_token,omitempty"` + MCPOAuthRefreshToken string `json:"mcp_oauth_refresh_token,omitempty"` + MCPOAuthAccessTokenValidity int `json:"mcp_oauth_access_token_validity,omitempty"` + MCPOAuthAccessTokenExpire int64 `json:"mcp_oauth_access_token_expire,omitempty"` + MCPOAuthRefreshTokenValidity int `json:"mcp_oauth_refresh_token_validity,omitempty"` + MCPOAuthRefreshTokenExpire int64 `json:"mcp_oauth_refresh_token_expire,omitempty"` +} + +func getMCPConfigPath() string { + return config.GetConfigPath() + "/.mcpproxy_config" +} + +func NewMcpProfile(name string) *McpProfile { + return &McpProfile{Name: name} +} + +func NewMcpProfileFromBytes(bytes []byte) (profile *McpProfile, err error) { + profile = &McpProfile{} + err = json.Unmarshal(bytes, profile) + if err != nil { + return nil, err + } + return profile, nil +} + +func saveMcpProfile(profile *McpProfile) error { + mcpConfigPath := getMCPConfigPath() + dir := filepath.Dir(mcpConfigPath) + + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create config directory %q: %w", dir, err) + } + + tempFile := mcpConfigPath + ".tmp" + + bytes, err := json.MarshalIndent(profile, "", "\t") + if err != nil { + return fmt.Errorf("failed to marshal profile: %w", err) + } + + if err := os.WriteFile(tempFile, bytes, 0600); err != nil { + return fmt.Errorf("failed to write temp file %q: %w", tempFile, err) + } + + // 原子性地重命名临时文件为目标文件, 避免因各种系统异常直接损坏原文件 + if err := os.Rename(tempFile, mcpConfigPath); err != nil { + _ = os.Remove(tempFile) + return fmt.Errorf("failed to rename temp file to %q: %w", mcpConfigPath, err) + } + return nil +} + +func getOrCreateMCPProfile(ctx *cli.Context, opts ProxyConfig) (*McpProfile, error) { + profile, err := config.LoadProfileWithContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to load profile: %w", err) + } + + // 如果传入了 oauth-app-name,先验证该应用是否存在且合法 + // 如果已经验证过 oauth-app-name,直接使用验证过的 app;否则查找或创建默认的 OAuth 应用 + var validatedApp *OAuthApplication + if opts.OAuthAppName != "" { + app, err := findOAuthApplicationByName(ctx, profile, opts.RegionType, opts.OAuthAppName) + if err != nil { + return nil, fmt.Errorf("failed to find OAuth application '%s': %w", opts.OAuthAppName, err) + } + if app == nil { + return nil, fmt.Errorf("OAuth application '%s' not found", opts.OAuthAppName) + } + + // 验证 Scopes 和 Callback URI + requiredRedirectURI := buildRedirectUri(opts.Host, opts.Port) + if err := validateOAuthApplication(app, opts.Scope, requiredRedirectURI); err != nil { + return nil, fmt.Errorf("OAuth application validation failed: %w", err) + } + + validatedApp = app + cli.Printf(ctx.Stdout(), "Using existing OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId) + } else { + // 查找或创建默认的 OAuth 应用 + mcpConfigPath := getMCPConfigPath() + if bytes, err := os.ReadFile(mcpConfigPath); err == nil { + if mcpProfile, err := NewMcpProfileFromBytes(bytes); err == nil { + log.Println("MCP Profile loaded from file", mcpProfile.Name, "app id", mcpProfile.MCPOAuthAppId) + + // 检查 region type 是否匹配,因为国内和国际站的 OAuth 地址不同, Region type 不匹配则重新创建 profile + if mcpProfile.MCPOAuthSiteType != string(opts.RegionType) { + log.Printf("Region type mismatch: saved=%s, requested=%s, recreating profile", mcpProfile.MCPOAuthSiteType, string(opts.RegionType)) + } else { + err = findOAuthApplicationById(ctx, profile, mcpProfile, opts.RegionType) + if err == nil { + return mcpProfile, nil + } else { + log.Println("Failed to find existing OAuth application", err.Error()) + } + } + } + } + app, err := getOrCreateMCPOAuthApplication(ctx, profile, opts.RegionType, opts.Host, opts.Port, opts.Scope) + if err != nil { + return nil, fmt.Errorf("failed to get or create OAuth application: %w", err) + } + validatedApp = app + } + + cli.Printf(ctx.Stdout(), "Setting up MCPOAuth profile '%s'...\n", DefaultMcpProfileName) + + mcpProfile := NewMcpProfile(DefaultMcpProfileName) + mcpProfile.MCPOAuthSiteType = string(opts.RegionType) + mcpProfile.MCPOAuthAppId = validatedApp.ApplicationId + mcpProfile.MCPOAuthAppName = validatedApp.AppName + // 刷新 token 接口不返回 refresh token 有效期,所以直接在这里设置 + currentTime := util.GetCurrentUnixTime() + mcpProfile.MCPOAuthAccessTokenValidity = validatedApp.AccessTokenValidity + mcpProfile.MCPOAuthRefreshTokenValidity = validatedApp.RefreshTokenValidity + + // noBrowser=true 表示禁用自动打开浏览器,autoOpenBrowser=false + // noBrowser=false 表示启用自动打开浏览器,autoOpenBrowser=true + tokenResult, err := startMCPOAuthFlow(ctx, mcpProfile.MCPOAuthAppId, opts.RegionType, opts.Host, opts.Port, opts.AutoOpenBrowser, opts.Scope) + if err != nil { + return nil, fmt.Errorf("OAuth login failed: %w", err) + } + mcpProfile.MCPOAuthAccessToken = tokenResult.AccessToken + mcpProfile.MCPOAuthRefreshToken = tokenResult.RefreshToken + mcpProfile.MCPOAuthAccessTokenExpire = tokenResult.AccessTokenExpire + mcpProfile.MCPOAuthRefreshTokenExpire = currentTime + int64(validatedApp.RefreshTokenValidity) + + if err = saveMcpProfile(mcpProfile); err != nil { + return nil, fmt.Errorf("failed to save mcp profile: %w", err) + } + + cli.Printf(ctx.Stdout(), "MCP Profile '%s' configured for oauth app '%s' successfully!\n", mcpProfile.Name, mcpProfile.MCPOAuthAppName) + + return mcpProfile, nil +} diff --git a/mcpproxy/mcp_profile_test.go b/mcpproxy/mcp_profile_test.go new file mode 100644 index 000000000..205f8c025 --- /dev/null +++ b/mcpproxy/mcp_profile_test.go @@ -0,0 +1,290 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewMcpProfile(t *testing.T) { + name := "test-profile" + profile := NewMcpProfile(name) + + assert.NotNil(t, profile) + assert.Equal(t, name, profile.Name) + assert.Empty(t, profile.MCPOAuthAccessToken) + assert.Empty(t, profile.MCPOAuthRefreshToken) + assert.Zero(t, profile.MCPOAuthAccessTokenExpire) + assert.Zero(t, profile.MCPOAuthRefreshTokenValidity) + assert.Zero(t, profile.MCPOAuthRefreshTokenExpire) + assert.Empty(t, profile.MCPOAuthSiteType) + assert.Empty(t, profile.MCPOAuthAppId) +} + +func TestNewMcpProfileFromBytes(t *testing.T) { + tests := []struct { + name string + jsonBytes []byte + wantErr bool + validate func(t *testing.T, profile *McpProfile) + }{ + { + name: "valid profile", + jsonBytes: []byte(`{ + "name": "test-profile", + "mcp_oauth_access_token": "test-access-token", + "mcp_oauth_refresh_token": "test-refresh-token", + "mcp_oauth_access_token_expire": 1234567890, + "mcp_oauth_refresh_token_validity": 31536000, + "mcp_oauth_refresh_token_expire": 1234567890, + "mcp_oauth_site_type": "CN", + "mcp_oauth_app_id": "test-app-id" + }`), + wantErr: false, + validate: func(t *testing.T, profile *McpProfile) { + assert.Equal(t, "test-profile", profile.Name) + assert.Equal(t, "test-access-token", profile.MCPOAuthAccessToken) + assert.Equal(t, "test-refresh-token", profile.MCPOAuthRefreshToken) + assert.Equal(t, int64(1234567890), profile.MCPOAuthAccessTokenExpire) + assert.Equal(t, 31536000, profile.MCPOAuthRefreshTokenValidity) + assert.Equal(t, int64(1234567890), profile.MCPOAuthRefreshTokenExpire) + assert.Equal(t, "CN", profile.MCPOAuthSiteType) + assert.Equal(t, "test-app-id", profile.MCPOAuthAppId) + }, + }, + { + name: "minimal profile", + jsonBytes: []byte(`{ + "name": "minimal-profile" + }`), + wantErr: false, + validate: func(t *testing.T, profile *McpProfile) { + assert.Equal(t, "minimal-profile", profile.Name) + }, + }, + { + name: "invalid json", + jsonBytes: []byte(`{invalid json}`), + wantErr: true, + validate: func(t *testing.T, profile *McpProfile) { + // JSON 解析失败时,profile 应该为 nil + assert.Nil(t, profile) + }, + }, + { + name: "empty bytes", + jsonBytes: []byte{}, + wantErr: true, + validate: func(t *testing.T, profile *McpProfile) { + // JSON 解析失败时,profile 应该为 nil + assert.Nil(t, profile) + }, + }, + { + name: "empty name remains empty", + jsonBytes: []byte(`{ + "mcp_oauth_access_token": "test-token" + }`), + wantErr: false, + validate: func(t *testing.T, profile *McpProfile) { + assert.NotNil(t, profile) + assert.Empty(t, profile.Name) + assert.Equal(t, "test-token", profile.MCPOAuthAccessToken) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile, err := NewMcpProfileFromBytes(tt.jsonBytes) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, profile) + } else { + assert.NoError(t, err) + assert.NotNil(t, profile) + } + if tt.validate != nil { + tt.validate(t, profile) + } + }) + } +} + +func TestSaveMcpProfile(t *testing.T) { + tmpDir := t.TempDir() + + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + + // 设置 HOME 环境变量指向临时目录,这样 GetConfigPath() 会返回 tmpDir/.aliyun + os.Setenv("HOME", tmpDir) + + configPath := getMCPConfigPath() + testConfigPath := filepath.Join(tmpDir, ".aliyun", ".mcpproxy_config") + assert.Equal(t, testConfigPath, configPath) + + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessToken = "test-token" + profile.MCPOAuthAppId = "test-app-id" + profile.MCPOAuthSiteType = "CN" + + err := saveMcpProfile(profile) + assert.NoError(t, err) + + // 验证文件是否存在 + _, err = os.Stat(testConfigPath) + assert.NoError(t, err) + + // 读取并验证内容 + bytes, err := os.ReadFile(testConfigPath) + assert.NoError(t, err) + + var loadedProfile McpProfile + err = json.Unmarshal(bytes, &loadedProfile) + assert.NoError(t, err) + assert.Equal(t, profile.Name, loadedProfile.Name) + assert.Equal(t, profile.MCPOAuthAccessToken, loadedProfile.MCPOAuthAccessToken) + assert.Equal(t, profile.MCPOAuthAppId, loadedProfile.MCPOAuthAppId) + assert.Equal(t, profile.MCPOAuthSiteType, loadedProfile.MCPOAuthSiteType) +} + +func TestMcpProfileRegionType(t *testing.T) { + t.Run("region type is saved and loaded", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + + os.Setenv("HOME", tmpDir) + + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = string(RegionCN) + profile.MCPOAuthAppId = "test-app-id" + + err := saveMcpProfile(profile) + assert.NoError(t, err) + + configPath := getMCPConfigPath() + bytes, err := os.ReadFile(configPath) + assert.NoError(t, err) + + var loadedProfile McpProfile + err = json.Unmarshal(bytes, &loadedProfile) + assert.NoError(t, err) + assert.Equal(t, string(RegionCN), loadedProfile.MCPOAuthSiteType) + }) + + t.Run("region type comparison", func(t *testing.T) { + tests := []struct { + name string + savedRegion string + requestedRegion RegionType + shouldMatch bool + }{ + { + name: "CN matches CN", + savedRegion: "CN", + requestedRegion: RegionCN, + shouldMatch: true, + }, + { + name: "INTL matches INTL", + savedRegion: "INTL", + requestedRegion: RegionINTL, + shouldMatch: true, + }, + { + name: "CN does not match INTL", + savedRegion: "CN", + requestedRegion: RegionINTL, + shouldMatch: false, + }, + { + name: "INTL does not match CN", + savedRegion: "INTL", + requestedRegion: RegionCN, + shouldMatch: false, + }, + { + name: "empty does not match", + savedRegion: "", + requestedRegion: RegionCN, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &McpProfile{ + MCPOAuthSiteType: tt.savedRegion, + } + + // 模拟对比逻辑:region type 必须存在且匹配 + matches := profile.MCPOAuthSiteType != "" && profile.MCPOAuthSiteType == string(tt.requestedRegion) + assert.Equal(t, tt.shouldMatch, matches) + }) + } + }) +} + +func TestMcpProfileJSONSerialization(t *testing.T) { + profile := &McpProfile{ + Name: "test-profile", + MCPOAuthSiteType: "CN", + MCPOAuthAppId: "app-id", + MCPOAuthAppName: "app-name", + MCPOAuthAccessToken: "access-token", + MCPOAuthRefreshToken: "refresh-token", + MCPOAuthAccessTokenExpire: 1234567890, + MCPOAuthAccessTokenValidity: 10800, + MCPOAuthRefreshTokenValidity: 31536000, + MCPOAuthRefreshTokenExpire: 1234567890, + } + + jsonBytes, err := json.Marshal(profile) + assert.NoError(t, err) + assert.NotEmpty(t, jsonBytes) + + var loadedProfile McpProfile + err = json.Unmarshal(jsonBytes, &loadedProfile) + assert.NoError(t, err) + assert.Equal(t, profile.Name, loadedProfile.Name) + assert.Equal(t, profile.MCPOAuthAppName, loadedProfile.MCPOAuthAppName) + assert.Equal(t, profile.MCPOAuthAccessTokenValidity, loadedProfile.MCPOAuthAccessTokenValidity) + assert.Equal(t, profile.MCPOAuthAccessToken, loadedProfile.MCPOAuthAccessToken) + assert.Equal(t, profile.MCPOAuthRefreshToken, loadedProfile.MCPOAuthRefreshToken) + assert.Equal(t, profile.MCPOAuthAccessTokenExpire, loadedProfile.MCPOAuthAccessTokenExpire) + assert.Equal(t, profile.MCPOAuthRefreshTokenValidity, loadedProfile.MCPOAuthRefreshTokenValidity) + assert.Equal(t, profile.MCPOAuthRefreshTokenExpire, loadedProfile.MCPOAuthRefreshTokenExpire) + assert.Equal(t, profile.MCPOAuthSiteType, loadedProfile.MCPOAuthSiteType) + assert.Equal(t, profile.MCPOAuthAppId, loadedProfile.MCPOAuthAppId) +} diff --git a/mcpproxy/mcp_server.go b/mcpproxy/mcp_server.go new file mode 100644 index 000000000..07fb25532 --- /dev/null +++ b/mcpproxy/mcp_server.go @@ -0,0 +1,967 @@ +package mcpproxy + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + openapiClient "github.com/alibabacloud-go/darabonba-openapi/v2/client" + openapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils" + "github.com/alibabacloud-go/tea/dara" + "github.com/alibabacloud-go/tea/tea" + + "github.com/aliyun/aliyun-cli/v3/cli" + "github.com/aliyun/aliyun-cli/v3/config" + "github.com/aliyun/aliyun-cli/v3/util" +) + +type MCPInfoUrls struct { + SSE string `json:"sse"` + MCP string `json:"mcp"` +} + +type MCPServerInfo struct { + Id string `json:"id"` + Name string `json:"name"` + SourceType string `json:"sourceType"` + Product string `json:"product"` + Urls MCPInfoUrls `json:"urls"` +} + +type ListMCPServersResponse struct { + ApiMcpServers []MCPServerInfo `json:"apiMcpServers"` +} + +func ListMCPServers(ctx *cli.Context, regionType RegionType) ([]MCPServerInfo, error) { + profile, err := config.LoadProfileWithContext(ctx) + if err != nil { + return nil, err + } + client, err := newOpenAPIClient(ctx, profile, EndpointMap[regionType].MCP) + if err != nil { + return nil, err + } + params := &openapiClient.Params{ + Action: tea.String("ListApiMcpServers"), + Version: tea.String("2024-11-30"), + Protocol: tea.String("HTTPS"), + Method: tea.String("GET"), + AuthType: tea.String("AK"), + Style: tea.String("ROA"), + Pathname: tea.String("/apimcpservers"), + ReqBodyType: tea.String("json"), + BodyType: tea.String("json"), + } + runtime := &dara.RuntimeOptions{} + request := &openapiutil.OpenApiRequest{} + response, err := client.CallApi(params, request, runtime) + if err != nil { + return nil, err + } + bodyBytes, err := GetContentFromApiResponse(response) + if err != nil { + return nil, fmt.Errorf("failed to get content from api response: %w", err) + } + var responseList ListMCPServersResponse + if err := json.Unmarshal(bodyBytes, &responseList); err != nil { + return nil, err + } + return responseList.ApiMcpServers, nil +} + +type RuntimeStats struct { + StartTime time.Time + + TotalRequests int64 + SuccessRequests int64 + ErrorRequests int64 + ActiveRequests int64 + + TokenRefreshes int64 + TokenRefreshErrors int64 + LastTokenRefresh int64 + + // 启动时的内存状态 + InitialMemStats runtime.MemStats +} + +type ProxyConfig struct { + Host string + Port int + RegionType RegionType + Scope string + McpProfile *McpProfile + ExistMcpServers []MCPServerInfo + CallbackManager *OAuthCallbackManager + AutoOpenBrowser bool + UpstreamBaseURL string // 用户自定义的上游服务器地址,如果为空则使用 EndpointMap 配置 + OAuthAppName string // 用户自定义的 OAuth 应用名称,如果为空则使用默认的 OAuth 应用 +} + +type MCPProxy struct { + Host string + Port int + RegionType RegionType + Server *http.Server // 只会在 Start() 中赋值一次,如果程序改变执行流,则需要加锁保护 + ExistMcpServers []MCPServerInfo + TokenRefresher *TokenRefresher + stopCh chan struct{} + stats *RuntimeStats + UpstreamBaseURL string // 用户自定义的上游服务器地址,如果为空则使用 EndpointMap 配置 +} + +const ( + MaxSaveFailures = 3 + CheckInterval = 30 * time.Second + AccessTokenRefreshWindow = 7 * time.Minute // Access token 提前刷新窗口 + RefreshTokenRefreshWindow = 13 * time.Minute // Refresh token 提前重新授权窗口 + WaitForRefreshTimeout = 5 * time.Second + WaitForReauthorizationTimeout = 120 * time.Second +) + +type TokenInfo struct { + Token string + ExpiresAt int64 +} + +type TokenRefresher struct { + profile *McpProfile + host string // 代理主机 + port int // 代理端口 + regionType RegionType + scope string // OAuth scope + callbackManager *OAuthCallbackManager + mu sync.RWMutex // 保护刷新操作的读写锁 + refreshing bool // 标记是否正在刷新,防止重复刷新 + reauthorizing bool // 标记是否正在重新授权,防止重复重新授权 + autoOpenBrowser bool // 是否自动打开浏览器(false 表示手动输入 code 模式) + stopCh chan struct{} + tokenCh chan TokenInfo // 用于传递 token 的 channel + ticker *time.Ticker + fatalErrCh chan error // 用于通知致命错误的 channel + stats *RuntimeStats // 运行时统计信息(用于更新 token 刷新统计) +} + +func NewMCPProxy(config ProxyConfig) *MCPProxy { + stats := &RuntimeStats{ + StartTime: time.Now(), + } + // 记录启动时的内存状态 + runtime.ReadMemStats(&stats.InitialMemStats) + + return &MCPProxy{ + Host: config.Host, + Port: config.Port, + RegionType: config.RegionType, + ExistMcpServers: config.ExistMcpServers, + TokenRefresher: &TokenRefresher{ + profile: config.McpProfile, + regionType: config.RegionType, + callbackManager: config.CallbackManager, + host: config.Host, + port: config.Port, + scope: config.Scope, + autoOpenBrowser: config.AutoOpenBrowser, + stopCh: make(chan struct{}), + tokenCh: make(chan TokenInfo, 1), // 带缓冲的 channel,存储最新的 token + fatalErrCh: make(chan error, 1), + stats: stats, + }, + stopCh: make(chan struct{}), + stats: stats, + UpstreamBaseURL: config.UpstreamBaseURL, + } +} + +func (p *MCPProxy) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/callback", p.handleOAuthCallback) + mux.HandleFunc("/health", p.handleHealth) + mux.HandleFunc("/", p.ServeMCPProxyRequest) + + p.Server = &http.Server{ + Addr: fmt.Sprintf("%s:%d", p.Host, p.Port), + Handler: mux, + } + + log.Printf("MCP Proxy starting on %s:%d\n", p.Host, p.Port) + + if err := p.Server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return fmt.Errorf("proxy server failed: %w", err) + } + + return nil +} + +func (p *MCPProxy) Stop() error { + close(p.stopCh) + + if p.Server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := p.Server.Shutdown(ctx); err != nil { + // 如果优雅关闭超时,强制关闭 + if err == context.DeadlineExceeded { + log.Println("Graceful shutdown timeout, forcing close...") + return p.Server.Close() + } + return err + } + } + + return nil +} + +func (p *MCPProxy) handleOAuthCallback(w http.ResponseWriter, r *http.Request) { + showCode := !p.TokenRefresher.autoOpenBrowser + handleOAuthCallbackRequest(w, r, p.TokenRefresher.callbackManager.HandleCallback, showCode) +} + +func (p *MCPProxy) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + // 检查基本健康状态 + now := time.Now() + health := map[string]any{ + "status": "healthy", + "timestamp": now.Unix(), + "timestamp_iso": now.Format(time.RFC3339), + "uptime": time.Since(p.stats.StartTime).String(), + "uptime_seconds": time.Since(p.stats.StartTime).Seconds(), + } + + p.TokenRefresher.mu.RLock() + currentTime := util.GetCurrentUnixTime() + tokenExpired := p.TokenRefresher.profile.MCPOAuthAccessTokenExpire <= currentTime + tokenExpiresIn := p.TokenRefresher.profile.MCPOAuthAccessTokenExpire - currentTime + refreshTokenExpired := p.TokenRefresher.profile.MCPOAuthRefreshTokenExpire <= currentTime + refreshTokenExpiresIn := p.TokenRefresher.profile.MCPOAuthRefreshTokenExpire - currentTime + p.TokenRefresher.mu.RUnlock() + + if tokenExpired { + health["status"] = "degraded" + health["token_status"] = "expired" + } else { + health["token_status"] = "valid" + health["token_expires_in"] = tokenExpiresIn + health["token_expires_inh"] = time.Duration(tokenExpiresIn * int64(time.Second)).String() + } + + if refreshTokenExpired { + health["status"] = "degraded" + health["refresh_token_status"] = "expired" + } else { + health["refresh_token_status"] = "valid" + health["refresh_token_expires_in"] = refreshTokenExpiresIn + health["refresh_token_expires_inh"] = time.Duration(refreshTokenExpiresIn * int64(time.Second)).String() + } + + // 检查内存 + var m runtime.MemStats + runtime.ReadMemStats(&m) + + // 计算从启动到现在的内存增量 + initialMem := p.stats.InitialMemStats + allocDelta := int64(m.Alloc) - int64(initialMem.Alloc) + sysDelta := int64(m.Sys) - int64(initialMem.Sys) + heapAllocDelta := int64(m.HeapAlloc) - int64(initialMem.HeapAlloc) + heapInuseDelta := int64(m.HeapInuse) - int64(initialMem.HeapInuse) + + health["memory"] = map[string]interface{}{ + "alloc_mb": m.Alloc / 1024 / 1024, + "sys_mb": m.Sys / 1024 / 1024, + "heap_alloc_mb": m.HeapAlloc / 1024 / 1024, + "heap_inuse_mb": m.HeapInuse / 1024 / 1024, + "num_gc": m.NumGC, + "alloc_delta_mb": allocDelta / 1024 / 1024, + "sys_delta_mb": sysDelta / 1024 / 1024, + "heap_alloc_delta_mb": heapAllocDelta / 1024 / 1024, + "heap_inuse_delta_mb": heapInuseDelta / 1024 / 1024, + } + + // 内存使用超过 500MB 警告 + if m.Alloc > 500*1024*1024 { + health["status"] = "degraded" + health["memory_warning"] = "high memory usage" + } + + health["goroutines"] = runtime.NumGoroutine() + + health["requests"] = map[string]interface{}{ + "total": atomic.LoadInt64(&p.stats.TotalRequests), + "success": atomic.LoadInt64(&p.stats.SuccessRequests), + "error": atomic.LoadInt64(&p.stats.ErrorRequests), + "active": atomic.LoadInt64(&p.stats.ActiveRequests), + } + + tokenRefreshes := map[string]interface{}{ + "total": atomic.LoadInt64(&p.stats.TokenRefreshes), + "errors": atomic.LoadInt64(&p.stats.TokenRefreshErrors), + } + lastRefresh := atomic.LoadInt64(&p.stats.LastTokenRefresh) + if lastRefresh > 0 { + tokenRefreshes["last_refresh"] = lastRefresh + } + health["token_refreshes"] = tokenRefreshes + + statusCode := http.StatusOK + if health["status"] == "degraded" { + statusCode = http.StatusServiceUnavailable + } + + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(health) +} + +func (p *MCPProxy) ServeMCPProxyRequest(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&p.stats.TotalRequests, 1) + atomic.AddInt64(&p.stats.ActiveRequests, 1) + defer atomic.AddInt64(&p.stats.ActiveRequests, -1) + + // 过滤常见的浏览器请求,直接返回 404,不代理到上游 + path := r.URL.Path + if path == "/favicon.ico" || path == "/robots.txt" || path == "/apple-touch-icon.png" { + w.WriteHeader(http.StatusNotFound) + return + } + + // 检查是否正在关闭 + select { + case <-p.stopCh: + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) + return + default: + } + + accessToken, err := p.getMCPAccessToken() + if err != nil { + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + log.Println("MCP Proxy received request url", r.URL.String()) + + // 读取并保存请求 Body,以便在需要重试时使用 + var bodyBytes []byte + if r.Body != nil { + var err error + bodyBytes, err = io.ReadAll(r.Body) + if err != nil { + log.Println("MCP Proxy upstream request body read error", err.Error()) + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + _ = r.Body.Close() + log.Println("MCP Proxy upstream request body content", string(bodyBytes)) + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } else { + log.Println("MCP Proxy upstream request body ") + } + + sendRequest := func(token string) (*http.Response, error) { + upstreamReq, err := p.buildUpstreamRequest(r, token) + if err != nil { + return nil, fmt.Errorf("failed to build upstream request: %w", err) + } + + log.Println("MCP Proxy build upstream request url", upstreamReq.URL.String()) + + if len(bodyBytes) > 0 { + upstreamReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + client := &http.Client{Timeout: 0} + resp, err := client.Do(upstreamReq) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + return resp, nil + } + + // 第一次发送请求 + resp, err := sendRequest(accessToken) + if err != nil { + log.Println("MCP Proxy sends upstream request error", err.Error()) + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + log.Println("MCP Proxy gets mcp server response status code", resp.StatusCode) + + // 如果响应状态码为 401,先尝试刷新 token,然后重试请求 + if resp.StatusCode == http.StatusUnauthorized { + log.Println("MCP Proxy gets mcp server response status code 401, attempting to refresh token") + var refreshErr error + p.TokenRefresher.mu.RLock() + MCPOAuthRefreshTokenExpire := p.TokenRefresher.profile.MCPOAuthRefreshTokenExpire + currentTime := util.GetCurrentUnixTime() + p.TokenRefresher.mu.RUnlock() + if MCPOAuthRefreshTokenExpire > currentTime { + // refresh token 未过期,尝试刷新 access token + log.Println("Received 401, attempting to refresh access token using refresh token") + refreshErr = p.TokenRefresher.refreshAccessToken() + } else { + // refresh token 已过期,需要重新授权 + log.Println("Received 401, refresh token expired, reauthorizing") + refreshErr = p.TokenRefresher.reauthorizeWithProxy() + } + + if refreshErr != nil { + log.Printf("Failed to handle 401: %v", refreshErr) + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, fmt.Sprintf("Authentication failed: %v", refreshErr), http.StatusUnauthorized) + return + } + + log.Println("Token refreshed/reauthorized successfully, retrying request with new token") + + // 关闭第一次请求的响应 + resp.Body.Close() + + if len(bodyBytes) > 0 { + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + newAccessToken, err := p.getMCPAccessToken() + if err != nil { + log.Printf("Failed to get new access token after refresh: %v", err) + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + resp, err = sendRequest(newAccessToken) + if err != nil { + log.Printf("MCP Proxy retry request after token refresh error: %v", err) + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + log.Println("MCP Proxy retry request after token refresh gets mcp server response status code", resp.StatusCode) + + // 如果重试后还是 401,说明 token 刷新失败或服务器持续拒绝,不再重试 + if resp.StatusCode == http.StatusUnauthorized { + log.Printf("MCP Proxy retry request still returns 401, authentication failed") + atomic.AddInt64(&p.stats.ErrorRequests, 1) + http.Error(w, "Authentication failed after token refresh", http.StatusUnauthorized) + return + } + } + + log.Println("MCP Proxy gets mcp server response content type", resp.Header.Get("Content-Type")) + contentType := resp.Header.Get("Content-Type") + if strings.Contains(strings.ToLower(contentType), "text/event-stream") { + p.handleSSE(w, resp) + if resp.StatusCode < 400 { + atomic.AddInt64(&p.stats.SuccessRequests, 1) + } else { + atomic.AddInt64(&p.stats.ErrorRequests, 1) + } + return + } + + p.handleHTTP(w, resp) + if resp.StatusCode < 400 { + atomic.AddInt64(&p.stats.SuccessRequests, 1) + } else { + atomic.AddInt64(&p.stats.ErrorRequests, 1) + } + +} + +func (p *MCPProxy) getMCPAccessToken() (string, error) { + var tokenInfo TokenInfo + select { + case tokenInfo = <-p.TokenRefresher.tokenCh: + default: + // channel 为空,从 profile 读取(加读锁保护) + p.TokenRefresher.mu.RLock() + tokenInfo = TokenInfo{ + Token: p.TokenRefresher.profile.MCPOAuthAccessToken, + ExpiresAt: p.TokenRefresher.profile.MCPOAuthAccessTokenExpire, + } + p.TokenRefresher.mu.RUnlock() + } + + currentTime := util.GetCurrentUnixTime() + // 检查 token 是否过期 + if tokenInfo.ExpiresAt > currentTime { + // Token 有效,将 token 放回 channel(供其他 goroutine 使用) + select { + case p.TokenRefresher.tokenCh <- tokenInfo: + default: + // channel 已满,忽略(说明已经有最新的 token 在 channel 中) + } + return tokenInfo.Token, nil + } + + if err := p.TokenRefresher.ForceRefresh(); err != nil { + return "", fmt.Errorf("failed to refresh access token: %w", err) + } + + select { + case tokenInfo = <-p.TokenRefresher.tokenCh: + return tokenInfo.Token, nil + case <-time.After(5 * time.Second): + return "", fmt.Errorf("timeout waiting for refreshed token") + } +} + +func (p *MCPProxy) buildUpstreamRequest(r *http.Request, accessToken string) (*http.Request, error) { + var upstreamBaseURL string + if p.UpstreamBaseURL != "" { + // 如果用户传入了自定义的上游地址,使用用户传入的 + upstreamBaseURL = p.UpstreamBaseURL + // 如果用户传入的地址没有协议前缀,添加 https:// + if !strings.HasPrefix(upstreamBaseURL, "http://") && !strings.HasPrefix(upstreamBaseURL, "https://") { + upstreamBaseURL = fmt.Sprintf("https://%s", upstreamBaseURL) + } + } else { + // 否则使用 EndpointMap 配置的地址 + upstreamBaseURL = fmt.Sprintf("https://%s", EndpointMap[p.RegionType].MCP) + } + + upstreamURL, err := url.Parse(upstreamBaseURL) + if err != nil { + return nil, err + } + + newURL := *r.URL + newURL.Scheme = upstreamURL.Scheme + newURL.Host = upstreamURL.Host + if newURL.Path == "" { + newURL.Path = "/" + } + + method := r.Method + var body io.ReadCloser = r.Body + + upstreamReq, err := http.NewRequest(method, newURL.String(), body) + if err != nil { + return nil, err + } + + for k, v := range r.Header { + if strings.ToLower(k) != "host" && strings.ToLower(k) != "authorization" { + upstreamReq.Header[k] = v + } + } + + upstreamReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + upstreamReq.Header.Set("User-Agent", fmt.Sprintf("%s/aliyun-cli-mcp-proxy", util.GetAliyunCliUserAgent())) + + return upstreamReq, nil +} + +func (p *MCPProxy) handleSSE(w http.ResponseWriter, resp *http.Response) { + log.Println("MCP Proxy handle SSE response from upstream request url", resp.Request.URL.String()) + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "SSE not supported", http.StatusInternalServerError) + return + } + + for k, v := range resp.Header { + if strings.ToLower(k) == "content-length" { + continue + } + w.Header()[k] = v + } + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", "text/event-stream") + } + + w.WriteHeader(resp.StatusCode) + + reader := bufio.NewReader(resp.Body) + for { + // 检查是否正在关闭 + select { + case <-p.stopCh: + log.Println("MCP Proxy handle SSE connection closed due to server shutdown") + return + default: + } + + line, err := reader.ReadBytes('\n') + if err != nil { + break + } + + if _, err = w.Write(line); err != nil { + break + } + log.Println("MCP Proxy handle SSE response line", string(line)) + + flusher.Flush() + } +} + +func (p *MCPProxy) handleHTTP(w http.ResponseWriter, resp *http.Response) { + log.Println("MCP Proxy handle HTTP response from upstream request url", resp.Request.URL.String()) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Failed to read response body", http.StatusInternalServerError) + log.Println("MCP Proxy gets mcp server http response error from http request", err.Error()) + + return + } + + // 检查是否正在关闭 + select { + case <-p.stopCh: + log.Println("HTTP response cancelled due to server shutdown") + return + default: + } + + for k, v := range resp.Header { + w.Header()[k] = v + } + + w.WriteHeader(resp.StatusCode) + w.Write(bodyBytes) +} + +func (r *TokenRefresher) Start() { + r.ticker = time.NewTicker(CheckInterval) + defer r.ticker.Stop() + + log.Println("MCP Proxy token refresher started") + + r.sendToken() + for { + select { + case <-r.ticker.C: + r.checkAndRefresh() + case <-r.stopCh: + return + } + } +} + +func (r *TokenRefresher) sendToken() { + r.mu.RLock() + token := r.profile.MCPOAuthAccessToken + expiresAt := r.profile.MCPOAuthAccessTokenExpire + r.mu.RUnlock() + + select { + case r.tokenCh <- TokenInfo{Token: token, ExpiresAt: expiresAt}: + // 成功发送 + default: + // channel 已满,清空旧值后发送新值 + select { + case <-r.tokenCh: + default: + } + r.tokenCh <- TokenInfo{Token: token, ExpiresAt: expiresAt} + } +} + +func (r *TokenRefresher) Stop() { + close(r.stopCh) +} + +func (r *TokenRefresher) checkAndRefresh() { + r.mu.RLock() + currentTime := util.GetCurrentUnixTime() + needRefresh := false + needReauth := false + + // 如果 refresh token 过期,则重新授权 + if r.profile.MCPOAuthRefreshTokenExpire-currentTime < int64(RefreshTokenRefreshWindow.Seconds()) { + needReauth = true + } + // 如果 access token 过期,则刷新 access token + if r.profile.MCPOAuthAccessTokenExpire-currentTime < int64(AccessTokenRefreshWindow.Seconds()) { + needRefresh = true + } + r.mu.RUnlock() + + if needReauth { + if err := r.reauthorizeWithProxy(); err != nil { + r.reportFatalError(fmt.Errorf("re-authorization failed: %v. Please restart aliyun mcp-proxy", err)) + return + } + } else if needRefresh { + if err := r.refreshAccessToken(); err != nil { + r.reportFatalError(fmt.Errorf("refresh access token failed. Please restart aliyun mcp-proxy")) + return + } + } +} + +func (r *TokenRefresher) refreshAccessToken() error { + r.mu.Lock() + + if r.refreshing { + currentTime := util.GetCurrentUnixTime() + currentExpiresAt := r.profile.MCPOAuthAccessTokenExpire + if currentExpiresAt > currentTime { + r.mu.Unlock() + return nil + } + // Token 已过期,必须等待刷新完成 + r.mu.Unlock() + return r.waitForRefresh(currentExpiresAt) + } + + r.refreshing = true + endpoint := EndpointMap[r.regionType].OAuth + clientId := r.profile.MCPOAuthAppId + refreshToken := r.profile.MCPOAuthRefreshToken + r.mu.Unlock() + + // 执行网络请求(不持有锁,避免阻塞) + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("client_id", clientId) + data.Set("refresh_token", refreshToken) + // fmt.Println("refresh access token data", data.Encode()) + // fmt.Println("refresh access token endpoint", endpoint) + // fmt.Println("refresh access token clientId", clientId) + // fmt.Println("refresh access token refreshToken", refreshToken) + + newTokens, err := oauthRefresh(endpoint, data) + if err != nil { + r.mu.Lock() + r.refreshing = false + r.mu.Unlock() + atomic.AddInt64(&r.stats.TokenRefreshErrors, 1) + return fmt.Errorf("oauth refresh failed: %w", err) + } + + log.Println("Access token refresh request successfully") + + r.mu.Lock() + currentTime := util.GetCurrentUnixTime() + r.profile.MCPOAuthAccessToken = newTokens.AccessToken + r.profile.MCPOAuthRefreshToken = newTokens.RefreshToken + r.profile.MCPOAuthAccessTokenExpire = currentTime + newTokens.ExpiresIn + r.refreshing = false + + retrySaveProfile( + r.atomicSaveProfile, + MaxSaveFailures, + func() { + r.mu.Unlock() + r.reportFatalError(fmt.Errorf("critical: failed to save refreshed tokens after %d attempts. "+ + "Please re-login with: aliyun configure and run 'aliyun mcp-proxy' again", MaxSaveFailures)) + }, + ) + r.mu.Unlock() + + log.Println("Access token refresh process completed successfully") + + atomic.AddInt64(&r.stats.TokenRefreshes, 1) + atomic.StoreInt64(&r.stats.LastTokenRefresh, time.Now().Unix()) + + r.sendToken() + return nil +} + +func (r *TokenRefresher) waitForRefresh(currentExpiresAt int64) error { + deadline := time.Now().Add(WaitForRefreshTimeout) + for time.Now().Before(deadline) { + time.Sleep(100 * time.Millisecond) + + r.mu.RLock() + if !r.refreshing && r.profile.MCPOAuthAccessTokenExpire > currentExpiresAt { + r.mu.RUnlock() + return nil + } + r.mu.RUnlock() + } + + return fmt.Errorf("timeout waiting for token refresh") +} + +func (r *TokenRefresher) waitForReauthorization(currentRefreshTokenExpire int64) error { + deadline := time.Now().Add(WaitForReauthorizationTimeout) + for time.Now().Before(deadline) { + time.Sleep(100 * time.Millisecond) + + r.mu.RLock() + if !r.reauthorizing && r.profile.MCPOAuthRefreshTokenExpire > currentRefreshTokenExpire { + r.mu.RUnlock() + return nil + } + r.mu.RUnlock() + } + + return fmt.Errorf("timeout waiting for reauthorization") +} + +func (r *TokenRefresher) ForceRefresh() error { + return r.refreshAccessToken() +} + +func (r *TokenRefresher) atomicSaveProfile() error { + return saveMcpProfile(r.profile) +} + +func deleteMcpConfigFile() { + configPath := getMCPConfigPath() + if bytes, err := os.ReadFile(configPath); err == nil { + if profile, err := NewMcpProfileFromBytes(bytes); err == nil { + log.Printf("MCP Config with issue:") + log.Printf(" Profile Name: %s", profile.Name) + log.Printf(" OAuth App Name: %s", profile.MCPOAuthAppName) + log.Printf(" OAuth App ID: %s", profile.MCPOAuthAppId) + log.Printf(" OAuth Site Type: %s", profile.MCPOAuthSiteType) + log.Printf(" Access Token Validity: %d seconds", profile.MCPOAuthAccessTokenValidity) + log.Printf(" Access Token Expire: %d", profile.MCPOAuthAccessTokenExpire) + log.Printf(" Refresh Token Validity: %d seconds", profile.MCPOAuthRefreshTokenValidity) + log.Printf(" Refresh Token Expire: %d", profile.MCPOAuthRefreshTokenExpire) + + // 打印脱敏后的 token + maskToken := func(token string) string { + if len(token) <= 8 { + return "***" + } + return token[:4] + "..." + token[len(token)-4:] + } + if len(profile.MCPOAuthAccessToken) > 0 { + log.Printf(" Access Token: %s", maskToken(profile.MCPOAuthAccessToken)) + } + if len(profile.MCPOAuthRefreshToken) > 0 { + log.Printf(" Refresh Token: %s", maskToken(profile.MCPOAuthRefreshToken)) + } + + currentTime := util.GetCurrentUnixTime() + if profile.MCPOAuthAccessTokenExpire > 0 { + accessTokenRemaining := profile.MCPOAuthAccessTokenExpire - currentTime + log.Printf(" Access Token remaining: %d seconds (%.1f minutes)", + accessTokenRemaining, float64(accessTokenRemaining)/60) + } + if profile.MCPOAuthRefreshTokenExpire > 0 { + refreshTokenRemaining := profile.MCPOAuthRefreshTokenExpire - currentTime + log.Printf(" Refresh Token remaining: %d seconds (%.1f hours)", + refreshTokenRemaining, float64(refreshTokenRemaining)/3600) + } + } else { + log.Printf("Failed to parse mcp config before close: %v", err) + } + } else if !os.IsNotExist(err) { + log.Printf("Failed to read mcp config before close: %v", err) + } + + if err := os.Remove(configPath); err != nil { + if !os.IsNotExist(err) { + log.Printf("Failed to delete mcp config file %q: %v", configPath, err) + } + } else { + log.Printf("Deleted mcp config file: %q", configPath) + } +} + +func (r *TokenRefresher) reportFatalError(err error) { + deleteMcpConfigFile() + select { + case r.fatalErrCh <- err: + default: + // channel 已满,说明已经有错误在等待处理,忽略新的错误 + } +} + +func retrySaveProfile(saveFn func() error, maxAttempts int, onMaxFailures func()) { + for attempt := 1; attempt <= maxAttempts; attempt++ { + if err := saveFn(); err == nil { + return + } + if attempt == maxAttempts { + onMaxFailures() + } + } +} + +func (r *TokenRefresher) reauthorizeWithProxy() error { + r.mu.Lock() + + if r.reauthorizing { + currentTime := util.GetCurrentUnixTime() + currentRefreshTokenExpire := r.profile.MCPOAuthRefreshTokenExpire + if currentRefreshTokenExpire > currentTime { + r.mu.Unlock() + return nil + } + // Refresh token 已过期,必须等待重新授权完成 + r.mu.Unlock() + return r.waitForReauthorization(currentRefreshTokenExpire) + } + + r.reauthorizing = true + clientId := r.profile.MCPOAuthAppId + refreshTokenValidity := r.profile.MCPOAuthRefreshTokenValidity + r.mu.Unlock() + + // 执行 OAuth 流程(不持有锁,避免阻塞) + oauthScope := r.scope + if oauthScope == "" { + oauthScope = "/acs/mcp-server" + } + stderr := getStderrWriter(nil) + tokenResult, err := executeOAuthFlow(nil, clientId, r.regionType, r.callbackManager, r.host, r.port, r.autoOpenBrowser, oauthScope, func(authURL string) { + cli.Printf(stderr, "OAuth Re-authorization Required. Please visit: %s\n", authURL) + }) + if err != nil { + r.mu.Lock() + r.reauthorizing = false + r.mu.Unlock() + atomic.AddInt64(&r.stats.TokenRefreshErrors, 1) + return err + } + log.Println("OAuth re-authorization request successfully") + + r.mu.Lock() + currentTime := util.GetCurrentUnixTime() + r.profile.MCPOAuthAccessToken = tokenResult.AccessToken + r.profile.MCPOAuthRefreshToken = tokenResult.RefreshToken + r.profile.MCPOAuthAccessTokenExpire = tokenResult.AccessTokenExpire + r.profile.MCPOAuthRefreshTokenExpire = currentTime + int64(refreshTokenValidity) + r.reauthorizing = false + + retrySaveProfile( + r.atomicSaveProfile, + MaxSaveFailures, + func() { + r.mu.Unlock() + r.reportFatalError(fmt.Errorf("critical: failed to save reauthorized tokens after %d attempts. "+ + "Please re-login with: aliyun configure and run 'aliyun mcp-proxy' again", MaxSaveFailures)) + }, + ) + r.mu.Unlock() + log.Println("OAuth re-authorization process completed successfully") + + atomic.AddInt64(&r.stats.TokenRefreshes, 1) + atomic.StoreInt64(&r.stats.LastTokenRefresh, time.Now().Unix()) + + r.sendToken() + return nil +} diff --git a/mcpproxy/mcp_server_test.go b/mcpproxy/mcp_server_test.go new file mode 100644 index 000000000..f77c86989 --- /dev/null +++ b/mcpproxy/mcp_server_test.go @@ -0,0 +1,433 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewMCPProxy(t *testing.T) { + host := "127.0.0.1" + port := 8088 + regionType := RegionCN + scope := "/acs/mcp-server" + mcpProfile := NewMcpProfile("test-profile") + servers := []MCPServerInfo{ + { + Id: "server1", + Name: "Test Server", + SourceType: "api", + Product: "ecs", + Urls: MCPInfoUrls{ + MCP: "/mcp/server1", + SSE: "/sse/server1", + }, + }, + } + manager := NewOAuthCallbackManager() + autoOpenBrowser := true + + config := ProxyConfig{ + Host: host, + Port: port, + RegionType: regionType, + Scope: scope, + McpProfile: mcpProfile, + ExistMcpServers: servers, + CallbackManager: manager, + AutoOpenBrowser: autoOpenBrowser, + UpstreamBaseURL: "", + } + proxy := NewMCPProxy(config) + + assert.NotNil(t, proxy) + assert.Equal(t, host, proxy.Host) + assert.Equal(t, port, proxy.Port) + assert.Equal(t, regionType, proxy.RegionType) + assert.Equal(t, servers, proxy.ExistMcpServers) + assert.NotNil(t, proxy.TokenRefresher) + assert.NotNil(t, proxy.stopCh) + assert.NotNil(t, proxy.stats) + assert.Equal(t, mcpProfile, proxy.TokenRefresher.profile) + assert.Equal(t, regionType, proxy.TokenRefresher.regionType) + assert.Equal(t, manager, proxy.TokenRefresher.callbackManager) + assert.Equal(t, host, proxy.TokenRefresher.host) + assert.Equal(t, port, proxy.TokenRefresher.port) + assert.Equal(t, scope, proxy.TokenRefresher.scope) + assert.Equal(t, autoOpenBrowser, proxy.TokenRefresher.autoOpenBrowser) +} + +func TestMCPProxy_Stop(t *testing.T) { + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 0, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: NewMcpProfile("test"), + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: "", + } + proxy := NewMCPProxy(config) + + err := proxy.Stop() + assert.NoError(t, err) + + select { + case <-proxy.stopCh: + // 正常,channel 已关闭 + default: + t.Error("stopCh should be closed") + } +} + +func TestMCPProxy_handleHealth(t *testing.T) { + profile := NewMcpProfile("test-profile") + currentTime := time.Now().Unix() + profile.MCPOAuthAccessToken = "test-token" + profile.MCPOAuthAccessTokenExpire = currentTime + 3600 + profile.MCPOAuthRefreshToken = "refresh-token" + profile.MCPOAuthRefreshTokenExpire = currentTime + 86400 + + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 8088, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: profile, + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: "", + } + proxy := NewMCPProxy(config) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + + proxy.handleHealth(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + + var health map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &health) + assert.NoError(t, err) + assert.Equal(t, "healthy", health["status"]) + assert.NotNil(t, health["timestamp"]) + assert.NotNil(t, health["uptime"]) + assert.NotNil(t, health["memory"]) + assert.NotNil(t, health["requests"]) + assert.NotNil(t, health["token_refreshes"]) +} + +func TestMCPProxy_handleHealth_ExpiredToken(t *testing.T) { + profile := NewMcpProfile("test-profile") + // 设置过期的 token + profile.MCPOAuthAccessToken = "test-token" + profile.MCPOAuthAccessTokenExpire = time.Now().Unix() - 100 + profile.MCPOAuthRefreshToken = "refresh-token" + profile.MCPOAuthRefreshTokenExpire = time.Now().Unix() + 86400 + + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 8088, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: profile, + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: "", + } + proxy := NewMCPProxy(config) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + + proxy.handleHealth(w, req) + + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + + var health map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &health) + assert.NoError(t, err) + assert.Equal(t, "degraded", health["status"]) + assert.Equal(t, "expired", health["token_status"]) +} + +func TestMCPProxy_ServeHTTP_ShuttingDown(t *testing.T) { + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 8088, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: NewMcpProfile("test"), + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: "", + } + proxy := NewMCPProxy(config) + + // 关闭 stopCh 模拟正在关闭 + close(proxy.stopCh) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + proxy.ServeMCPProxyRequest(w, req) + + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + assert.Contains(t, w.Body.String(), "Server is shutting down") + assert.Equal(t, int64(1), atomic.LoadInt64(&proxy.stats.ErrorRequests)) +} + +func TestMCPProxy_buildUpstreamRequest(t *testing.T) { + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessToken = "test-access-token" + profile.MCPOAuthAccessTokenExpire = time.Now().Unix() + 3600 + + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 8088, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: profile, + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: "", + } + proxy := NewMCPProxy(config) + + body := bytes.NewBufferString("test body") + req := httptest.NewRequest("POST", "/test/path?query=value", body) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Host", "localhost") + req.Header.Set("Authorization", "Bearer old-token") + + upstreamReq, err := proxy.buildUpstreamRequest(req, "new-access-token") + assert.NoError(t, err) + assert.NotNil(t, upstreamReq) + + assert.Equal(t, "POST", upstreamReq.Method) + assert.Contains(t, upstreamReq.URL.String(), EndpointMap[RegionCN].MCP) + assert.Contains(t, upstreamReq.URL.Path, "/test/path") + assert.Equal(t, "value", upstreamReq.URL.Query().Get("query")) + assert.Equal(t, "Bearer new-access-token", upstreamReq.Header.Get("Authorization")) + assert.Equal(t, "application/json", upstreamReq.Header.Get("Content-Type")) + assert.NotEqual(t, "localhost", upstreamReq.Header.Get("Host")) +} + +func TestMCPProxy_buildUpstreamRequest_WithCustomURL(t *testing.T) { + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessToken = "test-access-token" + profile.MCPOAuthAccessTokenExpire = time.Now().Unix() + 3600 + + // 测试使用自定义的 upstream URL + customURL := "https://custom-mcp.example.com" + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 8088, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: profile, + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: customURL, + } + proxy := NewMCPProxy(config) + + body := bytes.NewBufferString("test body") + req := httptest.NewRequest("POST", "/test/path?query=value", body) + req.Header.Set("Content-Type", "application/json") + + upstreamReq, err := proxy.buildUpstreamRequest(req, "new-access-token") + assert.NoError(t, err) + assert.NotNil(t, upstreamReq) + + assert.Equal(t, "POST", upstreamReq.Method) + assert.Contains(t, upstreamReq.URL.String(), "https://custom-mcp.example.com") + assert.Contains(t, upstreamReq.URL.Path, "/test/path") + assert.Equal(t, "Bearer new-access-token", upstreamReq.Header.Get("Authorization")) +} + +func TestMCPProxy_buildUpstreamRequest_WithCustomURL_NoProtocol(t *testing.T) { + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessToken = "test-access-token" + profile.MCPOAuthAccessTokenExpire = time.Now().Unix() + 3600 + + // 测试使用自定义的 upstream URL(没有协议前缀) + customURL := "custom-mcp.example.com" + config := ProxyConfig{ + Host: "127.0.0.1", + Port: 8088, + RegionType: RegionCN, + Scope: "/acs/mcp-server", + McpProfile: profile, + ExistMcpServers: nil, + CallbackManager: NewOAuthCallbackManager(), + AutoOpenBrowser: false, + UpstreamBaseURL: customURL, + } + proxy := NewMCPProxy(config) + + body := bytes.NewBufferString("test body") + req := httptest.NewRequest("POST", "/test/path", body) + + upstreamReq, err := proxy.buildUpstreamRequest(req, "new-access-token") + assert.NoError(t, err) + assert.NotNil(t, upstreamReq) + + // 应该自动添加 https:// 前缀 + assert.Contains(t, upstreamReq.URL.String(), "https://custom-mcp.example.com") +} + +func TestTokenRefresher_Stop(t *testing.T) { + refresher := &TokenRefresher{ + stopCh: make(chan struct{}), + stats: &RuntimeStats{ + StartTime: time.Now(), + }, + } + + refresher.Stop() + + select { + case <-refresher.stopCh: + // 正常,channel 已关闭 + default: + t.Error("stopCh should be closed") + } +} + +func TestTokenRefresher_sendToken(t *testing.T) { + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessToken = "test-token" + profile.MCPOAuthAccessTokenExpire = time.Now().Unix() + 3600 + + refresher := &TokenRefresher{ + profile: profile, + tokenCh: make(chan TokenInfo, 1), + stats: &RuntimeStats{ + StartTime: time.Now(), + }, + } + + refresher.sendToken() + + select { + case tokenInfo := <-refresher.tokenCh: + assert.Equal(t, "test-token", tokenInfo.Token) + assert.Equal(t, profile.MCPOAuthAccessTokenExpire, tokenInfo.ExpiresAt) + case <-time.After(100 * time.Millisecond): + t.Error("token should be sent to channel") + } +} + +func TestTokenRefresher_atomicSaveProfile(t *testing.T) { + tmpDir := t.TempDir() + originalHome := getHomeEnv() + defer restoreHomeEnv(originalHome) + + setHomeEnv(tmpDir) + + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessToken = "test-token" + + refresher := &TokenRefresher{ + profile: profile, + stats: &RuntimeStats{ + StartTime: time.Now(), + }, + } + + err := refresher.atomicSaveProfile() + assert.NoError(t, err) + + configPath := getMCPConfigPath() + _, err = os.Stat(configPath) + assert.NoError(t, err) +} + +func TestRetrySaveProfile(t *testing.T) { + attempts := int32(0) + maxAttempts := 3 + + retrySaveProfile(func() error { + atomic.AddInt32(&attempts, 1) + return nil + }, maxAttempts, func() { + t.Error("onMaxFailures should not be called") + }) + + assert.Equal(t, int32(1), atomic.LoadInt32(&attempts)) + + attempts = 0 + onMaxFailuresCalled := false + + retrySaveProfile(func() error { + atomic.AddInt32(&attempts, 1) + return assert.AnError + }, maxAttempts, func() { + onMaxFailuresCalled = true + }) + + assert.Equal(t, int32(maxAttempts), atomic.LoadInt32(&attempts)) + assert.True(t, onMaxFailuresCalled) +} + +func TestGetContentFromApiResponse_Integration(t *testing.T) { + response := map[string]any{ + "body": map[string]any{ + "key": "value", + }, + } + + content, err := GetContentFromApiResponse(response) + assert.NoError(t, err) + assert.NotNil(t, content) + assert.Contains(t, string(content), "key") + assert.Contains(t, string(content), "value") +} + +// 辅助函数 +func getHomeEnv() string { + return os.Getenv("HOME") +} + +func setHomeEnv(value string) { + os.Setenv("HOME", value) +} + +func restoreHomeEnv(value string) { + if value != "" { + os.Setenv("HOME", value) + } else { + os.Unsetenv("HOME") + } +} diff --git a/mcpproxy/oauth_app.go b/mcpproxy/oauth_app.go new file mode 100644 index 000000000..51c096ab7 --- /dev/null +++ b/mcpproxy/oauth_app.go @@ -0,0 +1,998 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "bufio" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "html/template" + "io" + "log" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strings" + "sync" + "time" + + "github.com/aliyun/aliyun-cli/v3/cli" + "github.com/aliyun/aliyun-cli/v3/config" + "github.com/aliyun/aliyun-cli/v3/util" + + openapiClient "github.com/alibabacloud-go/darabonba-openapi/v2/client" + openapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils" + "github.com/alibabacloud-go/tea/dara" + "github.com/alibabacloud-go/tea/tea" +) + +type RegionType string + +const ( + RegionCN RegionType = "CN" + RegionINTL RegionType = "INTL" + DefaultMcpProfileName = "default-mcp" + MCPOAuthAppName = "aliyun-cli-mcp-proxy" + MCPOAuthDisplayName = "AliyunCLI-MCP-Proxy" +) + +type EndpointConfig struct { + SignIn string + OAuth string + IMS string + MCP string +} + +// 国内/国际站端点映射 +var EndpointMap = map[RegionType]EndpointConfig{ + RegionCN: { + SignIn: "https://signin.aliyun.com", + OAuth: "https://oauth.aliyun.com", + IMS: "ims.aliyuncs.com", + MCP: "openapi-mcp.cn-hangzhou.aliyuncs.com", + }, + RegionINTL: { + SignIn: "https://signin.alibabacloud.com", + OAuth: "https://oauth.alibabacloud.com", + IMS: "ims.aliyuncs.com", + MCP: "openapi-mcp.ap-southeast-1.aliyuncs.com", + }, +} + +const ( + OAuthTimeout = 5 * time.Minute + AccessTokenValiditySec = 10800 // 3 hours + RefreshTokenValiditySec = 31536000 // 365 days (1 year) + + AliyunCLIHomepageURL = "https://help.aliyun.com/zh/cli/what-is-alibaba-cloud-cli" + RedirectCountdownSeconds = 10 // 自动跳转倒计时(秒) + ManualModeCloseDelayMs = 3000 // 手动模式自动关闭延迟(毫秒) +) + +var oauthTimeout = OAuthTimeout + +const ( + oauthErrorPageHTML = ` + + + + OAuth Authorization Error + + + +

Authorization Error

+

No authorization code received. Please try again.

+ +` + oauthSuccessPageManualHTML = ` + + + + OAuth Authorization Success + + + +
+

✓ Authorization Successful

+

Your authorization code is:

+
+ {{.Code}} +
+ +

You can close this window now.

+
+ + +` + oauthSuccessPageAutoHTML = ` + + + + OAuth Authorization Success + + + +
+

✓ Authorization Successful

+

Redirecting to Aliyun CLI homepage in {{.Countdown}} seconds...

+

Or click here to visit now.

+
+ + +` +) + +var ( + oauthErrorPageTemplate *template.Template + oauthSuccessPageManualTemplate *template.Template + oauthSuccessPageAutoTemplate *template.Template +) + +func init() { + var err error + + oauthErrorPageTemplate, err = template.New("error").Parse(oauthErrorPageHTML) + if err != nil { + panic(fmt.Sprintf("failed to parse oauth error page template: %v", err)) + } + + oauthSuccessPageManualTemplate, err = template.New("successManual").Parse(oauthSuccessPageManualHTML) + if err != nil { + panic(fmt.Sprintf("failed to parse oauth success manual page template: %v", err)) + } + + oauthSuccessPageAutoTemplate, err = template.New("successAuto").Funcs(template.FuncMap{ + "js": func(s string) template.JS { + return template.JS(fmt.Sprintf("%q", s)) + }, + }).Parse(oauthSuccessPageAutoHTML) + if err != nil { + panic(fmt.Sprintf("failed to parse oauth success auto page template: %v", err)) + } +} + +type OAuthPageData struct { + Code string + Countdown int + HomepageURL string + CloseDelayMs int +} + +type OAuthCallbackManager struct { + mu sync.RWMutex + pendingAuth chan string // 用于传递授权码 + errorCh chan error // 用于传递错误 + isWaiting bool // 是否正在等待回调 +} + +type OAuthTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshExpiresIn int64 `json:"refresh_expires_in"` + TokenType string `json:"token_type"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +func NewOAuthCallbackManager() *OAuthCallbackManager { + return &OAuthCallbackManager{ + pendingAuth: make(chan string, 1), + errorCh: make(chan error, 1), + isWaiting: false, + } +} + +func (m *OAuthCallbackManager) StartWaiting() { + m.mu.Lock() + defer m.mu.Unlock() + m.isWaiting = true + select { + case <-m.pendingAuth: + default: + } + select { + case <-m.errorCh: + default: + } +} + +func (m *OAuthCallbackManager) StopWaiting() { + m.mu.Lock() + defer m.mu.Unlock() + m.isWaiting = false +} + +func (m *OAuthCallbackManager) IsWaiting() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.isWaiting +} + +func (m *OAuthCallbackManager) HandleCallback(code string, err error) bool { + if !m.IsWaiting() { + return false + } + + if err != nil { + select { + case m.errorCh <- err: + default: + } + return true + } + + if code != "" { + select { + case m.pendingAuth <- code: + default: + } + return true + } + + return false +} + +func (m *OAuthCallbackManager) WaitForCode() (string, error) { + select { + case code := <-m.pendingAuth: + return code, nil + case err := <-m.errorCh: + return "", err + case <-time.After(oauthTimeout): + return "", fmt.Errorf("timeout waiting for authorization") + } +} + +func handleOAuthCallbackRequest(w http.ResponseWriter, r *http.Request, handler func(string, error) bool, showCode bool) bool { + if r.URL.Path != "/callback" { + return false + } + + code := r.URL.Query().Get("code") + if code == "" { + handler("", fmt.Errorf("no authorization code received")) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + if err := oauthErrorPageTemplate.Execute(w, nil); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + return true + } + + handler(code, nil) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + + if showCode { + data := OAuthPageData{ + Code: code, + CloseDelayMs: ManualModeCloseDelayMs, + } + if err := oauthSuccessPageManualTemplate.Execute(w, data); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + } else { + data := OAuthPageData{ + Countdown: RedirectCountdownSeconds, + HomepageURL: AliyunCLIHomepageURL, + } + if err := oauthSuccessPageAutoTemplate.Execute(w, data); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + } + return true +} + +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +func oauthRefresh(endpoint string, data url.Values) (*OAuthTokenResponse, error) { + req, err := http.NewRequest("POST", endpoint+"/v1/token", strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := util.NewHttpClient().Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("refresh failed: status %d", resp.StatusCode) + } + + var tokenResp OAuthTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, err + } + + if tokenResp.Error != "" { + return nil, fmt.Errorf("%s: %s", tokenResp.Error, tokenResp.ErrorDescription) + } + + return &tokenResp, nil +} + +type OAuthTokenResult struct { + AccessToken string + RefreshToken string + AccessTokenExpire int64 +} + +func exchangeCodeForTokenWithPKCE(clientId, code, codeVerifier, redirectURI, oauthEndpoint string) (*OAuthTokenResult, error) { + log.Println("Start to exchange code for token with PKCE") + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("client_id", clientId) + data.Set("redirect_uri", redirectURI) + data.Set("code_verifier", codeVerifier) + + tokenResp, err := oauthRefresh(oauthEndpoint, data) + if err != nil { + log.Println("Exchange code for token with PKCE failed:", err) + return nil, fmt.Errorf("oauth refresh failed: %w", err) + } + log.Println("Exchange code for token with PKCE successfully") + + currentTime := util.GetCurrentUnixTime() + return &OAuthTokenResult{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + AccessTokenExpire: currentTime + tokenResp.ExpiresIn, + }, nil +} + +func buildOAuthURL(clientId string, region RegionType, host string, port int, codeChallenge string, scope string) string { + redirectURI := buildRedirectUri(host, port) + return fmt.Sprintf("%s/oauth2/v1/auth?client_id=%s&response_type=code&scope=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256", + EndpointMap[region].SignIn, clientId, url.QueryEscape(scope), redirectURI, codeChallenge) +} + +func executeOAuthFlow(ctx *cli.Context, clientId string, regionType RegionType, manager *OAuthCallbackManager, + host string, port int, autoOpenBrowser bool, scope string, logAuthURL func(string)) (*OAuthTokenResult, error) { + stderr := getStderrWriter(ctx) + codeVerifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate code verifier: %w", err) + } + codeChallenge := generateCodeChallenge(codeVerifier) + + redirectURI := buildRedirectUri(host, port) + authURL := buildOAuthURL(clientId, regionType, host, port, codeChallenge, scope) + + if logAuthURL != nil { + logAuthURL(authURL) + } + waitStarted := false + stopWaiting := func() { + if waitStarted { + manager.StopWaiting() + waitStarted = false + } + } + defer stopWaiting() + + var code string + + if autoOpenBrowser { + if err := OpenBrowser(authURL); err != nil { + // 错误信息输出到 stderr,确保用户能看到 + cli.Printf(stderr, "Failed to open browser automatically: %v\n", err) + cli.Printf(stderr, "Falling back to manual code input mode...\n") + if !isInteractiveInput() { + return nil, fmt.Errorf("manual authorization required but standard input is not interactive") + } + reader := bufio.NewReader(os.Stdin) + code, err = promptAuthorizationCode(stderr, reader) + if err != nil { + return nil, fmt.Errorf("failed to read authorization code: %w", err) + } + } else { + manager.StartWaiting() + waitStarted = true + code, err = manager.WaitForCode() + if err != nil { + return nil, fmt.Errorf("failed to get authorization code: %w", err) + } + } + } else { + if !isInteractiveInput() { + return nil, fmt.Errorf("manual authorization required but standard input is not interactive") + } + reader := bufio.NewReader(os.Stdin) + code, err = promptAuthorizationCode(stderr, reader) + if err != nil { + return nil, fmt.Errorf("failed to read authorization code: %w", err) + } + } + + if code == "" { + return nil, fmt.Errorf("authorization code is empty") + } + log.Println("Oauth authorization successfully, code received:", code) + + return exchangeCodeForTokenWithPKCE(clientId, code, codeVerifier, redirectURI, EndpointMap[regionType].OAuth) +} + +func startMCPOAuthFlowWithManager(ctx *cli.Context, clientId string, region RegionType, + manager *OAuthCallbackManager, host string, port int, autoOpenBrowser bool, scope string) (*OAuthTokenResult, error) { + stderr := getStderrWriter(ctx) + tokenResult, err := executeOAuthFlow(ctx, clientId, region, manager, host, port, autoOpenBrowser, scope, func(authURL string) { + cli.Printf(stderr, "Opening browser for OAuth login...\nURL: %s\n\n", authURL) + }) + if err != nil { + log.Println("Execute OAuth flow failed:", err) + return nil, err + } + + cli.Println(stderr, "OAuth login successful!") + return tokenResult, nil +} + +func startMCPOAuthFlow(ctx *cli.Context, clientId string, region RegionType, host string, port int, autoOpenBrowser bool, scope string) (*OAuthTokenResult, error) { + manager := NewOAuthCallbackManager() + + server := &http.Server{Addr: fmt.Sprintf("%s:%d", host, port)} + http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + // 手动输入模式需要显示授权码(autoOpenBrowser=false 表示需要显示) + showCode := !autoOpenBrowser + handleOAuthCallbackRequest(w, r, manager.HandleCallback, showCode) + }) + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + manager.HandleCallback("", err) + } + }() + + defer server.Close() + + return startMCPOAuthFlowWithManager(ctx, clientId, region, manager, host, port, autoOpenBrowser, scope) +} + +func isStderrRedirected() bool { + info, err := os.Stderr.Stat() + if err != nil { + return true + } + return (info.Mode() & os.ModeCharDevice) == 0 +} + +type teeWriter struct { + writers []io.Writer +} + +func (t *teeWriter) Write(p []byte) (n int, err error) { + for _, w := range t.writers { + n, err = w.Write(p) + if err != nil { + return n, err + } + } + return len(p), nil +} + +// 获取 stderr writer用于交互式提示 +func getStderrWriter(ctx *cli.Context) io.Writer { + var stderrWriter io.Writer + if ctx != nil && ctx.Stderr() != nil { + stderrWriter = ctx.Stderr() + } else { + stderrWriter = os.Stderr + } + + if isStderrRedirected() { + if tty, err := os.OpenFile("/dev/tty", os.O_WRONLY, 0); err == nil { + return &teeWriter{writers: []io.Writer{stderrWriter, tty}} + } + return stderrWriter + } + return stderrWriter +} + +func isInteractiveInput() bool { + info, err := os.Stdin.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func promptAuthorizationCode(stderr io.Writer, reader *bufio.Reader) (string, error) { + cli.Println(stderr, "\nPlease open the authorization URL on a machine with a browser and complete the sign-in.") + cli.Println(stderr, "") + cli.Println(stderr, "After authorization, the browser will redirect to a callback URL.") + cli.Println(stderr, "Even if the page fails to load (connection error), the authorization code is in the URL.") + cli.Println(stderr, "Please copy the value of the `code` parameter from the browser's address bar.") + cli.Println(stderr, "") + cli.Println(stderr, "Example: If the URL is:") + cli.Println(stderr, " http://127.0.0.1:8088/callback?code=abc123xyz&state=...") + cli.Println(stderr, " Then copy only: abc123xyz") + cli.Println(stderr, "") + + for { + cli.Print(stderr, "Enter authorization code: ") + line, err := reader.ReadString('\n') + if err != nil { + return "", err + } + line = strings.TrimSpace(line) + if line == "" { + cli.Println(stderr, "Input is empty. Please try again.") + continue + } + + if strings.HasPrefix(strings.ToLower(line), "http://") || + strings.HasPrefix(strings.ToLower(line), "https://") || + strings.Contains(line, "?") || + strings.Contains(strings.ToLower(line), "code=") { + cli.Println(stderr, "Please paste the authorization code only, not the entire URL.") + continue + } + + return line, nil + } +} + +func OpenBrowser(url string) error { + // return errors.New("not implemented") + var cmd *exec.Cmd + switch runtime.GOOS { + case "linux": + cmd = exec.Command("xdg-open", url) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) + case "darwin": + cmd = exec.Command("open", url) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + + return cmd.Start() +} + +type OAuthApplication struct { + ApplicationId string `json:"ApplicationId"` + AppName string `json:"AppName"` + DisplayName string `json:"DisplayName"` + AppType string `json:"AppType"` + RedirectUris []string `json:"RedirectUris"` + Scopes []string `json:"Scopes"` + AccessTokenValidity int `json:"AccessTokenValidity"` + RefreshTokenValidity int `json:"RefreshTokenValidity"` +} + +type IMSApplication struct { + AppId string `json:"AppId"` + AppName string `json:"AppName"` + DisplayName string `json:"DisplayName"` + AppType string `json:"AppType"` + RedirectUris struct { + RedirectUri []string `json:"RedirectUri"` + } `json:"RedirectUris"` + DelegatedScope struct { + PredefinedScopes struct { + PredefinedScope []struct { + Name string `json:"Name"` + } `json:"PredefinedScope"` + } `json:"PredefinedScopes"` + } `json:"DelegatedScope"` + AccessTokenValidity int `json:"AccessTokenValidity"` + RefreshTokenValidity int `json:"RefreshTokenValidity"` +} + +// 创建应用响应 +type CreateApplicationResponse struct { + Application IMSApplication `json:"Application"` +} + +type ListApplicationsResponse struct { + Applications struct { + Application []IMSApplication `json:"Application"` + } `json:"Applications"` +} + +type GetApplicationResponse struct { + Application IMSApplication `json:"Application"` +} + +func newOpenAPIClient(ctx *cli.Context, profile config.Profile, endpoint string) (*openapiClient.Client, error) { + credential, err := profile.GetCredential(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to get credential: %w", err) + } + + conf := &openapiClient.Config{ + Credential: credential, + RegionId: tea.String(profile.RegionId), + Endpoint: tea.String(endpoint), + UserAgent: tea.String(util.GetAliyunCliUserAgent()), + } + + client, err := openapiClient.NewClient(conf) + if err != nil { + return nil, err + } + + return client, nil +} + +func getOrCreateMCPOAuthApplication(ctx *cli.Context, profile config.Profile, region RegionType, host string, port int, scope string) (*OAuthApplication, error) { + app, err := findOAuthApplicationByName(ctx, profile, region, MCPOAuthAppName) + if err != nil { + return nil, err + } + + if app != nil { + return app, nil + } + + return createMCPOauthApplication(ctx, profile, region, host, port, scope) +} + +func findOAuthApplicationById(ctx *cli.Context, profile config.Profile, mcpProfile *McpProfile, region RegionType) error { + client, err := newOpenAPIClient(ctx, profile, EndpointMap[region].IMS) + if err != nil { + return err + } + params := &openapiClient.Params{ + Action: tea.String("GetApplication"), + Version: tea.String("2019-08-15"), + Protocol: tea.String("HTTPS"), + Method: tea.String("GET"), + AuthType: tea.String("AK"), + Style: tea.String("RPC"), + Pathname: tea.String("/"), + } + runtime := &dara.RuntimeOptions{} + request := &openapiutil.OpenApiRequest{ + Query: map[string]*string{ + "AppId": tea.String(mcpProfile.MCPOAuthAppId), + }, + } + response, err := client.CallApi(params, request, runtime) + if err != nil { + return err + } + bodyBytes, err := GetContentFromApiResponse(response) + if err != nil { + return fmt.Errorf("failed to get content from api response: %w", err) + } + var responseGet GetApplicationResponse + if err := json.Unmarshal(bodyBytes, &responseGet); err != nil { + return err + } + return nil +} + +func findOAuthApplicationByName(ctx *cli.Context, profile config.Profile, region RegionType, appName string) (*OAuthApplication, error) { + client, err := newOpenAPIClient(ctx, profile, EndpointMap[region].IMS) + if err != nil { + return nil, err + } + params := &openapiClient.Params{ + Action: tea.String("ListApplications"), + Version: tea.String("2019-08-15"), + Protocol: tea.String("HTTPS"), + Method: tea.String("POST"), + AuthType: tea.String("AK"), + Style: tea.String("RPC"), + Pathname: tea.String("/"), + ReqBodyType: tea.String("json"), + BodyType: tea.String("json"), + } + runtime := &dara.RuntimeOptions{} + request := &openapiutil.OpenApiRequest{} + response, err := client.CallApi(params, request, runtime) + if err != nil { + return nil, err + } + bodyBytes, err := GetContentFromApiResponse(response) + if err != nil { + return nil, fmt.Errorf("failed to get content from api response: %w", err) + } + var responseList ListApplicationsResponse + if err := json.Unmarshal(bodyBytes, &responseList); err != nil { + return nil, err + } + + for _, app := range responseList.Applications.Application { + if app.AppName == appName { + scopes := make([]string, 0, len(app.DelegatedScope.PredefinedScopes.PredefinedScope)) + for _, s := range app.DelegatedScope.PredefinedScopes.PredefinedScope { + scopes = append(scopes, s.Name) + } + + return &OAuthApplication{ + ApplicationId: app.AppId, + AppName: app.AppName, + DisplayName: app.DisplayName, + AppType: app.AppType, + RedirectUris: app.RedirectUris.RedirectUri, + Scopes: scopes, + AccessTokenValidity: app.AccessTokenValidity, + RefreshTokenValidity: app.RefreshTokenValidity, + }, nil + } + } + return nil, nil +} + +// validateOAuthApplication 验证 OAuth 应用的 Scopes 和 Callback URI 是否符合要求 +func validateOAuthApplication(app *OAuthApplication, requiredScope string, requiredRedirectURI string) error { + if app == nil { + return fmt.Errorf("OAuth application is nil") + } + + // 验证 Scopes + scopeFound := false + for _, scope := range app.Scopes { + if scope == requiredScope { + scopeFound = true + break + } + } + if !scopeFound { + return fmt.Errorf("OAuth application '%s' does not have required scope '%s'. Available scopes: %v", + app.AppName, requiredScope, app.Scopes) + } + + // 验证 Callback URI + redirectURIFound := false + for _, uri := range app.RedirectUris { + if uri == requiredRedirectURI { + redirectURIFound = true + break + } + } + if !redirectURIFound { + return fmt.Errorf("OAuth application '%s' does not have required redirect URI '%s'. Available redirect URIs: %v", + app.AppName, requiredRedirectURI, app.RedirectUris) + } + + return nil +} + +func buildRedirectUri(host string, port int) string { + return fmt.Sprintf("http://%s:%d/callback", host, port) +} + +func createMCPOauthApplication(ctx *cli.Context, profile config.Profile, region RegionType, host string, port int, scope string) (*OAuthApplication, error) { + client, err := newOpenAPIClient(ctx, profile, EndpointMap[region].IMS) + if err != nil { + return nil, err + } + + redirectUri := buildRedirectUri(host, port) + + params := &openapiClient.Params{ + Action: tea.String("CreateApplication"), + Version: tea.String("2019-08-15"), + Protocol: tea.String("HTTPS"), + Method: tea.String("POST"), + AuthType: tea.String("AK"), + Style: tea.String("RPC"), + Pathname: tea.String("/"), + ReqBodyType: tea.String("json"), + BodyType: tea.String("json"), + } + + request := &openapiutil.OpenApiRequest{ + Query: map[string]*string{ + "AppName": tea.String(MCPOAuthAppName), + "AppType": tea.String("NativeApp"), + "DisplayName": tea.String(MCPOAuthDisplayName), + "PredefinedScopes": tea.String(scope), + "ProtocolVersion": tea.String("2.1"), + "AccessTokenValidity": tea.String(fmt.Sprintf("%d", AccessTokenValiditySec)), + "RefreshTokenValidity": tea.String(fmt.Sprintf("%d", RefreshTokenValiditySec)), + "RedirectUris": tea.String(redirectUri), + }, + } + + runtime := &dara.RuntimeOptions{} + response, err := client.CallApi(params, request, runtime) + if err != nil { + return nil, fmt.Errorf("create application failed: %w", err) + } + + bodyBytes, err := GetContentFromApiResponse(response) + if err != nil { + return nil, fmt.Errorf("failed to get content from api response: %w", err) + } + + var responseCreate CreateApplicationResponse + if err := json.Unmarshal(bodyBytes, &responseCreate); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + scopes := make([]string, 0, len(responseCreate.Application.DelegatedScope.PredefinedScopes.PredefinedScope)) + for _, s := range responseCreate.Application.DelegatedScope.PredefinedScopes.PredefinedScope { + scopes = append(scopes, s.Name) + } + + return &OAuthApplication{ + ApplicationId: responseCreate.Application.AppId, + AppName: responseCreate.Application.AppName, + DisplayName: responseCreate.Application.DisplayName, + AppType: responseCreate.Application.AppType, + RedirectUris: responseCreate.Application.RedirectUris.RedirectUri, + Scopes: scopes, + AccessTokenValidity: responseCreate.Application.AccessTokenValidity, + RefreshTokenValidity: responseCreate.Application.RefreshTokenValidity, + }, nil +} diff --git a/mcpproxy/oauth_app_test.go b/mcpproxy/oauth_app_test.go new file mode 100644 index 000000000..38a3ee5d1 --- /dev/null +++ b/mcpproxy/oauth_app_test.go @@ -0,0 +1,447 @@ +// Copyright (c) 2009-present, Alibaba Cloud All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcpproxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewOAuthCallbackManager(t *testing.T) { + manager := NewOAuthCallbackManager() + assert.NotNil(t, manager) + assert.NotNil(t, manager.pendingAuth) + assert.NotNil(t, manager.errorCh) + assert.False(t, manager.isWaiting) +} + +func TestOAuthCallbackManager_StartWaiting(t *testing.T) { + manager := NewOAuthCallbackManager() + + manager.StartWaiting() + assert.True(t, manager.IsWaiting()) + + // 清空channels + select { + case <-manager.pendingAuth: + default: + } + select { + case <-manager.errorCh: + default: + } +} + +func TestOAuthCallbackManager_StopWaiting(t *testing.T) { + manager := NewOAuthCallbackManager() + manager.StartWaiting() + assert.True(t, manager.IsWaiting()) + + manager.StopWaiting() + assert.False(t, manager.IsWaiting()) +} + +func TestOAuthCallbackManager_HandleCallback(t *testing.T) { + tests := []struct { + name string + setup func(*OAuthCallbackManager) + code string + err error + expectHandled bool + }{ + { + name: "not waiting", + setup: func(m *OAuthCallbackManager) { + // 不调用 StartWaiting + }, + code: "test-code", + err: nil, + expectHandled: false, + }, + { + name: "waiting with code", + setup: func(m *OAuthCallbackManager) { + m.StartWaiting() + }, + code: "test-code", + err: nil, + expectHandled: true, + }, + { + name: "waiting with error", + setup: func(m *OAuthCallbackManager) { + m.StartWaiting() + }, + code: "", + err: assert.AnError, + expectHandled: true, + }, + { + name: "waiting with empty code", + setup: func(m *OAuthCallbackManager) { + m.StartWaiting() + }, + code: "", + err: nil, + expectHandled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := NewOAuthCallbackManager() + tt.setup(manager) + // 有error 或者 有code, 即handled,否则均为非handled + handled := manager.HandleCallback(tt.code, tt.err) + assert.Equal(t, tt.expectHandled, handled) + }) + } +} + +func TestOAuthCallbackManager_WaitForCode(t *testing.T) { + t.Run("receive code", func(t *testing.T) { + manager := NewOAuthCallbackManager() + manager.StartWaiting() + + go func() { + time.Sleep(10 * time.Millisecond) + manager.HandleCallback("test-code", nil) + }() + + code, err := manager.WaitForCode() + assert.NoError(t, err) + assert.Equal(t, "test-code", code) + }) + + t.Run("receive error", func(t *testing.T) { + manager := NewOAuthCallbackManager() + manager.StartWaiting() + + testErr := assert.AnError + go func() { + time.Sleep(10 * time.Millisecond) + manager.HandleCallback("", testErr) + }() + + code, err := manager.WaitForCode() + assert.Error(t, err) + assert.Empty(t, code) + }) + + t.Run("timeout", func(t *testing.T) { + // 临时缩短超时时间用于测试 + originalTimeout := oauthTimeout + oauthTimeout = 100 * time.Millisecond + defer func() { + oauthTimeout = originalTimeout + }() + + manager := NewOAuthCallbackManager() + manager.StartWaiting() + + code, err := manager.WaitForCode() + assert.Error(t, err) + assert.Contains(t, err.Error(), "timeout") + assert.Empty(t, code) + }) +} + +func TestGenerateCodeVerifier(t *testing.T) { + verifier, err := generateCodeVerifier() + assert.NoError(t, err) + assert.NotEmpty(t, verifier) + assert.GreaterOrEqual(t, len(verifier), 32) // base64编码后应该至少有32个字符 + + // 测试多次生成应该不同 + verifier2, err := generateCodeVerifier() + assert.NoError(t, err) + assert.NotEqual(t, verifier, verifier2) +} + +func TestGenerateCodeChallenge(t *testing.T) { + verifier := "test-verifier-string-for-pkce" + challenge := generateCodeChallenge(verifier) + + assert.NotEmpty(t, challenge) + assert.NotEqual(t, verifier, challenge) + + // 相同verifier应该生成相同challenge + challenge2 := generateCodeChallenge(verifier) + assert.Equal(t, challenge, challenge2) + + // 不同verifier应该生成不同challenge + challenge3 := generateCodeChallenge("different-verifier") + assert.NotEqual(t, challenge, challenge3) +} + +func TestBuildRedirectUri(t *testing.T) { + tests := []struct { + name string + host string + port int + expected string + }{ + { + name: "localhost", + host: "127.0.0.1", + port: 8088, + expected: "http://127.0.0.1:8088/callback", + }, + { + name: "all interfaces", + host: "0.0.0.0", + port: 9000, + expected: "http://0.0.0.0:9000/callback", + }, + { + name: "custom host", + host: "example.com", + port: 443, + expected: "http://example.com:443/callback", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildRedirectUri(tt.host, tt.port) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBuildOAuthURL(t *testing.T) { + host := "127.0.0.1" + port := 8088 + codeChallenge := "test-challenge" + scope := "/acs/mcp-server" + + // 测试CN区域 + urlCN := buildOAuthURL("test-app-id", RegionCN, host, port, codeChallenge, scope) + assert.Contains(t, urlCN, EndpointMap[RegionCN].SignIn) + assert.Contains(t, urlCN, "test-app-id") + assert.Contains(t, urlCN, "test-challenge") + assert.Contains(t, urlCN, "code_challenge_method=S256") + assert.Contains(t, urlCN, "response_type=code") + + // 测试INTL区域 + urlINTL := buildOAuthURL("test-app-id", RegionINTL, host, port, codeChallenge, scope) + assert.Contains(t, urlINTL, EndpointMap[RegionINTL].SignIn) + + // 验证URL可以解析 + parsedURL, err := url.Parse(urlCN) + assert.NoError(t, err) + assert.Equal(t, "https", parsedURL.Scheme) + assert.Equal(t, EndpointMap[RegionCN].SignIn, parsedURL.Scheme+"://"+parsedURL.Host) + + query := parsedURL.Query() + assert.Equal(t, "test-app-id", query.Get("client_id")) + assert.Equal(t, "code", query.Get("response_type")) + assert.Equal(t, codeChallenge, query.Get("code_challenge")) + assert.Equal(t, "S256", query.Get("code_challenge_method")) +} + +func TestHandleOAuthCallbackRequest(t *testing.T) { + tests := []struct { + name string + path string + code string + showCode bool + expectHandled bool + expectStatus int + }{ + { + name: "valid callback with code", + path: "/callback", + code: "test-code-123", + showCode: false, + expectHandled: true, + expectStatus: http.StatusOK, + }, + { + name: "valid callback with code (show code)", + path: "/callback", + code: "test-code-456", + showCode: true, + expectHandled: true, + expectStatus: http.StatusOK, + }, + { + name: "callback without code", + path: "/callback", + code: "", + showCode: false, + expectHandled: true, + expectStatus: http.StatusBadRequest, + }, + { + name: "wrong path", + path: "/other", + code: "test-code", + showCode: false, + expectHandled: false, + expectStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := NewOAuthCallbackManager() + manager.StartWaiting() + + req := httptest.NewRequest("GET", tt.path+"?code="+tt.code, nil) + w := httptest.NewRecorder() + + handled := handleOAuthCallbackRequest(w, req, manager.HandleCallback, tt.showCode) + assert.Equal(t, tt.expectHandled, handled) + assert.Equal(t, tt.expectStatus, w.Code) + + if tt.expectHandled && tt.code != "" { + select { + case code := <-manager.pendingAuth: + assert.Equal(t, tt.code, code) + case <-time.After(100 * time.Millisecond): + t.Error("code not received in channel") + } + } + }) + } +} + +func TestOAuthRefresh(t *testing.T) { + t.Run("successful refresh", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/v1/token", r.URL.Path) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + "token_type": "Bearer" + }`)) + })) + defer server.Close() + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", "old-refresh-token") + data.Set("client_id", "test-app-id") + + tokenResp, err := oauthRefresh(server.URL, data) + assert.NoError(t, err) + assert.NotNil(t, tokenResp) + assert.Equal(t, "new-access-token", tokenResp.AccessToken) + assert.Equal(t, "new-refresh-token", tokenResp.RefreshToken) + assert.Equal(t, int64(3600), tokenResp.ExpiresIn) + assert.Equal(t, "Bearer", tokenResp.TokenType) + }) + + t.Run("error response", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{ + "error": "invalid_grant", + "error_description": "Refresh token expired" + }`)) + })) + defer server.Close() + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("client_id", "test-app-id") + data.Set("refresh_token", "expired-token") + + tokenResp, err := oauthRefresh(server.URL, data) + assert.Error(t, err) + assert.Nil(t, tokenResp) + }) + + t.Run("non-200 status", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + data := url.Values{} + tokenResp, err := oauthRefresh(server.URL, data) + assert.Error(t, err) + assert.Contains(t, err.Error(), "status 500") + assert.Nil(t, tokenResp) + }) +} + +func TestExchangeCodeForTokenWithPKCE(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, "authorization_code", r.Form.Get("grant_type")) + assert.Equal(t, "test-code", r.Form.Get("code")) + assert.Equal(t, "test-app-id", r.Form.Get("client_id")) + assert.Equal(t, "http://127.0.0.1:8088/callback", r.Form.Get("redirect_uri")) + assert.Equal(t, "test-verifier", r.Form.Get("code_verifier")) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": 3600, + "token_type": "Bearer" + }`)) + })) + defer server.Close() + + tokenResult, err := exchangeCodeForTokenWithPKCE("test-app-id", "test-code", "test-verifier", "http://127.0.0.1:8088/callback", server.URL) + assert.NoError(t, err) + assert.Equal(t, "access-token", tokenResult.AccessToken) + assert.Equal(t, "refresh-token", tokenResult.RefreshToken) + assert.NotZero(t, tokenResult.AccessTokenExpire) +} + +func TestRegionTypeConstants(t *testing.T) { + assert.Equal(t, RegionType("CN"), RegionCN) + assert.Equal(t, RegionType("INTL"), RegionINTL) + assert.Equal(t, "default-mcp", DefaultMcpProfileName) +} + +func TestEndpointMap(t *testing.T) { + assert.NotNil(t, EndpointMap[RegionCN]) + assert.NotNil(t, EndpointMap[RegionINTL]) + + assert.NotEmpty(t, EndpointMap[RegionCN].SignIn) + assert.NotEmpty(t, EndpointMap[RegionCN].OAuth) + assert.NotEmpty(t, EndpointMap[RegionCN].IMS) + assert.NotEmpty(t, EndpointMap[RegionCN].MCP) + + assert.NotEmpty(t, EndpointMap[RegionINTL].SignIn) + assert.NotEmpty(t, EndpointMap[RegionINTL].OAuth) + assert.NotEmpty(t, EndpointMap[RegionINTL].IMS) + assert.NotEmpty(t, EndpointMap[RegionINTL].MCP) + + assert.True(t, strings.HasPrefix(EndpointMap[RegionCN].SignIn, "https://")) + assert.True(t, strings.HasPrefix(EndpointMap[RegionCN].OAuth, "https://")) + assert.True(t, strings.HasPrefix(EndpointMap[RegionINTL].SignIn, "https://")) + assert.True(t, strings.HasPrefix(EndpointMap[RegionINTL].OAuth, "https://")) +} diff --git a/util/util.go b/util/util.go index ff3249633..dd5ddcf94 100644 --- a/util/util.go +++ b/util/util.go @@ -12,6 +12,7 @@ import ( "runtime" "time" + "github.com/aliyun/aliyun-cli/v3/cli" "github.com/aliyun/aliyun-cli/v3/i18n" ) @@ -114,3 +115,11 @@ func CopyFileAndRemoveSource(sourceFile, destFile string) error { _ = os.Remove(sourceFile) return nil } + +func GetAliyunCliUserAgent() string { + ua := "Aliyun-CLI/" + cli.GetVersion() + if vendorEnv, ok := os.LookupEnv("ALIBABA_CLOUD_VENDOR"); ok { + ua += " vendor/" + vendorEnv + } + return ua +} diff --git a/util/util_test.go b/util/util_test.go index c8a199b49..7c6b45273 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -541,3 +541,93 @@ func TestCopyFileAndRemoveSource(t *testing.T) { assert.NoError(t, err) }) } + +func TestGetAliyunCliUserAgent(t *testing.T) { + // 保存原始环境变量 + originalVendor := os.Getenv("ALIBABA_CLOUD_VENDOR") + defer func() { + if originalVendor != "" { + os.Setenv("ALIBABA_CLOUD_VENDOR", originalVendor) + } else { + os.Unsetenv("ALIBABA_CLOUD_VENDOR") + } + }() + + t.Run("without vendor environment variable", func(t *testing.T) { + os.Unsetenv("ALIBABA_CLOUD_VENDOR") + + ua := GetAliyunCliUserAgent() + + assert.NotEmpty(t, ua) + assert.Contains(t, ua, "Aliyun-CLI/") + assert.NotContains(t, ua, "vendor/") + }) + + t.Run("with vendor environment variable", func(t *testing.T) { + testVendor := "test-vendor" + os.Setenv("ALIBABA_CLOUD_VENDOR", testVendor) + + ua := GetAliyunCliUserAgent() + + assert.NotEmpty(t, ua) + assert.Contains(t, ua, "Aliyun-CLI/") + assert.Contains(t, ua, "vendor/") + assert.Contains(t, ua, testVendor) + assert.True(t, strings.HasSuffix(ua, "vendor/"+testVendor)) + }) + + t.Run("with empty vendor environment variable", func(t *testing.T) { + os.Setenv("ALIBABA_CLOUD_VENDOR", "") + + ua := GetAliyunCliUserAgent() + + assert.NotEmpty(t, ua) + assert.Contains(t, ua, "Aliyun-CLI/") + // os.LookupEnv 返回 ok=true 即使值为空字符串,所以会包含 vendor/ + assert.Contains(t, ua, "vendor/") + assert.True(t, strings.HasSuffix(ua, "vendor/")) + }) + + t.Run("with different vendor values", func(t *testing.T) { + testCases := []string{ + "alibaba-cloud", + "aliyun", + "custom-vendor-123", + "vendor-with-special-chars-!@#", + } + + for _, vendor := range testCases { + t.Run("vendor_"+vendor, func(t *testing.T) { + os.Setenv("ALIBABA_CLOUD_VENDOR", vendor) + + ua := GetAliyunCliUserAgent() + + assert.Contains(t, ua, "Aliyun-CLI/") + assert.Contains(t, ua, "vendor/") + assert.Contains(t, ua, vendor) + assert.True(t, strings.HasSuffix(ua, "vendor/"+vendor)) + }) + } + }) + + t.Run("format consistency", func(t *testing.T) { + os.Setenv("ALIBABA_CLOUD_VENDOR", "test-vendor") + + ua := GetAliyunCliUserAgent() + + // 验证格式:Aliyun-CLI/ vendor/ + parts := strings.Split(ua, " ") + assert.Len(t, parts, 2) + assert.True(t, strings.HasPrefix(parts[0], "Aliyun-CLI/")) + assert.Equal(t, "vendor/test-vendor", parts[1]) + }) + + t.Run("multiple calls return consistent result", func(t *testing.T) { + os.Setenv("ALIBABA_CLOUD_VENDOR", "consistent-vendor") + + ua1 := GetAliyunCliUserAgent() + ua2 := GetAliyunCliUserAgent() + + assert.Equal(t, ua1, ua2) + }) +}