-
Notifications
You must be signed in to change notification settings - Fork 32
feat: Add model selector feature to chat interface #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,7 +230,7 @@ def create_task_or_append( | |
| db.add(task) | ||
|
|
||
| # Create subtasks for the task | ||
| self._create_subtasks(db, task, team, user.id, obj_in.prompt) | ||
| self._create_subtasks(db, task, team, user.id, obj_in.prompt, obj_in.override_model) | ||
|
|
||
| db.commit() | ||
| db.refresh(task) | ||
|
|
@@ -436,6 +436,17 @@ def get_task_detail( | |
| # Convert subtasks to dict and replace bot_ids with bot objects | ||
| subtasks_dict = [] | ||
| for subtask in subtasks: | ||
| # Apply override_model_config to bot configs if present | ||
| subtask_bots = [] | ||
| for bot_id in subtask.bot_ids: | ||
| if bot_id in bots: | ||
| bot_dict = bots[bot_id].copy() | ||
| # If subtask has override_model_config, use it to override agent_config | ||
| if subtask.override_model_config: | ||
| bot_dict["agent_config"] = subtask.override_model_config | ||
| logger.info(f"Applied override_model_config to bot {bot_id} in subtask {subtask.id}: {subtask.override_model_config}") | ||
| subtask_bots.append(bot_dict) | ||
|
|
||
| # Convert subtask to dict | ||
| subtask_dict = { | ||
| # Subtask base fields | ||
|
|
@@ -458,8 +469,8 @@ def get_task_detail( | |
| "created_at": subtask.created_at, | ||
| "updated_at": subtask.updated_at, | ||
| "completed_at": subtask.completed_at, | ||
| # Add bot objects as dict for each bot_id | ||
| "bots": [bots.get(bot_id) for bot_id in subtask.bot_ids if bot_id in bots] | ||
| # Add bot objects with potentially overridden config | ||
| "bots": subtask_bots | ||
| } | ||
| subtasks_dict.append(subtask_dict) | ||
|
|
||
|
|
@@ -832,11 +843,11 @@ def _convert_team_to_dict(self, team: Kind, db: Session, user_id: int) -> Dict[s | |
| "updated_at": team.updated_at, | ||
| } | ||
|
|
||
| def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, user_prompt: str) -> None: | ||
| def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, user_prompt: str, override_model: Optional[str] = None) -> None: | ||
| """ | ||
| Create subtasks based on team's workflow configuration | ||
| """ | ||
| logger.info(f"_create_subtasks called with task_id={task.id}, team_id={team.id}, user_id={user_id}") | ||
| logger.info(f"_create_subtasks called with task_id={task.id}, team_id={team.id}, user_id={user_id}, override_model={override_model}") | ||
| team_crd = Team.model_validate(team.json) | ||
| task_crd = Task.model_validate(task.json) | ||
|
|
||
|
|
@@ -902,6 +913,12 @@ def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, us | |
| ) | ||
| db.add(user_subtask) | ||
|
|
||
| # Prepare override_model_config if override_model is provided | ||
| override_model_config = None | ||
| if override_model: | ||
| override_model_config = {"private_model": override_model} | ||
| logger.info(f"Using override_model_config: {override_model_config}") | ||
|
Comment on lines
+916
to
+920
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAdd validation for override_model value. The code creates Add validation before creating the config: # Prepare override_model_config if override_model is provided
override_model_config = None
if override_model:
+ # Validate override_model is not empty
+ if not override_model.strip():
+ raise HTTPException(
+ status_code=400,
+ detail="override_model cannot be empty"
+ )
+ # TODO: Consider validating against available models from the models API
override_model_config = {"private_model": override_model}
logger.info(f"Using override_model_config: {override_model_config}")Additionally, verify if there's a models service that can validate model names: 🏁 Script executed: #!/bin/bash
# Description: Check for model validation services or available models lists
# Search for models API endpoints or services
echo "=== Checking for models API or validation services ==="
rg -n "def.*models" --type=py -g "backend/**" -A3
# Search for model name validation
echo "=== Checking for model name validation ==="
rg -n "validate.*model|model.*valid" --type=py -g "backend/**" -C3
# Look for /models/names endpoint implementation
echo "=== Checking /models/names endpoint ==="
rg -n "/models/names|models/names" --type=py -g "backend/**" -C5Length of output: 44837 🏁 Script executed: #!/bin/bash
# Examine context around lines 916-920 in task_kinds.py
echo "=== Code context around lines 916-920 ==="
sed -n '900,930p' backend/app/services/adapters/task_kinds.py
echo -e "\n=== Searching for override_model usage ==="
rg -n "override_model" --type=py -C5 -g "backend/**"
echo -e "\n=== Checking function signature where override_model is used ==="
rg -n "def.*override_model" --type=py -C3 -g "backend/**"Length of output: 10999 🏁 Script executed: #!/bin/bash
# Check where override_model_config is used downstream and if validation exists
echo "=== Searching for where override_model_config is processed ==="
rg -n "override_model_config" --type=py -C3 -g "backend/**"
echo -e "\n=== Checking models endpoint for validation capabilities ==="
cat backend/app/api/endpoints/adapter/models.py | head -150
echo -e "\n=== Checking if there's a method to validate model names ==="
rg -n "def.*valid.*model|def.*check.*model" --type=py -C2 -g "backend/**"Length of output: 8660 🏁 Script executed: #!/bin/bash
# Check list_model_names implementation and related context
echo "=== Checking list_model_names implementation ==="
rg -n "def list_model_names" --type=py -C15 -g "backend/**"
echo -e "\n=== Checking if validation is done elsewhere for similar parameters ==="
rg -n "HTTPException.*400|raise.*ValueError" --type=py -C2 -g "backend/app/services/adapters/task_kinds.py" | head -50
echo -e "\n=== Check if override_model is ever validated at API level ==="
rg -n "override_model" --type=py backend/app/api/ -C3Length of output: 5483 Add validation for override_model before creating config. The code accepts At minimum, validate that # Prepare override_model_config if override_model is provided
override_model_config = None
if override_model:
+ if not override_model.strip():
+ raise HTTPException(
+ status_code=400,
+ detail="override_model cannot be empty"
+ )
+ # TODO: Consider validating against available models for the shells used in this team
override_model_config = {"private_model": override_model}
logger.info(f"Using override_model_config: {override_model_config}")
🤖 Prompt for AI Agents |
||
|
|
||
| # Update id of next message and parent | ||
| if parent_id == 0: | ||
| parent_id = 1 | ||
|
|
@@ -944,6 +961,7 @@ def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, us | |
| error_message="", | ||
| completed_at=datetime.now(), | ||
| result=None, | ||
| override_model_config=override_model_config, | ||
| ) | ||
|
|
||
| # Update id of next message and parent | ||
|
|
@@ -977,6 +995,7 @@ def _create_subtasks(self, db: Session, task: Kind, team: Kind, user_id: int, us | |
| error_message="", | ||
| completed_at=datetime.now(), | ||
| result=None, | ||
| override_model_config=override_model_config, | ||
| ) | ||
| db.add(assistant_subtask) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| // SPDX-FileCopyrightText: 2025 Weibo, Inc. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| 'use client'; | ||
|
|
||
| import React, { useEffect, useState, useMemo } from 'react'; | ||
| import { Select, theme } from 'antd'; | ||
| import { CpuChipIcon } from '@heroicons/react/24/outline'; | ||
| import { apiClient } from '@/apis/client'; | ||
| import { useTranslation } from '@/hooks/useTranslation'; | ||
| import { useMediaQuery } from '@/hooks/useMediaQuery'; | ||
|
|
||
| interface ModelSelectorProps { | ||
| selectedModel: string | null; | ||
| setSelectedModel: (model: string | null) => void; | ||
| disabled?: boolean; | ||
| isLoading?: boolean; | ||
| } | ||
|
|
||
| interface ModelOption { | ||
| name: string; | ||
| } | ||
|
|
||
| export default function ModelSelector({ | ||
| selectedModel, | ||
| setSelectedModel, | ||
| disabled = false, | ||
| isLoading = false, | ||
| }: ModelSelectorProps) { | ||
| const { t } = useTranslation('common'); | ||
| const { token } = theme.useToken(); | ||
| const isMobile = useMediaQuery('(max-width: 767px)'); | ||
| const [models, setModels] = useState<ModelOption[]>([]); | ||
| const [modelsLoading, setModelsLoading] = useState(false); | ||
| const [error, setError] = useState<string | null>(null); | ||
|
|
||
| // Fetch models from API | ||
| useEffect(() => { | ||
| const fetchModels = async () => { | ||
| setModelsLoading(true); | ||
| setError(null); | ||
| try { | ||
| const response = await apiClient.get<{ data: ModelOption[] }>('/models/names', { | ||
| agent_name: 'ClaudeCode', | ||
| }); | ||
| setModels(response.data || []); | ||
| } catch (err) { | ||
| console.error('Failed to fetch models:', err); | ||
| setError('Failed to load models'); | ||
| setModels([]); | ||
| } finally { | ||
| setModelsLoading(false); | ||
| } | ||
| }; | ||
|
|
||
| fetchModels(); | ||
| }, []); | ||
|
|
||
| // Validate selectedModel exists in models list | ||
| useEffect(() => { | ||
| if (selectedModel && models.length > 0) { | ||
| const exists = models.some(m => m.name === selectedModel); | ||
| if (!exists) { | ||
| console.warn( | ||
| `Selected model "${selectedModel}" not found in models list, clearing selection` | ||
| ); | ||
| setSelectedModel(null); | ||
| } | ||
| } | ||
| }, [selectedModel, models, setSelectedModel]); | ||
|
|
||
| const handleChange = (value: string | null) => { | ||
| setSelectedModel(value); | ||
| }; | ||
|
|
||
| const modelOptions = useMemo(() => { | ||
| return models.map(model => ({ | ||
| label: ( | ||
| <span className="font-medium text-xs text-text-primary truncate" title={model.name}> | ||
| {model.name} | ||
| </span> | ||
| ), | ||
| value: model.name, | ||
| })); | ||
| }, [models]); | ||
|
|
||
| const filterOption = (input: string, option?: { label: React.ReactNode; value: string }) => { | ||
| if (!option) return false; | ||
| return option.value.toLowerCase().includes(input.toLowerCase()); | ||
| }; | ||
|
|
||
| if (error) { | ||
| return null; // Hide selector if models failed to load | ||
| } | ||
|
|
||
| return ( | ||
| <div className="flex items-baseline space-x-1 min-w-0"> | ||
| <CpuChipIcon | ||
| className={`w-3 h-3 text-text-muted flex-shrink-0 ${modelsLoading || isLoading ? 'animate-pulse' : ''}`} | ||
| /> | ||
| <Select | ||
| showSearch | ||
| allowClear | ||
| value={selectedModel} | ||
| placeholder={ | ||
| <span className="text-sx truncate h-2"> | ||
| {modelsLoading ? t('chat.model_loading') || 'Loading...' : t('chat.select_model') || 'Select Model'} | ||
| </span> | ||
| } | ||
| className="repository-selector min-w-0 truncate" | ||
| style={{ | ||
| width: 'auto', | ||
| maxWidth: isMobile ? 150 : 200, | ||
| display: 'inline-block', | ||
| paddingRight: 20, | ||
| }} | ||
| popupMatchSelectWidth={false} | ||
| styles={{ popup: { root: { maxWidth: 280 } } }} | ||
| classNames={{ popup: { root: 'repository-selector-dropdown custom-scrollbar' } }} | ||
| disabled={disabled || modelsLoading} | ||
| loading={modelsLoading} | ||
| size="small" | ||
| filterOption={filterOption} | ||
| onChange={handleChange} | ||
| notFoundContent={ | ||
| <div className="px-3 py-2 text-sm text-text-muted"> | ||
| {t('chat.no_model_found') || 'No model found'} | ||
| </div> | ||
| } | ||
| options={modelOptions} | ||
| /> | ||
| </div> | ||
| ); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Override replaces entire agent_config instead of merging.
Line 446 replaces the entire
agent_configwithoverride_model_config:This means if the original
agent_confighad other important fields (e.g.,temperature,max_tokens,top_p), they will be lost. Onlyprivate_modelwill remain.Consider merging the override config with the existing config:
if bot_id in bots: bot_dict = bots[bot_id].copy() # If subtask has override_model_config, use it to override agent_config if subtask.override_model_config: - bot_dict["agent_config"] = subtask.override_model_config + # Merge override config with existing config to preserve other settings + if bot_dict.get("agent_config"): + bot_dict["agent_config"] = {**bot_dict["agent_config"], **subtask.override_model_config} + else: + bot_dict["agent_config"] = subtask.override_model_config logger.info(f"Applied override_model_config to bot {bot_id} in subtask {subtask.id}: {subtask.override_model_config}") subtask_bots.append(bot_dict)This preserves other configuration parameters while overriding only the
private_modelfield.🤖 Prompt for AI Agents