diff --git a/backend/app/models/subtask.py b/backend/app/models/subtask.py index 141e4baf..02f29faa 100644 --- a/backend/app/models/subtask.py +++ b/backend/app/models/subtask.py @@ -43,6 +43,7 @@ class Subtask(Base): progress = Column(Integer, nullable=False, default=0) result = Column(JSON) error_message = Column(Text) + override_model_config = Column(JSON, nullable=True) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) completed_at = Column(DateTime) diff --git a/backend/app/schemas/task.py b/backend/app/schemas/task.py index 74b6a134..60fe9a0a 100644 --- a/backend/app/schemas/task.py +++ b/backend/app/schemas/task.py @@ -53,6 +53,7 @@ class TaskCreate(BaseModel): task_type: Optional[str] = "chat" # chat、code auto_delete_executor: Optional[str] = "false" # true、fasle source: Optional[str] = "web" + override_model: Optional[str] = None class TaskUpdate(BaseModel): diff --git a/backend/app/services/adapters/task_kinds.py b/backend/app/services/adapters/task_kinds.py index 23b189c5..7654c583 100644 --- a/backend/app/services/adapters/task_kinds.py +++ b/backend/app/services/adapters/task_kinds.py @@ -230,7 +230,7 @@ def create_task_or_append( db.add(task) # Create subtasks for the task - self._create_subtasks(db, task, team, user.id, obj_in.prompt) + self._create_subtasks(db, task, team, user.id, obj_in.prompt, obj_in.override_model) db.commit() db.refresh(task) @@ -436,6 +436,17 @@ def get_task_detail( # Convert subtasks to dict and replace bot_ids with bot objects subtasks_dict = [] for subtask in subtasks: + # Apply override_model_config to bot configs if present + subtask_bots = [] + for bot_id in subtask.bot_ids: + if bot_id in bots: + bot_dict = bots[bot_id].copy() + # If subtask has override_model_config, use it to override agent_config + if subtask.override_model_config: + bot_dict["agent_config"] = subtask.override_model_config + logger.info(f"Applied override_model_config to bot {bot_id} in subtask {subtask.id}: {subtask.override_model_config}") + subtask_bots.append(bot_dict) + # Convert subtask to dict subtask_dict = { # Subtask base fields @@ -458,8 +469,8 @@ def get_task_detail( "created_at": subtask.created_at, "updated_at": subtask.updated_at, "completed_at": subtask.completed_at, - # Add bot objects as dict for each bot_id - "bots": [bots.get(bot_id) for bot_id in subtask.bot_ids if bot_id in bots] + # Add bot objects with potentially overridden config + "bots": subtask_bots } subtasks_dict.append(subtask_dict) @@ -832,11 +843,11 @@ def _convert_team_to_dict(self, team: Kind, db: Session, user_id: int) -> Dict[s "updated_at": team.updated_at, } - def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, user_prompt: str) -> None: + def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, user_prompt: str, override_model: Optional[str] = None) -> None: """ Create subtasks based on team's workflow configuration """ - logger.info(f"_create_subtasks called with task_id={task.id}, team_id={team.id}, user_id={user_id}") + logger.info(f"_create_subtasks called with task_id={task.id}, team_id={team.id}, user_id={user_id}, override_model={override_model}") team_crd = Team.model_validate(team.json) task_crd = Task.model_validate(task.json) @@ -902,6 +913,12 @@ def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, us ) db.add(user_subtask) + # Prepare override_model_config if override_model is provided + override_model_config = None + if override_model: + override_model_config = {"private_model": override_model} + logger.info(f"Using override_model_config: {override_model_config}") + # Update id of next message and parent if parent_id == 0: parent_id = 1 @@ -944,6 +961,7 @@ def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, us error_message="", completed_at=datetime.now(), result=None, + override_model_config=override_model_config, ) # Update id of next message and parent @@ -977,6 +995,7 @@ def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, us error_message="", completed_at=datetime.now(), result=None, + override_model_config=override_model_config, ) db.add(assistant_subtask) diff --git a/frontend/src/apis/tasks.ts b/frontend/src/apis/tasks.ts index 09813204..d36c7b6a 100644 --- a/frontend/src/apis/tasks.ts +++ b/frontend/src/apis/tasks.ts @@ -19,6 +19,7 @@ export interface CreateTaskRequest { batch: number; user_id: number; user_name: string; + override_model?: string; } export interface UpdateTaskRequest { diff --git a/frontend/src/features/tasks/components/ChatArea.tsx b/frontend/src/features/tasks/components/ChatArea.tsx index b4370d32..13e4a0d3 100644 --- a/frontend/src/features/tasks/components/ChatArea.tsx +++ b/frontend/src/features/tasks/components/ChatArea.tsx @@ -9,6 +9,7 @@ import { ArrowTurnDownLeftIcon } from '@heroicons/react/24/outline'; import MessagesArea from './MessagesArea'; import ChatInput from './ChatInput'; import TeamSelector from './TeamSelector'; +import ModelSelector from './ModelSelector'; import RepositorySelector from './RepositorySelector'; import BranchSelector from './BranchSelector'; import type { Team, GitRepoInfo, GitBranch } from '@/types/api'; @@ -18,7 +19,7 @@ import { useTaskContext } from '../contexts/taskContext'; import { App, Button } from 'antd'; import QuotaUsage from './QuotaUsage'; import { useMediaQuery } from '@/hooks/useMediaQuery'; -import { saveLastTeam, getLastTeamId, saveLastRepo } from '@/utils/userPreferences'; +import { saveLastTeam, getLastTeamId, saveLastRepo, saveLastModel, getLastModel } from '@/utils/userPreferences'; const SHOULD_HIDE_QUOTA_NAME_LIMIT = 18; @@ -47,6 +48,7 @@ export default function ChatArea({ } const [selectedTeam, setSelectedTeam] = useState(null); + const [selectedModel, setSelectedModel] = useState(null); const [selectedRepo, setSelectedRepo] = useState(null); const [selectedBranch, setSelectedBranch] = useState(null); const [hasRestoredPreferences, setHasRestoredPreferences] = useState(false); @@ -114,6 +116,38 @@ export default function ChatArea({ setHasRestoredPreferences(true); }, [teams, hasRestoredPreferences]); + // Restore model preference from localStorage on mount (only for new tasks) + useEffect(() => { + if (!hasMessages) { + const lastModel = getLastModel(); + if (lastModel) { + console.log('[ChatArea] Restoring model from localStorage:', lastModel); + setSelectedModel(lastModel); + } + } + }, [hasMessages]); + + // Read current model from task detail when appending messages + useEffect(() => { + if (hasMessages && selectedTaskDetail?.subtasks) { + // Find the most recent ASSISTANT subtask + const assistantSubtasks = selectedTaskDetail.subtasks.filter( + (st: any) => st.role === 'ASSISTANT' + ); + if (assistantSubtasks.length > 0) { + const latestAssistant = assistantSubtasks[assistantSubtasks.length - 1]; + if (latestAssistant.bots && latestAssistant.bots.length > 0) { + const bot = latestAssistant.bots[0]; + const privateModel = bot.agent_config?.private_model; + if (privateModel) { + console.log('[ChatArea] Setting model from current task:', privateModel); + setSelectedModel(privateModel); + } + } + } + } + }, [hasMessages, selectedTaskDetail]); + // Handle external team selection for new tasks (from team sharing) useEffect(() => { if (selectedTeamForNewTask && !hasMessages) { @@ -142,6 +176,17 @@ export default function ChatArea({ } }; + const handleModelChange = (model: string | null) => { + console.log('[ChatArea] handleModelChange called:', model || 'null'); + setSelectedModel(model); + + // Save model preference to localStorage + if (model) { + console.log('[ChatArea] Saving model to localStorage:', model); + saveLastModel(model); + } + }; + // Save repository preference when it changes useEffect(() => { if (selectedRepo) { @@ -159,6 +204,7 @@ export default function ChatArea({ branch: showRepositorySelector ? selectedBranch : null, task_id: selectedTaskDetail?.id, taskType: taskType, + selectedModel: selectedModel, }); if (error) { message.error(error); @@ -358,15 +404,23 @@ export default function ChatArea({ /> {/* Team Selector and Send Button */}
-
+
{teams.length > 0 && ( - + <> + + + )}
diff --git a/frontend/src/features/tasks/components/ModelSelector.tsx b/frontend/src/features/tasks/components/ModelSelector.tsx new file mode 100644 index 00000000..e75ef760 --- /dev/null +++ b/frontend/src/features/tasks/components/ModelSelector.tsx @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: 2025 Weibo, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +'use client'; + +import React, { useEffect, useState, useMemo } from 'react'; +import { Select, theme } from 'antd'; +import { CpuChipIcon } from '@heroicons/react/24/outline'; +import { apiClient } from '@/apis/client'; +import { useTranslation } from '@/hooks/useTranslation'; +import { useMediaQuery } from '@/hooks/useMediaQuery'; + +interface ModelSelectorProps { + selectedModel: string | null; + setSelectedModel: (model: string | null) => void; + disabled?: boolean; + isLoading?: boolean; +} + +interface ModelOption { + name: string; +} + +export default function ModelSelector({ + selectedModel, + setSelectedModel, + disabled = false, + isLoading = false, +}: ModelSelectorProps) { + const { t } = useTranslation('common'); + const { token } = theme.useToken(); + const isMobile = useMediaQuery('(max-width: 767px)'); + const [models, setModels] = useState([]); + const [modelsLoading, setModelsLoading] = useState(false); + const [error, setError] = useState(null); + + // Fetch models from API + useEffect(() => { + const fetchModels = async () => { + setModelsLoading(true); + setError(null); + try { + const response = await apiClient.get<{ data: ModelOption[] }>('/models/names', { + agent_name: 'ClaudeCode', + }); + setModels(response.data || []); + } catch (err) { + console.error('Failed to fetch models:', err); + setError('Failed to load models'); + setModels([]); + } finally { + setModelsLoading(false); + } + }; + + fetchModels(); + }, []); + + // Validate selectedModel exists in models list + useEffect(() => { + if (selectedModel && models.length > 0) { + const exists = models.some(m => m.name === selectedModel); + if (!exists) { + console.warn( + `Selected model "${selectedModel}" not found in models list, clearing selection` + ); + setSelectedModel(null); + } + } + }, [selectedModel, models, setSelectedModel]); + + const handleChange = (value: string | null) => { + setSelectedModel(value); + }; + + const modelOptions = useMemo(() => { + return models.map(model => ({ + label: ( + + {model.name} + + ), + value: model.name, + })); + }, [models]); + + const filterOption = (input: string, option?: { label: React.ReactNode; value: string }) => { + if (!option) return false; + return option.value.toLowerCase().includes(input.toLowerCase()); + }; + + if (error) { + return null; // Hide selector if models failed to load + } + + return ( +
+ +