Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 116 additions & 40 deletions mcpproxy/mcp_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,21 @@ func saveMcpProfile(profile *McpProfile) error {
return fmt.Errorf("failed to create config directory %q: %w", dir, err)
}

// log.Printf("saveMcpProfile: Before marshaling - RefreshToken length=%d, RefreshToken empty=%v",
// len(profile.MCPOAuthRefreshToken), profile.MCPOAuthRefreshToken == "")

tempFile := mcpConfigPath + ".tmp"

bytes, err := json.MarshalIndent(profile, "", "\t")
if err != nil {
return fmt.Errorf("failed to marshal profile: %w", err)
}

// jsonStr := string(bytes)
// hasRefreshToken := strings.Contains(jsonStr, "mcp_oauth_refresh_token")
// log.Printf("saveMcpProfile: After marshaling - JSON contains 'mcp_oauth_refresh_token'=%v, JSON length=%d",
// hasRefreshToken, len(jsonStr))

if err := os.WriteFile(tempFile, bytes, 0600); err != nil {
return fmt.Errorf("failed to write temp file %q: %w", tempFile, err)
}
Expand All @@ -80,64 +88,121 @@ func saveMcpProfile(profile *McpProfile) error {
_ = os.Remove(tempFile)
return fmt.Errorf("failed to rename temp file to %q: %w", mcpConfigPath, err)
}

log.Printf("saveMcpProfile: Successfully saved MCP profile")
return nil
}

// loadExistingMCPProfile 加载并验证已有的 MCP profile,如果有效则返回,避免重复拉起 OAuth
func loadExistingMCPProfile(ctx *cli.Context, profile config.Profile, opts ProxyConfig, desiredAppName string) *McpProfile {
mcpConfigPath := getMCPConfigPath()
bytes, err := os.ReadFile(mcpConfigPath)
if err != nil {
return nil
}
mcpProfile, err := NewMcpProfileFromBytes(bytes)
if err != nil {
return nil
}

if mcpProfile.MCPOAuthSiteType != string(opts.RegionType) {
log.Printf("Region type mismatch: saved=%s, requested=%s, ignoring local profile", mcpProfile.MCPOAuthSiteType, string(opts.RegionType))
return nil
}

if mcpProfile.MCPOAuthAppName != desiredAppName {
log.Printf("App name mismatch: saved=%s, requested=%s, ignoring local profile", mcpProfile.MCPOAuthAppName, desiredAppName)
return nil
}

if mcpProfile.MCPOAuthAppId == "" {
log.Printf("MCP profile with AppId is empty, ignoring local profile")
return nil
}

if mcpProfile.MCPOAuthRefreshToken == "" {
log.Printf("MCP profile with RefreshToken is empty, ignoring local profile")
return nil
}

if mcpProfile.MCPOAuthRefreshTokenExpire <= util.GetCurrentUnixTime() {
log.Printf("MCP profile with RefreshTokenExpire is expired, ignoring local profile")
return nil
}

app, err := findOAuthApplicationById(ctx, profile, mcpProfile.MCPOAuthAppId, opts.RegionType)
if err != nil {
log.Printf("Failed to reuse existing MCP profile (app: %s): %v, ignoring local profile", mcpProfile.MCPOAuthAppName, err)
return nil
}
if app == nil {
log.Printf("OAuth application with AppId '%s' not found, ignoring local profile", mcpProfile.MCPOAuthAppId)
return nil
}

if err := validateOAuthApplication(app, opts.Scope, opts.Host, opts.Port); err != nil {
log.Printf("Reused existing MCP profile validation failed: %v, ignoring local profile", err)
return nil
}

// 根据远端 app 信息更新 mcp profile 中的相关字段,其他字段(如 token)保持不变
mcpProfile.MCPOAuthAppName = app.AppName
mcpProfile.MCPOAuthAppId = app.ApplicationId
mcpProfile.MCPOAuthAccessTokenValidity = app.AccessTokenValidity
mcpProfile.MCPOAuthRefreshTokenValidity = app.RefreshTokenValidity

log.Printf("Reused existing MCP profile with app '%s' (AppId: %s)", app.AppName, app.ApplicationId)

return mcpProfile
}

func getOrCreateMCPProfile(ctx *cli.Context, opts ProxyConfig) (*McpProfile, error) {
profile, err := config.LoadProfileWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load profile: %w", err)
}

// 如果传入了 oauth-app-name,先验证该应用是否存在且合法
// 如果已经验证过 oauth-app-name,直接使用验证过的 app;否则查找或创建默认的 OAuth 应用
var validatedApp *OAuthApplication
if opts.OAuthAppName != "" {
app, err := findOAuthApplicationByName(ctx, profile, opts.RegionType, opts.OAuthAppName)
if err != nil {
return nil, fmt.Errorf("failed to find OAuth application '%s': %w", opts.OAuthAppName, err)
}
if app == nil {
return nil, fmt.Errorf("OAuth application '%s' not found", opts.OAuthAppName)
}
// 如果未显式指定 app name,则使用默认的 MCPOAuthAppName,便于复用本地 profile
desiredAppName := opts.OAuthAppName
if desiredAppName == "" {
desiredAppName = MCPOAuthAppName
}

// 验证 Scopes 和 Callback URI
requiredRedirectURI := buildRedirectUri(opts.Host, opts.Port)
if err := validateOAuthApplication(app, opts.Scope, requiredRedirectURI); err != nil {
return nil, fmt.Errorf("OAuth application validation failed: %w", err)
existingMcpProfile := loadExistingMCPProfile(ctx, profile, opts, desiredAppName)
if existingMcpProfile != nil {
// mcpprofile might change, save it again to ensure the latest state is saved
if err := saveMcpProfile(existingMcpProfile); err != nil {
return nil, fmt.Errorf("failed to save mcp profile: %w", err)
}
return existingMcpProfile, nil
}

validatedApp = app
cli.Printf(ctx.Stdout(), "Using existing OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId)
} else {
// 查找或创建默认的 OAuth 应用
mcpConfigPath := getMCPConfigPath()
if bytes, err := os.ReadFile(mcpConfigPath); err == nil {
if mcpProfile, err := NewMcpProfileFromBytes(bytes); err == nil {
log.Println("MCP Profile loaded from file", mcpProfile.Name, "app id", mcpProfile.MCPOAuthAppId)

// 检查 region type 是否匹配,因为国内和国际站的 OAuth 地址不同, Region type 不匹配则重新创建 profile
if mcpProfile.MCPOAuthSiteType != string(opts.RegionType) {
log.Printf("Region type mismatch: saved=%s, requested=%s, recreating profile", mcpProfile.MCPOAuthSiteType, string(opts.RegionType))
} else {
err = findOAuthApplicationById(ctx, profile, mcpProfile, opts.RegionType)
if err == nil {
return mcpProfile, nil
} else {
log.Println("Failed to find existing OAuth application", err.Error())
}
}
}
app, err := findOAuthApplicationByName(ctx, profile, opts.RegionType, desiredAppName)
if err != nil {
return nil, fmt.Errorf("failed to find OAuth application '%s': %w", desiredAppName, err)
}

if app == nil {
if opts.OAuthAppName != "" {
// if user provide app name, but not found, return error
return nil, fmt.Errorf("OAuth application '%s' not found", opts.OAuthAppName)
}
app, err := getOrCreateMCPOAuthApplication(ctx, profile, opts.RegionType, opts.Host, opts.Port, opts.Scope)
cli.Printf(ctx.Stdout(), "Creating new default MCP profile '%s'...\n", DefaultMcpProfileName)
app, err = createDefaultMCPOauthApplication(ctx, profile, opts.RegionType, opts.Host, opts.Port, opts.Scope)
if err != nil {
return nil, fmt.Errorf("failed to get or create OAuth application: %w", err)
return nil, fmt.Errorf("failed to create default OAuth application: %w", err)
}
validatedApp = app
cli.Printf(ctx.Stdout(), "Created new default OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId)
} else {
cli.Printf(ctx.Stdout(), "Using existing OAuth application '%s' (AppId: %s)\n", app.AppName, app.ApplicationId)
}

cli.Printf(ctx.Stdout(), "Setting up MCPOAuth profile '%s'...\n", DefaultMcpProfileName)
if err := validateOAuthApplication(app, opts.Scope, opts.Host, opts.Port); err != nil {
return nil, fmt.Errorf("OAuth application validation failed: %w", err)
}
validatedApp := app

cli.Printf(ctx.Stdout(), "Setting up MCPOAuth profile '%s'...\n", DefaultMcpProfileName)
mcpProfile := NewMcpProfile(DefaultMcpProfileName)
mcpProfile.MCPOAuthSiteType = string(opts.RegionType)
mcpProfile.MCPOAuthAppId = validatedApp.ApplicationId
Expand All @@ -153,9 +218,20 @@ func getOrCreateMCPProfile(ctx *cli.Context, opts ProxyConfig) (*McpProfile, err
if err != nil {
return nil, fmt.Errorf("OAuth login failed: %w", err)
}

log.Printf("OAuth flow completed: AccessToken length=%d, RefreshToken length=%d, AccessTokenExpire=%d",
len(tokenResult.AccessToken), len(tokenResult.RefreshToken), tokenResult.AccessTokenExpire)
if tokenResult.RefreshToken == "" {
return nil, fmt.Errorf("OAuth flow returned empty RefreshToken (Region=%s, AppId=%s). "+
"Please delete this application and let the system create a new NativeApp, or manually create a NativeApp",
opts.RegionType, mcpProfile.MCPOAuthAppId)
}

mcpProfile.MCPOAuthAccessToken = tokenResult.AccessToken
mcpProfile.MCPOAuthRefreshToken = tokenResult.RefreshToken
mcpProfile.MCPOAuthAccessTokenExpire = tokenResult.AccessTokenExpire
// refresh token will be updated each time latest access token is refreshed,
// however the validity and expiration time is the same as the original when finishing oauth flow
mcpProfile.MCPOAuthRefreshTokenExpire = currentTime + int64(validatedApp.RefreshTokenValidity)

if err = saveMcpProfile(mcpProfile); err != nil {
Expand Down
Loading
Loading