diff --git a/mcpproxy/mcp_profile.go b/mcpproxy/mcp_profile.go index bf0becbe0..c3e16c50b 100644 --- a/mcpproxy/mcp_profile.go +++ b/mcpproxy/mcp_profile.go @@ -64,6 +64,9 @@ func saveMcpProfile(profile *McpProfile) error { return fmt.Errorf("failed to create config directory %q: %w", dir, err) } + // log.Printf("saveMcpProfile: Before marshaling - RefreshToken length=%d, RefreshToken empty=%v", + // len(profile.MCPOAuthRefreshToken), profile.MCPOAuthRefreshToken == "") + tempFile := mcpConfigPath + ".tmp" bytes, err := json.MarshalIndent(profile, "", "\t") @@ -71,6 +74,11 @@ func saveMcpProfile(profile *McpProfile) error { return fmt.Errorf("failed to marshal profile: %w", err) } + // jsonStr := string(bytes) + // hasRefreshToken := strings.Contains(jsonStr, "mcp_oauth_refresh_token") + // log.Printf("saveMcpProfile: After marshaling - JSON contains 'mcp_oauth_refresh_token'=%v, JSON length=%d", + // hasRefreshToken, len(jsonStr)) + if err := os.WriteFile(tempFile, bytes, 0600); err != nil { return fmt.Errorf("failed to write temp file %q: %w", tempFile, err) } @@ -80,64 +88,121 @@ func saveMcpProfile(profile *McpProfile) error { _ = os.Remove(tempFile) return fmt.Errorf("failed to rename temp file to %q: %w", mcpConfigPath, err) } + + log.Printf("saveMcpProfile: Successfully saved MCP profile") return nil } +// loadExistingMCPProfile 加载并验证已有的 MCP profile,如果有效则返回,避免重复拉起 OAuth +func loadExistingMCPProfile(ctx *cli.Context, profile config.Profile, opts ProxyConfig, desiredAppName string) *McpProfile { + mcpConfigPath := getMCPConfigPath() + bytes, err := os.ReadFile(mcpConfigPath) + if err != nil { + return nil + } + mcpProfile, err := NewMcpProfileFromBytes(bytes) + if err != nil { + return nil + } + + if mcpProfile.MCPOAuthSiteType != string(opts.RegionType) { + log.Printf("Region type mismatch: saved=%s, requested=%s, ignoring local profile", mcpProfile.MCPOAuthSiteType, string(opts.RegionType)) + return nil + } + + if mcpProfile.MCPOAuthAppName != desiredAppName { + log.Printf("App name mismatch: saved=%s, requested=%s, ignoring local profile", mcpProfile.MCPOAuthAppName, desiredAppName) + return nil + } + + if mcpProfile.MCPOAuthAppId == "" { + log.Printf("MCP profile with AppId is empty, ignoring local profile") + return nil + } + + if mcpProfile.MCPOAuthRefreshToken == "" { + log.Printf("MCP profile with RefreshToken is empty, ignoring local profile") + return nil + } + + if mcpProfile.MCPOAuthRefreshTokenExpire <= util.GetCurrentUnixTime() { + log.Printf("MCP profile with RefreshTokenExpire is expired, ignoring local profile") + return nil + } + + app, err := findOAuthApplicationById(ctx, profile, mcpProfile.MCPOAuthAppId, opts.RegionType) + if err != nil { + log.Printf("Failed to reuse existing MCP profile (app: %s): %v, ignoring local profile", mcpProfile.MCPOAuthAppName, err) + return nil + } + if app == nil { + log.Printf("OAuth application with AppId '%s' not found, ignoring local profile", mcpProfile.MCPOAuthAppId) + return nil + } + + if err := validateOAuthApplication(app, opts.Scope, opts.Host, opts.Port); err != nil { + log.Printf("Reused existing MCP profile validation failed: %v, ignoring local profile", err) + return nil + } + + // 根据远端 app 信息更新 mcp profile 中的相关字段,其他字段(如 token)保持不变 + mcpProfile.MCPOAuthAppName = app.AppName + mcpProfile.MCPOAuthAppId = app.ApplicationId + mcpProfile.MCPOAuthAccessTokenValidity = app.AccessTokenValidity + mcpProfile.MCPOAuthRefreshTokenValidity = app.RefreshTokenValidity + + log.Printf("Reused existing MCP profile with app '%s' (AppId: %s)", app.AppName, app.ApplicationId) + + return mcpProfile +} + func getOrCreateMCPProfile(ctx *cli.Context, opts ProxyConfig) (*McpProfile, error) { profile, err := config.LoadProfileWithContext(ctx) if err != nil { return nil, fmt.Errorf("failed to load profile: %w", err) } - // 如果传入了 oauth-app-name,先验证该应用是否存在且合法 - // 如果已经验证过 oauth-app-name,直接使用验证过的 app;否则查找或创建默认的 OAuth 应用 - var validatedApp *OAuthApplication - if opts.OAuthAppName != "" { - app, err := findOAuthApplicationByName(ctx, profile, opts.RegionType, opts.OAuthAppName) - if err != nil { - return nil, fmt.Errorf("failed to find OAuth application '%s': %w", opts.OAuthAppName, err) - } - if app == nil { - return nil, fmt.Errorf("OAuth application '%s' not found", opts.OAuthAppName) - } + // 如果未显式指定 app name,则使用默认的 MCPOAuthAppName,便于复用本地 profile + desiredAppName := opts.OAuthAppName + if desiredAppName == "" { + desiredAppName = MCPOAuthAppName + } - // 验证 Scopes 和 Callback URI - requiredRedirectURI := buildRedirectUri(opts.Host, opts.Port) - if err := validateOAuthApplication(app, opts.Scope, requiredRedirectURI); err != nil { - return nil, fmt.Errorf("OAuth application validation failed: %w", err) + existingMcpProfile := loadExistingMCPProfile(ctx, profile, opts, desiredAppName) + if existingMcpProfile != nil { + // mcpprofile might change, save it again to ensure the latest state is saved + if err := saveMcpProfile(existingMcpProfile); err != nil { + return nil, fmt.Errorf("failed to save mcp profile: %w", err) } + return existingMcpProfile, nil + } - validatedApp = app - cli.Printf(ctx.Stdout(), "Using existing OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId) - } else { - // 查找或创建默认的 OAuth 应用 - mcpConfigPath := getMCPConfigPath() - if bytes, err := os.ReadFile(mcpConfigPath); err == nil { - if mcpProfile, err := NewMcpProfileFromBytes(bytes); err == nil { - log.Println("MCP Profile loaded from file", mcpProfile.Name, "app id", mcpProfile.MCPOAuthAppId) - - // 检查 region type 是否匹配,因为国内和国际站的 OAuth 地址不同, Region type 不匹配则重新创建 profile - if mcpProfile.MCPOAuthSiteType != string(opts.RegionType) { - log.Printf("Region type mismatch: saved=%s, requested=%s, recreating profile", mcpProfile.MCPOAuthSiteType, string(opts.RegionType)) - } else { - err = findOAuthApplicationById(ctx, profile, mcpProfile, opts.RegionType) - if err == nil { - return mcpProfile, nil - } else { - log.Println("Failed to find existing OAuth application", err.Error()) - } - } - } + app, err := findOAuthApplicationByName(ctx, profile, opts.RegionType, desiredAppName) + if err != nil { + return nil, fmt.Errorf("failed to find OAuth application '%s': %w", desiredAppName, err) + } + + if app == nil { + if opts.OAuthAppName != "" { + // if user provide app name, but not found, return error + return nil, fmt.Errorf("OAuth application '%s' not found", opts.OAuthAppName) } - app, err := getOrCreateMCPOAuthApplication(ctx, profile, opts.RegionType, opts.Host, opts.Port, opts.Scope) + cli.Printf(ctx.Stdout(), "Creating new default MCP profile '%s'...\n", DefaultMcpProfileName) + app, err = createDefaultMCPOauthApplication(ctx, profile, opts.RegionType, opts.Host, opts.Port, opts.Scope) if err != nil { - return nil, fmt.Errorf("failed to get or create OAuth application: %w", err) + return nil, fmt.Errorf("failed to create default OAuth application: %w", err) } - validatedApp = app + cli.Printf(ctx.Stdout(), "Created new default OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId) + } else { + cli.Printf(ctx.Stdout(), "Using existing OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId) } - cli.Printf(ctx.Stdout(), "Setting up MCPOAuth profile '%s'...\n", DefaultMcpProfileName) + if err := validateOAuthApplication(app, opts.Scope, opts.Host, opts.Port); err != nil { + return nil, fmt.Errorf("OAuth application validation failed: %w", err) + } + validatedApp := app + cli.Printf(ctx.Stdout(), "Setting up MCPOAuth profile '%s'...\n", DefaultMcpProfileName) mcpProfile := NewMcpProfile(DefaultMcpProfileName) mcpProfile.MCPOAuthSiteType = string(opts.RegionType) mcpProfile.MCPOAuthAppId = validatedApp.ApplicationId @@ -153,9 +218,20 @@ func getOrCreateMCPProfile(ctx *cli.Context, opts ProxyConfig) (*McpProfile, err if err != nil { return nil, fmt.Errorf("OAuth login failed: %w", err) } + + log.Printf("OAuth flow completed: AccessToken length=%d, RefreshToken length=%d, AccessTokenExpire=%d", + len(tokenResult.AccessToken), len(tokenResult.RefreshToken), tokenResult.AccessTokenExpire) + if tokenResult.RefreshToken == "" { + return nil, fmt.Errorf("OAuth flow returned empty RefreshToken (Region=%s, AppId=%s). "+ + "Please delete this application and let the system create a new NativeApp, or manually create a NativeApp", + opts.RegionType, mcpProfile.MCPOAuthAppId) + } + mcpProfile.MCPOAuthAccessToken = tokenResult.AccessToken mcpProfile.MCPOAuthRefreshToken = tokenResult.RefreshToken mcpProfile.MCPOAuthAccessTokenExpire = tokenResult.AccessTokenExpire + // refresh token will be updated each time latest access token is refreshed, + // however the validity and expiration time is the same as the original when finishing oauth flow mcpProfile.MCPOAuthRefreshTokenExpire = currentTime + int64(validatedApp.RefreshTokenValidity) if err = saveMcpProfile(mcpProfile); err != nil { diff --git a/mcpproxy/mcp_profile_test.go b/mcpproxy/mcp_profile_test.go index 205f8c025..9e9dda16b 100644 --- a/mcpproxy/mcp_profile_test.go +++ b/mcpproxy/mcp_profile_test.go @@ -15,11 +15,16 @@ package mcpproxy import ( + "bytes" "encoding/json" + "fmt" "os" "path/filepath" "testing" + "time" + "github.com/aliyun/aliyun-cli/v3/cli" + "github.com/aliyun/aliyun-cli/v3/config" "github.com/stretchr/testify/assert" ) @@ -288,3 +293,394 @@ func TestMcpProfileJSONSerialization(t *testing.T) { assert.Equal(t, profile.MCPOAuthSiteType, loadedProfile.MCPOAuthSiteType) assert.Equal(t, profile.MCPOAuthAppId, loadedProfile.MCPOAuthAppId) } + +func TestLoadExistingMCPProfile(t *testing.T) { + t.Run("config file not exists", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + profile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, profile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when config file does not exist") + }) + + t.Run("invalid json in config file", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建无效的配置文件 + configPath := getMCPConfigPath() + err := os.MkdirAll(filepath.Dir(configPath), 0755) + assert.NoError(t, err) + err = os.WriteFile(configPath, []byte("{invalid json}"), 0644) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + profile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, profile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when config file has invalid JSON") + }) + + t.Run("region type mismatch", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建配置文件,region type 为 CN + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = "CN" + profile.MCPOAuthAppName = "default-mcp" + profile.MCPOAuthAppId = "test-app-id" + profile.MCPOAuthRefreshToken = "test-refresh-token" + profile.MCPOAuthRefreshTokenExpire = time.Now().Unix() + 3600 + err := saveMcpProfile(profile) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + configProfile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionINTL, // 请求 INTL,但保存的是 CN + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, configProfile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when region type mismatch") + }) + + t.Run("app name mismatch", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建配置文件,app name 为 "default-mcp" + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = "CN" + profile.MCPOAuthAppName = "default-mcp" + profile.MCPOAuthAppId = "test-app-id" + profile.MCPOAuthRefreshToken = "test-refresh-token" + profile.MCPOAuthRefreshTokenExpire = time.Now().Unix() + 3600 + err := saveMcpProfile(profile) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + configProfile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + OAuthAppName: "different-app", // 请求不同的 app name + } + + result := loadExistingMCPProfile(ctx, configProfile, opts, "different-app") + assert.Nil(t, result, "should return nil when app name mismatch") + }) + + t.Run("empty AppId", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建配置文件,AppId 为空 + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = "CN" + profile.MCPOAuthAppName = "default-mcp" + profile.MCPOAuthAppId = "" // 空的 AppId + profile.MCPOAuthRefreshToken = "test-refresh-token" + profile.MCPOAuthRefreshTokenExpire = time.Now().Unix() + 3600 + err := saveMcpProfile(profile) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + configProfile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, configProfile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when AppId is empty") + }) + + t.Run("empty RefreshToken", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建配置文件,RefreshToken 为空 + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = "CN" + profile.MCPOAuthAppName = "default-mcp" + profile.MCPOAuthAppId = "test-app-id" + profile.MCPOAuthRefreshToken = "" // 空的 RefreshToken + profile.MCPOAuthRefreshTokenExpire = time.Now().Unix() + 3600 + err := saveMcpProfile(profile) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + configProfile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, configProfile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when RefreshToken is empty") + }) + + t.Run("expired RefreshToken", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建配置文件,RefreshToken 已过期 + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = "CN" + profile.MCPOAuthAppName = "default-mcp" + profile.MCPOAuthAppId = "test-app-id" + profile.MCPOAuthRefreshToken = "test-refresh-token" + profile.MCPOAuthRefreshTokenExpire = time.Now().Unix() - 100 // 已过期 + err := saveMcpProfile(profile) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + configProfile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, configProfile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when RefreshToken is expired") + }) + + t.Run("valid profile but missing required fields", func(t *testing.T) { + tmpDir := t.TempDir() + originalHome := os.Getenv("HOME") + defer func() { + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + }() + os.Setenv("HOME", tmpDir) + + // 创建最小化的配置文件(缺少必要字段) + profile := NewMcpProfile("test-profile") + profile.MCPOAuthSiteType = "CN" + profile.MCPOAuthAppName = "default-mcp" + // 缺少 AppId 和 RefreshToken + err := saveMcpProfile(profile) + assert.NoError(t, err) + + ctx := cli.NewCommandContext(&bytes.Buffer{}, &bytes.Buffer{}) + configProfile := config.NewProfile("test") + opts := ProxyConfig{ + RegionType: RegionCN, + Host: "127.0.0.1", + Port: 8088, + Scope: "/acs/mcp-server", + } + + result := loadExistingMCPProfile(ctx, configProfile, opts, "default-mcp") + assert.Nil(t, result, "should return nil when required fields are missing") + }) +} + +func TestGetOrCreateMCPProfile_FindOAuthApplicationLogic(t *testing.T) { + + t.Run("validateOAuthApplication with nil app", func(t *testing.T) { + err := validateOAuthApplication(nil, "/acs/mcp-server", "127.0.0.1", 8088) + assert.Error(t, err) + assert.Contains(t, err.Error(), "OAuth application is nil") + }) + + t.Run("validateOAuthApplication with wrong AppType", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "WebApp", // 错误的类型 + Scopes: []string{"/acs/mcp-server"}, + RedirectUris: []string{"http://127.0.0.1:8088/callback"}, + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be 'NativeApp'") + assert.Contains(t, err.Error(), "WebApp") + }) + + t.Run("validateOAuthApplication with missing scope", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "NativeApp", + Scopes: []string{"/other/scope"}, // 缺少 required scope + RedirectUris: []string{"http://127.0.0.1:8088/callback"}, + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not have required scope") + assert.Contains(t, err.Error(), "/acs/mcp-server") + }) + + t.Run("validateOAuthApplication with wrong redirect URI", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "NativeApp", + Scopes: []string{"/acs/mcp-server"}, + RedirectUris: []string{"http://127.0.0.1:9999/callback"}, // 错误的端口 + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not have required redirect URI") + assert.Contains(t, err.Error(), "http://127.0.0.1:8088/callback") + }) + + t.Run("validateOAuthApplication with valid app", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "NativeApp", + Scopes: []string{"/acs/mcp-server"}, + RedirectUris: []string{"http://127.0.0.1:8088/callback"}, + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.NoError(t, err) + }) + + t.Run("validateOAuthApplication with multiple scopes", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "NativeApp", + Scopes: []string{"/other/scope", "/acs/mcp-server", "/another/scope"}, + RedirectUris: []string{"http://127.0.0.1:8088/callback"}, + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.NoError(t, err, "should pass validation when required scope is present among multiple scopes") + }) + + t.Run("validateOAuthApplication with multiple redirect URIs", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "NativeApp", + Scopes: []string{"/acs/mcp-server"}, + RedirectUris: []string{"http://0.0.0.0:8088/callback", "http://127.0.0.1:8088/callback"}, + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.NoError(t, err, "should pass validation when required redirect URI is present among multiple URIs") + }) + + t.Run("validateOAuthApplication error message format", func(t *testing.T) { + app := &OAuthApplication{ + AppName: "test-app", + ApplicationId: "app-id", + AppType: "WebApp", + Scopes: []string{"/other/scope"}, + RedirectUris: []string{"http://127.0.0.1:9999/callback"}, + AccessTokenValidity: 10800, + RefreshTokenValidity: 31536000, + } + + err := validateOAuthApplication(app, "/acs/mcp-server", "127.0.0.1", 8088) + assert.Error(t, err) + wrappedErr := fmt.Errorf("OAuth application validation failed: %w", err) + assert.Contains(t, wrappedErr.Error(), "validation failed") + assert.Contains(t, wrappedErr.Error(), "must be 'NativeApp'") + }) +} diff --git a/mcpproxy/mcp_server.go b/mcpproxy/mcp_server.go index 07fb25532..d3a7ac346 100644 --- a/mcpproxy/mcp_server.go +++ b/mcpproxy/mcp_server.go @@ -89,9 +89,11 @@ type RuntimeStats struct { ErrorRequests int64 ActiveRequests int64 - TokenRefreshes int64 - TokenRefreshErrors int64 - LastTokenRefresh int64 + TokenRefreshes int64 + TokenRefreshErrors int64 + LastTokenRefresh int64 + HealthCheckCounter int64 // 用于定期输出健康检查日志 + LastHealthCheckTime int64 // 启动时的内存状态 InitialMemStats runtime.MemStats @@ -691,6 +693,8 @@ func (r *TokenRefresher) Stop() { func (r *TokenRefresher) checkAndRefresh() { r.mu.RLock() currentTime := util.GetCurrentUnixTime() + accessTokenRemaining := r.profile.MCPOAuthAccessTokenExpire - currentTime + refreshTokenRemaining := r.profile.MCPOAuthRefreshTokenExpire - currentTime needRefresh := false needReauth := false @@ -704,14 +708,32 @@ func (r *TokenRefresher) checkAndRefresh() { } r.mu.RUnlock() + // 定期输出健康检查日志(每 120 次检查,即每小时) + checkCount := atomic.AddInt64(&r.stats.HealthCheckCounter, 1) + if checkCount%120 == 0 { + log.Printf("Token health check #%d: access_token_remaining=%dm (%.1fh), refresh_token_remaining=%dh (%.1fd), need_refresh=%v, need_reauth=%v", + checkCount, + accessTokenRemaining/60, + float64(accessTokenRemaining)/3600, + refreshTokenRemaining/3600, + float64(refreshTokenRemaining)/86400, + needRefresh, + needReauth) + atomic.StoreInt64(&r.stats.LastHealthCheckTime, currentTime) + } + if needReauth { + log.Printf("Token check: refresh token expiring soon (remaining: %dm), triggering re-authorization", refreshTokenRemaining/60) if err := r.reauthorizeWithProxy(); err != nil { - r.reportFatalError(fmt.Errorf("re-authorization failed: %v. Please restart aliyun mcp-proxy", err)) + log.Printf("MCP Proxy token refresher re-authorization failed: %v", err) + r.reportFatalError(fmt.Errorf("re-authorization failed: %w", err)) return } } else if needRefresh { + log.Printf("Token check: access token expiring soon (remaining: %dm), triggering refresh access token", accessTokenRemaining/60) if err := r.refreshAccessToken(); err != nil { - r.reportFatalError(fmt.Errorf("refresh access token failed. Please restart aliyun mcp-proxy")) + log.Printf("MCP Proxy token refresher refresh access token failed: %v", err) + r.reportFatalError(fmt.Errorf("refresh access token failed: %w", err)) return } } @@ -736,6 +758,7 @@ func (r *TokenRefresher) refreshAccessToken() error { endpoint := EndpointMap[r.regionType].OAuth clientId := r.profile.MCPOAuthAppId refreshToken := r.profile.MCPOAuthRefreshToken + accessTokenExpire := r.profile.MCPOAuthAccessTokenExpire r.mu.Unlock() // 执行网络请求(不持有锁,避免阻塞) @@ -743,21 +766,81 @@ func (r *TokenRefresher) refreshAccessToken() error { data.Set("grant_type", "refresh_token") data.Set("client_id", clientId) data.Set("refresh_token", refreshToken) - // fmt.Println("refresh access token data", data.Encode()) - // fmt.Println("refresh access token endpoint", endpoint) - // fmt.Println("refresh access token clientId", clientId) - // fmt.Println("refresh access token refreshToken", refreshToken) - newTokens, err := oauthRefresh(endpoint, data) + // 重试逻辑:最多重试 3 次,使用指数退避,避免因为临时网络问题导致服务直接关闭 + var newTokens *OAuthTokenResponse + var err error + maxRetries := 3 + successAttempt := 0 + + for attempt := 1; attempt <= maxRetries; attempt++ { + if attempt > 1 { + backoffDuration := time.Duration(1< 0 { + log.Printf("Temporary refresh failure, access token still valid for %ds, continuing to use current token (will retry on next check in 30s)", accessTimeRemaining) + return nil + } + + // Access token 已经失效,报告致命错误 + log.Printf("Access token expired and refresh failed, service is unavailable") + log.Printf("Reporting fatal error - service requires restart or re-authorization") + return fmt.Errorf("oauth refresh failed and access token expired: %w", err) } - log.Println("Access token refresh request successfully") + log.Printf("Access token refresh request successfully after %d attempt(s), access token length=%d, refresh token length=%d, new access token expires in %d seconds", + successAttempt, len(newTokens.AccessToken), len(newTokens.RefreshToken), newTokens.ExpiresIn) r.mu.Lock() currentTime := util.GetCurrentUnixTime() @@ -935,9 +1018,11 @@ func (r *TokenRefresher) reauthorizeWithProxy() error { r.reauthorizing = false r.mu.Unlock() atomic.AddInt64(&r.stats.TokenRefreshErrors, 1) - return err + log.Printf("OAuth re-authorization request failed: %v", err) + return fmt.Errorf("OAuth re-authorization failed: %w", err) } - log.Println("OAuth re-authorization request successfully") + log.Printf("OAuth re-authorization request successfully: AccessToken length=%d, RefreshToken length=%d, ExpiresIn=%d", + len(tokenResult.AccessToken), len(tokenResult.RefreshToken), refreshTokenValidity) r.mu.Lock() currentTime := util.GetCurrentUnixTime() diff --git a/mcpproxy/mcp_server_test.go b/mcpproxy/mcp_server_test.go index f77c86989..db304d63e 100644 --- a/mcpproxy/mcp_server_test.go +++ b/mcpproxy/mcp_server_test.go @@ -415,6 +415,48 @@ func TestGetContentFromApiResponse_Integration(t *testing.T) { assert.Contains(t, string(content), "value") } +func TestTokenRefresher_refreshAccessToken_PermanentError(t *testing.T) { + // 测试永久性错误时立即停止重试 + // 这里主要测试逻辑:永久性错误应该立即返回,不应该重试 + // 实际的网络调用测试在 oauth_app_test.go 中 + permanentErr := &OAuthPermanentError{ + StatusCode: 400, + ErrorCode: "invalid_grant", + Message: "OAuth permanent error: invalid_grant (status 400)", + } + + assert.True(t, IsPermanentError(permanentErr)) + + // 验证错误信息 + assert.Contains(t, permanentErr.Error(), "invalid_grant") + assert.Equal(t, 400, permanentErr.StatusCode) + assert.Equal(t, "invalid_grant", permanentErr.ErrorCode) +} + +func TestTokenRefresher_refreshAccessToken_AccessTokenStillValid(t *testing.T) { + // 测试当 access token 还有效时,临时错误应该返回 nil + currentTime := time.Now().Unix() + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessTokenExpire = currentTime + 3600 // access token 还有 1 小时有效 + + // 验证 access token 还有效时的逻辑 + accessTimeRemaining := profile.MCPOAuthAccessTokenExpire - currentTime + assert.Greater(t, accessTimeRemaining, int64(0), "access token should still be valid") + assert.GreaterOrEqual(t, accessTimeRemaining, int64(3600), "access token should have at least 1 hour remaining") +} + +func TestTokenRefresher_refreshAccessToken_AccessTokenExpired(t *testing.T) { + // 测试当 access token 已过期时,应该返回错误 + currentTime := time.Now().Unix() + profile := NewMcpProfile("test-profile") + profile.MCPOAuthAccessTokenExpire = currentTime - 100 // access token 已过期 + + // 验证 access token 已过期时的逻辑 + accessTimeRemaining := profile.MCPOAuthAccessTokenExpire - currentTime + assert.LessOrEqual(t, accessTimeRemaining, int64(0), "access token should be expired") + assert.Less(t, accessTimeRemaining, int64(0), "access token should be expired (negative remaining time)") +} + // 辅助函数 func getHomeEnv() string { return os.Getenv("HOME") diff --git a/mcpproxy/oauth_app.go b/mcpproxy/oauth_app.go index 51c096ab7..5e9bf08e5 100644 --- a/mcpproxy/oauth_app.go +++ b/mcpproxy/oauth_app.go @@ -43,6 +43,24 @@ import ( "github.com/alibabacloud-go/tea/tea" ) +// OAuthPermanentError 表示不可重试的 OAuth 错误(如 refresh token 失效、应用被删除等) +type OAuthPermanentError struct { + StatusCode int + ErrorCode string + Message string + Body string +} + +func (e *OAuthPermanentError) Error() string { + return e.Message +} + +// IsPermanentError 检查是否是永久性错误 +func IsPermanentError(err error) bool { + _, ok := err.(*OAuthPermanentError) + return ok +} + type RegionType string const ( @@ -453,36 +471,65 @@ func generateCodeChallenge(verifier string) string { } func oauthRefresh(endpoint string, data url.Values) (*OAuthTokenResponse, error) { - req, err := http.NewRequest("POST", endpoint+"/v1/token", strings.NewReader(data.Encode())) + fullURL := endpoint + "/v1/token" + log.Printf("OAuth refresh: attempting to refresh token at %s", fullURL) + + req, err := http.NewRequest("POST", fullURL, strings.NewReader(data.Encode())) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + startTime := time.Now() resp, err := util.NewHttpClient().Do(req) + duration := time.Since(startTime) + if err != nil { - return nil, err + log.Printf("OAuth refresh: HTTP request to %s failed after %v: %v", fullURL, duration, err) + return nil, fmt.Errorf("http request to %s failed after %v: %w", fullURL, duration, err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read response body: %w", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("refresh failed: status %d", resp.StatusCode) + log.Printf("OAuth refresh: received non-OK status %d after %v, body: %s", resp.StatusCode, duration, string(body)) + + // 检查是否是 OAuth 认证错误(不可重试的错误) + if resp.StatusCode == http.StatusBadRequest || resp.StatusCode == http.StatusUnauthorized { + var errorResp map[string]interface{} + if json.Unmarshal(body, &errorResp) == nil { + if errorCode, ok := errorResp["error"].(string); ok { + // 永久性错误,不应该重试 + if errorCode == "invalid_grant" || errorCode == "invalid_client" || + errorCode == "unauthorized_client" || errorCode == "invalid_token" { + return nil, &OAuthPermanentError{ + StatusCode: resp.StatusCode, + ErrorCode: errorCode, + Message: fmt.Sprintf("OAuth permanent error: %s (status %d)", errorCode, resp.StatusCode), + Body: string(body), + } + } + } + } + } + + return nil, fmt.Errorf("refresh failed: status %d, body: %s", resp.StatusCode, string(body)) } var tokenResp OAuthTokenResponse if err := json.Unmarshal(body, &tokenResp); err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse response JSON: %w", err) } if tokenResp.Error != "" { - return nil, fmt.Errorf("%s: %s", tokenResp.Error, tokenResp.ErrorDescription) + return nil, fmt.Errorf("oauth error: %s: %s", tokenResp.Error, tokenResp.ErrorDescription) } + log.Printf("OAuth refresh: token refreshed successfully in %v", duration) return &tokenResp, nil } @@ -508,12 +555,21 @@ func exchangeCodeForTokenWithPKCE(clientId, code, codeVerifier, redirectURI, oau } log.Println("Exchange code for token with PKCE successfully") + log.Printf("OAuth token response for authorization code exchange: AccessToken length=%d, RefreshToken length=%d, ExpiresIn=%d", + len(tokenResp.AccessToken), len(tokenResp.RefreshToken), tokenResp.ExpiresIn) + if tokenResp.RefreshToken == "" { + log.Printf("WARNING: OAuth response has empty RefreshToken! TokenType=%s, Error=%s", + tokenResp.TokenType, tokenResp.Error) + } + currentTime := util.GetCurrentUnixTime() - return &OAuthTokenResult{ + tokenResult := &OAuthTokenResult{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, AccessTokenExpire: currentTime + tokenResp.ExpiresIn, - }, nil + } + + return tokenResult, nil } func buildOAuthURL(clientId string, region RegionType, host string, port int, codeChallenge string, scope string) string { @@ -791,23 +847,11 @@ func newOpenAPIClient(ctx *cli.Context, profile config.Profile, endpoint string) return client, nil } -func getOrCreateMCPOAuthApplication(ctx *cli.Context, profile config.Profile, region RegionType, host string, port int, scope string) (*OAuthApplication, error) { - app, err := findOAuthApplicationByName(ctx, profile, region, MCPOAuthAppName) - if err != nil { - return nil, err - } - - if app != nil { - return app, nil - } - - return createMCPOauthApplication(ctx, profile, region, host, port, scope) -} - -func findOAuthApplicationById(ctx *cli.Context, profile config.Profile, mcpProfile *McpProfile, region RegionType) error { +// err means process error, application not found is not error, whether app exists based on the first return value +func findOAuthApplicationById(ctx *cli.Context, profile config.Profile, appId string, region RegionType) (*OAuthApplication, error) { client, err := newOpenAPIClient(ctx, profile, EndpointMap[region].IMS) if err != nil { - return err + return nil, err } params := &openapiClient.Params{ Action: tea.String("GetApplication"), @@ -821,24 +865,41 @@ func findOAuthApplicationById(ctx *cli.Context, profile config.Profile, mcpProfi runtime := &dara.RuntimeOptions{} request := &openapiutil.OpenApiRequest{ Query: map[string]*string{ - "AppId": tea.String(mcpProfile.MCPOAuthAppId), + "AppId": tea.String(appId), }, } response, err := client.CallApi(params, request, runtime) if err != nil { - return err + return nil, err } bodyBytes, err := GetContentFromApiResponse(response) if err != nil { - return fmt.Errorf("failed to get content from api response: %w", err) + return nil, fmt.Errorf("failed to get content from api response: %w", err) } var responseGet GetApplicationResponse if err := json.Unmarshal(bodyBytes, &responseGet); err != nil { - return err + return nil, err } - return nil + + app := responseGet.Application + scopes := make([]string, 0, len(app.DelegatedScope.PredefinedScopes.PredefinedScope)) + for _, s := range app.DelegatedScope.PredefinedScopes.PredefinedScope { + scopes = append(scopes, s.Name) + } + + return &OAuthApplication{ + ApplicationId: app.AppId, + AppName: app.AppName, + DisplayName: app.DisplayName, + AppType: app.AppType, + RedirectUris: app.RedirectUris.RedirectUri, + Scopes: scopes, + AccessTokenValidity: app.AccessTokenValidity, + RefreshTokenValidity: app.RefreshTokenValidity, + }, nil } +// err means process error, application not found is not error, whether app exists based on the first return value func findOAuthApplicationByName(ctx *cli.Context, profile config.Profile, region RegionType, appName string) (*OAuthApplication, error) { client, err := newOpenAPIClient(ctx, profile, EndpointMap[region].IMS) if err != nil { @@ -892,12 +953,22 @@ func findOAuthApplicationByName(ctx *cli.Context, profile config.Profile, region return nil, nil } -// validateOAuthApplication 验证 OAuth 应用的 Scopes 和 Callback URI 是否符合要求 -func validateOAuthApplication(app *OAuthApplication, requiredScope string, requiredRedirectURI string) error { +// validateOAuthApplication 验证 OAuth 应用的 Scopes, AppType 和 Callback URI 是否符合要求 +func validateOAuthApplication(app *OAuthApplication, requiredScope string, host string, port int) error { if app == nil { return fmt.Errorf("OAuth application is nil") } + log.Printf("Validating OAuth application: Name=%s, AppType=%s, AccessTokenValidity=%d, RefreshTokenValidity=%d", + app.AppName, app.AppType, app.AccessTokenValidity, app.RefreshTokenValidity) + + if app.AppType != "NativeApp" { + log.Printf("WARNING: OAuth application type is '%s', not 'NativeApp', refresh token is not supported!", + app.AppType) + return fmt.Errorf("OAuth application type is '%s', must be 'NativeApp' to get refresh token. "+ + "Please delete this application and let the system create a new NativeApp, or manually create a NativeApp", app.AppType) + } + // 验证 Scopes scopeFound := false for _, scope := range app.Scopes { @@ -911,7 +982,7 @@ func validateOAuthApplication(app *OAuthApplication, requiredScope string, requi app.AppName, requiredScope, app.Scopes) } - // 验证 Callback URI + requiredRedirectURI := buildRedirectUri(host, port) redirectURIFound := false for _, uri := range app.RedirectUris { if uri == requiredRedirectURI { @@ -931,7 +1002,7 @@ func buildRedirectUri(host string, port int) string { return fmt.Sprintf("http://%s:%d/callback", host, port) } -func createMCPOauthApplication(ctx *cli.Context, profile config.Profile, region RegionType, host string, port int, scope string) (*OAuthApplication, error) { +func createDefaultMCPOauthApplication(ctx *cli.Context, profile config.Profile, region RegionType, host string, port int, scope string) (*OAuthApplication, error) { client, err := newOpenAPIClient(ctx, profile, EndpointMap[region].IMS) if err != nil { return nil, err diff --git a/mcpproxy/oauth_app_test.go b/mcpproxy/oauth_app_test.go index 38a3ee5d1..0b2cf7882 100644 --- a/mcpproxy/oauth_app_test.go +++ b/mcpproxy/oauth_app_test.go @@ -445,3 +445,117 @@ func TestEndpointMap(t *testing.T) { assert.True(t, strings.HasPrefix(EndpointMap[RegionINTL].SignIn, "https://")) assert.True(t, strings.HasPrefix(EndpointMap[RegionINTL].OAuth, "https://")) } + +func TestOAuthPermanentError(t *testing.T) { + t.Run("error message", func(t *testing.T) { + err := &OAuthPermanentError{ + StatusCode: 400, + ErrorCode: "invalid_grant", + Message: "OAuth permanent error: invalid_grant (status 400)", + Body: `{"error":"invalid_grant"}`, + } + assert.Equal(t, "OAuth permanent error: invalid_grant (status 400)", err.Error()) + }) + + t.Run("IsPermanentError with OAuthPermanentError", func(t *testing.T) { + err := &OAuthPermanentError{ + StatusCode: 400, + ErrorCode: "invalid_grant", + Message: "test error", + } + assert.True(t, IsPermanentError(err)) + }) + + t.Run("IsPermanentError with regular error", func(t *testing.T) { + err := assert.AnError + assert.False(t, IsPermanentError(err)) + }) + + t.Run("IsPermanentError with nil", func(t *testing.T) { + assert.False(t, IsPermanentError(nil)) + }) + + t.Run("IsPermanentError with wrapped error", func(t *testing.T) { + permanentErr := &OAuthPermanentError{ + StatusCode: 400, + ErrorCode: "invalid_client", + Message: "test error", + } + wrappedErr := assert.AnError + assert.False(t, IsPermanentError(wrappedErr)) + assert.True(t, IsPermanentError(permanentErr)) + }) +} + +func TestOAuthRefresh_PermanentErrors(t *testing.T) { + permanentErrorCodes := []string{"invalid_grant", "invalid_client", "unauthorized_client", "invalid_token"} + + for _, errorCode := range permanentErrorCodes { + t.Run("permanent error: "+errorCode, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{ + "error": "` + errorCode + `", + "error_description": "Test error" + }`)) + })) + defer server.Close() + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("client_id", "test-app-id") + data.Set("refresh_token", "test-token") + + tokenResp, err := oauthRefresh(server.URL, data) + assert.Error(t, err) + assert.Nil(t, tokenResp) + assert.True(t, IsPermanentError(err), "error should be permanent: %v", err) + + permanentErr, ok := err.(*OAuthPermanentError) + assert.True(t, ok) + assert.Equal(t, http.StatusBadRequest, permanentErr.StatusCode) + assert.Equal(t, errorCode, permanentErr.ErrorCode) + }) + } + + t.Run("temporary error (500)", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) + })) + defer server.Close() + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("client_id", "test-app-id") + data.Set("refresh_token", "test-token") + + tokenResp, err := oauthRefresh(server.URL, data) + assert.Error(t, err) + assert.Nil(t, tokenResp) + assert.False(t, IsPermanentError(err), "500 error should not be permanent") + }) + + t.Run("non-permanent 400 error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{ + "error": "invalid_request", + "error_description": "Missing parameter" + }`)) + })) + defer server.Close() + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("client_id", "test-app-id") + data.Set("refresh_token", "test-token") + + tokenResp, err := oauthRefresh(server.URL, data) + assert.Error(t, err) + assert.Nil(t, tokenResp) + assert.False(t, IsPermanentError(err), "invalid_request should not be permanent") + }) +}