Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backend/alembic/versions/054_add_table_select_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""add table_select_answer column to chat_record

Revision ID: 054_table_select
Revises: 5755c0b95839
Create Date: 2025-12-23

"""
from alembic import op
import sqlalchemy as sa

revision = '054_table_select'
down_revision = '5755c0b95839'
branch_labels = None
depends_on = None


def upgrade():
op.add_column('chat_record', sa.Column('table_select_answer', sa.Text(), nullable=True))


def downgrade():
op.drop_column('chat_record', 'table_select_answer')
53 changes: 53 additions & 0 deletions backend/apps/chat/curd/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,26 @@ def save_select_datasource_answer(session: SessionDep, record_id: int, answer: s
return result


def save_table_select_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
"""保存 LLM 表选择的结果到 ChatRecord"""
if not record_id:
raise Exception("Record id cannot be None")
record = get_chat_record_by_id(session, record_id)

record.table_select_answer = answer

result = ChatRecord(**record.model_dump())

stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
table_select_answer=record.table_select_answer,
)

session.execute(stmt)
session.commit()

return result


def save_recommend_question_answer(session: SessionDep, record_id: int,
answer: dict = None) -> ChatRecord:
if not record_id:
Expand Down Expand Up @@ -839,3 +859,36 @@ def get_old_questions(session: SessionDep, datasource: int):
for r in result:
records.append(r.question)
return records


def get_chat_history_questions(session: SessionDep, chat_id: int, limit: int = 3) -> List[str]:
"""
获取当前chat的历史问题列表(按时间正序,最旧的在前)

Args:
session: 数据库会话
chat_id: 当前对话ID
limit: 获取的历史问题数量

Returns:
历史问题列表,按时间正序排列
"""
stmt = (
select(ChatRecord.question)
.where(
and_(
ChatRecord.chat_id == chat_id,
ChatRecord.question.isnot(None),
ChatRecord.question != '',
ChatRecord.error.is_(None)
)
)
.order_by(ChatRecord.create_time.desc())
.limit(limit)
)

result = session.execute(stmt)
questions = [row.question for row in result if row.question and row.question.strip()]

# 反转列表,使最旧的在前
return list(reversed(questions))
3 changes: 3 additions & 0 deletions backend/apps/chat/models/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class OperationEnum(Enum):
GENERATE_SQL_WITH_PERMISSIONS = '5'
CHOOSE_DATASOURCE = '6'
GENERATE_DYNAMIC_SQL = '7'
SELECT_TABLE = '8' # LLM 表选择


class ChatFinishStep(Enum):
Expand Down Expand Up @@ -112,6 +113,7 @@ class ChatRecord(SQLModel, table=True):
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
table_select_answer: str = Field(sa_column=Column(Text, nullable=True))
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
error: str = Field(sa_column=Column(Text, nullable=True))
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
Expand All @@ -137,6 +139,7 @@ class ChatRecordResult(BaseModel):
predict_data: Optional[str] = None
recommended_question: Optional[str] = None
datasource_select_answer: Optional[str] = None
table_select_answer: Optional[str] = None
finish: Optional[bool] = None
error: Optional[str] = None
analysis_record_id: Optional[int] = None
Expand Down
48 changes: 42 additions & 6 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate, get_chat_predict_data, \
get_chat_chart_config
get_chat_chart_config, get_chat_history_questions
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep, AxisObj
from apps.data_training.curd.data_training import get_training_template
Expand Down Expand Up @@ -101,6 +101,12 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
chat: Chat | None = session.get(Chat, chat_id)
if not chat:
raise SingleMessageError(f"Chat with id {chat_id} not found")

# 获取历史问题(用于多轮对话embedding)
history_questions = []
if settings.MULTI_TURN_EMBEDDING_ENABLED:
history_questions = get_chat_history_questions(session, chat_id, settings.MULTI_TURN_HISTORY_COUNT)

ds: CoreDatasource | AssistantOutDsSchema | None = None
if chat.datasource:
# Get available datasource
Expand All @@ -116,8 +122,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds,
question=chat_question.question, embedding=embedding)
# 延迟 get_table_schema 调用到 init_record 之后,以便记录 LLM 表选择日志
self._pending_schema_params = {
'session': session,
'current_user': current_user,
'ds': ds,
'question': chat_question.question,
'embedding': embedding,
'history_questions': history_questions,
'config': config
}

self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)
Expand Down Expand Up @@ -224,6 +238,22 @@ def init_messages(self):

def init_record(self, session: Session) -> ChatRecord:
self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question)

# 如果有延迟的 schema 获取,现在执行(此时 record 已存在,可以记录 LLM 表选择日志)
if hasattr(self, '_pending_schema_params') and self._pending_schema_params:
params = self._pending_schema_params
self.chat_question.db_schema = get_table_schema(
session=params['session'],
current_user=params['current_user'],
ds=params['ds'],
question=params['question'],
embedding=params['embedding'],
history_questions=params['history_questions'],
config=params['config'],
record_id=self.record.id
)
self._pending_schema_params = None

return self.record

def get_record(self):
Expand Down Expand Up @@ -349,7 +379,9 @@ def generate_recommend_questions_task(self, _session: Session):
session=_session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question,
embedding=False)
embedding=False,
config=self.config,
record_id=self.record.id)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
Expand Down Expand Up @@ -494,7 +526,9 @@ def select_datasource(self, _session: Session):
self.ds)
self.chat_question.db_schema = get_table_schema(session=_session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question)
question=self.chat_question.question,
config=self.config,
record_id=self.record.id)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type_name
# save chat
Expand Down Expand Up @@ -997,7 +1031,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
session=_session,
current_user=self.current_user,
ds=self.ds,
question=self.chat_question.question)
question=self.chat_question.question,
config=self.config,
record_id=self.record.id)
else:
self.validate_history_ds(_session)

Expand Down
40 changes: 34 additions & 6 deletions backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from sqlbot_xpack.permissions.models.ds_rules import DsRules
from sqlmodel import select

from apps.ai_model.model_factory import LLMConfig
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
from apps.datasource.embedding.table_embedding import calc_table_embedding
from apps.datasource.llm_select.table_selection import calc_table_llm_selection
from apps.datasource.utils.utils import aes_decrypt
from apps.db.constant import DB
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
Expand Down Expand Up @@ -416,7 +418,8 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core


def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
embedding: bool = True) -> str:
embedding: bool = True, history_questions: List[str] = None,
config: LLMConfig = None, lang: str = "中文", record_id: int = None) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
Expand All @@ -425,7 +428,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
all_tables = [] # temp save all tables

# 构建 table_name -> table_obj 映射,用于 LLM 表选择
table_name_to_obj = {}
for obj in table_objs:
table_name_to_obj[obj.table.table_name] = obj

schema_table = ''
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
table_comment = ''
Expand Down Expand Up @@ -453,16 +461,36 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables.append(t_obj)
all_tables.append(t_obj)

# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
tables = calc_table_embedding(tables, question)
# do table selection
used_llm_selection = False # 标记是否使用了 LLM 表选择
if embedding and tables:
if settings.TABLE_LLM_SELECTION_ENABLED and config:
# 使用 LLM 表选择
selected_table_names = calc_table_llm_selection(
config=config,
table_objs=table_objs,
question=question,
ds_table_relation=ds.table_relation,
history_questions=history_questions,
lang=lang,
session=session,
record_id=record_id
)
if selected_table_names:
# 根据选中的表名筛选 tables
selected_table_ids = [table_name_to_obj[name].table.id for name in selected_table_names if name in table_name_to_obj]
tables = [t for t in tables if t.get('id') in selected_table_ids]
used_llm_selection = True # LLM 成功选择了表
elif settings.TABLE_EMBEDDING_ENABLED:
# 使用 RAG 表选择
tables = calc_table_embedding(tables, question, history_questions)
# splice schema
if tables:
for s in tables:
schema_str += s.get('schema_table')

# field relation
if tables and ds.table_relation:
# field relation - LLM 表选择模式下不补全关联表,完全信任 LLM 的选择结果
if tables and ds.table_relation and not used_llm_selection:
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
if relations:
# Complete the missing table
Expand Down
55 changes: 51 additions & 4 deletions backend/apps/datasource/embedding/table_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,46 @@
import json
import time
import traceback
from typing import List

from apps.ai_model.embedding import EmbeddingModelCache
from apps.datasource.embedding.utils import cosine_similarity
from common.core.config import settings
from common.utils.utils import SQLBotLogUtil


def get_table_embedding(tables: list[dict], question: str):
def build_context_query(current_question: str, history_questions: List[str] = None) -> str:
"""
构建包含上下文的查询文本
Args:
current_question: 当前问题
history_questions: 历史问题列表(按时间正序,最旧的在前)
Returns:
拼接后的查询文本
"""
if not settings.MULTI_TURN_EMBEDDING_ENABLED or not history_questions:
return current_question

max_history = settings.MULTI_TURN_HISTORY_COUNT
recent_history = history_questions[-max_history:] if history_questions else []

if not recent_history:
return current_question

# 拼接:历史问题 + 当前问题
context_parts = recent_history + [current_question]

# 使用分隔符拼接,保持语义连贯
context_query = " | ".join(context_parts)

SQLBotLogUtil.info(f"Context query for embedding: {context_query}")

return context_query


def get_table_embedding(tables: list[dict], question: str, history_questions: List[str] = None):
_list = []
for table in tables:
_list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0})
Expand All @@ -25,7 +57,9 @@ def get_table_embedding(tables: list[dict], question: str):
end_time = time.time()
SQLBotLogUtil.info(str(end_time - start_time))

q_embedding = model.embed_query(question)
# 构建包含上下文的查询
context_query = build_context_query(question, history_questions)
q_embedding = model.embed_query(context_query)
for index in range(len(results)):
item = results[index]
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
Expand All @@ -40,7 +74,18 @@ def get_table_embedding(tables: list[dict], question: str):
return _list


def calc_table_embedding(tables: list[dict], question: str):
def calc_table_embedding(tables: list[dict], question: str, history_questions: List[str] = None):
"""
计算表结构与问题的embedding相似度
Args:
tables: 表结构列表
question: 当前问题
history_questions: 历史问题列表(可选,用于多轮对话)
Returns:
按相似度排序的表列表
"""
_list = []
for table in tables:
_list.append(
Expand All @@ -58,7 +103,9 @@ def calc_table_embedding(tables: list[dict], question: str):
# SQLBotLogUtil.info(str(end_time - start_time))
results = [item.get('embedding') for item in _list]

q_embedding = model.embed_query(question)
# 构建包含上下文的查询
context_query = build_context_query(question, history_questions)
q_embedding = model.embed_query(context_query)
for index in range(len(results)):
item = results[index]
if item:
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/datasource/llm_select/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Author: SQLBot
# Date: 2025/12/23
Loading