diff --git a/ai-assistant/src/langchain/LangChainManager.ts b/ai-assistant/src/langchain/LangChainManager.ts index ceddb1bb69..bbb47a4805 100644 --- a/ai-assistant/src/langchain/LangChainManager.ts +++ b/ai-assistant/src/langchain/LangChainManager.ts @@ -103,69 +103,83 @@ export default class LangChainManager extends AIManager { } private createModel(providerId: string, config: Record): BaseChatModel { + const sanitize = (value: any) => (typeof value === 'string' ? value.trim() : value); + const sanitizedConfig = { + ...config, + apiKey: sanitize(config.apiKey), + endpoint: sanitize(config.endpoint), + baseUrl: sanitize(config.baseUrl), + deploymentName: sanitize(config.deploymentName), + model: sanitize(config.model), + }; + try { switch (providerId) { case 'openai': - if (!config.apiKey) { + if (!sanitizedConfig.apiKey) { throw new Error('API key is required for OpenAI'); } return new ChatOpenAI({ - apiKey: config.apiKey, - modelName: config.model, + apiKey: sanitizedConfig.apiKey, + modelName: sanitizedConfig.model, dangerouslyAllowBrowser: true, verbose: true, }); case 'azure': - if (!config.apiKey || !config.endpoint || !config.deploymentName) { + if ( + !sanitizedConfig.apiKey || + !sanitizedConfig.endpoint || + !sanitizedConfig.deploymentName + ) { throw new Error('Incomplete Azure OpenAI configuration'); } return new AzureChatOpenAI({ - azureOpenAIEndpoint: config.endpoint.replace(/\/+\$/, ''), - azureOpenAIApiKey: config.apiKey, - azureOpenAIApiDeploymentName: config.deploymentName, + azureOpenAIEndpoint: sanitizedConfig.endpoint.replace(/\/+$/, ''), + azureOpenAIApiKey: sanitizedConfig.apiKey, + azureOpenAIApiDeploymentName: sanitizedConfig.deploymentName, azureOpenAIApiVersion: '2024-12-01-preview', - modelName: config.model, + modelName: sanitizedConfig.model, dangerouslyAllowBrowser: true, verbose: true, }); case 'anthropic': - if (!config.apiKey) { + if (!sanitizedConfig.apiKey) { throw new Error('API key is required for Anthropic'); } return new ChatAnthropic({ - apiKey: config.apiKey, - modelName: config.model, + apiKey: sanitizedConfig.apiKey, + modelName: sanitizedConfig.model, dangerouslyAllowBrowser: true, verbose: true, }); case 'mistral': - if (!config.apiKey) { + if (!sanitizedConfig.apiKey) { throw new Error('API key is required for Mistral AI'); } return new ChatMistralAI({ - apiKey: config.apiKey, - modelName: config.model, + apiKey: sanitizedConfig.apiKey, + modelName: sanitizedConfig.model, dangerouslyAllowBrowser: true, verbose: true, }); case 'gemini': { - if (!config.apiKey) { + if (!sanitizedConfig.apiKey) { throw new Error('API key is required for Google Gemini'); } return new ChatGoogleGenerativeAI({ - apiKey: config.apiKey, - model: config.model, + apiKey: sanitizedConfig.apiKey, + model: sanitizedConfig.model, dangerouslyAllowBrowser: true, verbose: true, }); } case 'local': - if (!config.baseUrl) { + if (!sanitizedConfig.baseUrl) { throw new Error('Base URL is required for local models'); } return new ChatOllama({ - baseUrl: config.baseUrl, - model: config.model, + baseUrl: sanitizedConfig.baseUrl, + model: sanitizedConfig.model, dangerouslyAllowBrowser: true, verbose: true, }); diff --git a/ai-assistant/src/modal.tsx b/ai-assistant/src/modal.tsx index 1882f20819..3349b72bb6 100644 --- a/ai-assistant/src/modal.tsx +++ b/ai-assistant/src/modal.tsx @@ -142,7 +142,7 @@ export default function AIPrompt(props: { if (active) { setActiveConfig(active); // Set the default model for the active provider - const defaultModel = getProviderModels(active)[0] || 'default'; + const defaultModel = resolveSelectedModel(active); setSelectedModel(defaultModel); // Update global state with all providers and active one @@ -178,7 +178,7 @@ export default function AIPrompt(props: { _pluginSetting.setActiveProvider(newActive); // Set the default model for the new provider - const defaultModel = getProviderModels(newActive)[0] || 'default'; + const defaultModel = resolveSelectedModel(newActive); setSelectedModel(defaultModel); // Clear history and show provider change message @@ -224,6 +224,20 @@ export default function AIPrompt(props: { // Handle changing the active configuration const [selectedModel, setSelectedModel] = useState('default'); + const resolveSelectedModel = (config: StoredProviderConfig, explicitModel?: string) => { + if (explicitModel && explicitModel.trim().length > 0) { + return explicitModel; + } + + const savedModel = config.config?.model; + if (savedModel && savedModel.trim().length > 0) { + return savedModel; + } + + const providerModels = getProviderModels(config); + return providerModels[0] || 'default'; + }; + const handleChangeConfig = (config: StoredProviderConfig, model?: string) => { if (!config) return; if ( @@ -233,11 +247,9 @@ export default function AIPrompt(props: { JSON.stringify(activeConfig.config) !== JSON.stringify(config.config) || selectedModel !== model ) { - setPromptHistory([]); - setPromptVal(''); setApiError(null); setActiveConfig(config); - setSelectedModel(model || getProviderModels(config)[0] || 'default'); + setSelectedModel(resolveSelectedModel(config, model)); _pluginSetting.setActiveProvider(config); if (aiManager) { aiManager.reset(); @@ -245,19 +257,19 @@ export default function AIPrompt(props: { setTimeout(() => { const providerName = config.displayName || getProviderById(config.providerId)?.name || config.providerId; - setPromptHistory([ + setPromptHistory(prev => [ + ...prev, { role: 'system', - content: `Switched to ${providerName}${ - model ? ' / ' + model : '' - }. History has been cleared.`, + content: `Switched to ${providerName}${model ? ' / ' + model : ''}.`, }, ]); }, 100); } else { const providerName = config.displayName || getProviderById(config.providerId)?.name || config.providerId; - setPromptHistory([ + setPromptHistory(prev => [ + ...prev, { role: 'system', content: `Using ${providerName}${model ? ' / ' + model : ''}.`,