diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index f93ad18e..4b97a92e 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -839,3 +839,36 @@ def get_old_questions(session: SessionDep, datasource: int): for r in result: records.append(r.question) return records + + +def get_chat_history_questions(session: SessionDep, chat_id: int, limit: int = 3) -> List[str]: + """ + 获取当前chat的历史问题列表(按时间正序,最旧的在前) + + Args: + session: 数据库会话 + chat_id: 当前对话ID + limit: 获取的历史问题数量 + + Returns: + 历史问题列表,按时间正序排列 + """ + stmt = ( + select(ChatRecord.question) + .where( + and_( + ChatRecord.chat_id == chat_id, + ChatRecord.question.isnot(None), + ChatRecord.question != '', + ChatRecord.error.is_(None) + ) + ) + .order_by(ChatRecord.create_time.desc()) + .limit(limit) + ) + + result = session.execute(stmt) + questions = [row.question for row in result if row.question and row.question.strip()] + + # 反转列表,使最旧的在前 + return list(reversed(questions)) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index b7402ab1..0397137d 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -31,7 +31,7 @@ get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate, get_chat_predict_data, \ - get_chat_chart_config + get_chat_chart_config, get_chat_history_questions from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ ChatFinishStep, AxisObj from apps.data_training.curd.data_training import get_training_template @@ -101,6 +101,12 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C chat: Chat | None = session.get(Chat, chat_id) if not chat: raise SingleMessageError(f"Chat with id {chat_id} not found") + + # 获取历史问题(用于多轮对话embedding) + history_questions = [] + if settings.MULTI_TURN_EMBEDDING_ENABLED: + history_questions = get_chat_history_questions(session, chat_id, settings.MULTI_TURN_HISTORY_COUNT) + ds: CoreDatasource | AssistantOutDsSchema | None = None if chat.datasource: # Get available datasource @@ -117,7 +123,8 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C raise SingleMessageError("No available datasource configuration found") chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds) chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds, - question=chat_question.question, embedding=embedding) + question=chat_question.question, embedding=embedding, + history_questions=history_questions) self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id) self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id) diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index 153e5088..31d953cd 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -416,7 +416,7 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str, - embedding: bool = True) -> str: + embedding: bool = True, history_questions: List[str] = None) -> str: schema_str = "" table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) if len(table_objs) == 0: @@ -455,7 +455,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat # do table embedding if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: - tables = calc_table_embedding(tables, question) + tables = calc_table_embedding(tables, question, history_questions) # splice schema if tables: for s in tables: diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py index c467ecd8..0227f680 100644 --- a/backend/apps/datasource/embedding/table_embedding.py +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -3,6 +3,7 @@ import json import time import traceback +from typing import List from apps.ai_model.embedding import EmbeddingModelCache from apps.datasource.embedding.utils import cosine_similarity @@ -10,7 +11,38 @@ from common.utils.utils import SQLBotLogUtil -def get_table_embedding(tables: list[dict], question: str): +def build_context_query(current_question: str, history_questions: List[str] = None) -> str: + """ + 构建包含上下文的查询文本 + + Args: + current_question: 当前问题 + history_questions: 历史问题列表(按时间正序,最旧的在前) + + Returns: + 拼接后的查询文本 + """ + if not settings.MULTI_TURN_EMBEDDING_ENABLED or not history_questions: + return current_question + + max_history = settings.MULTI_TURN_HISTORY_COUNT + recent_history = history_questions[-max_history:] if history_questions else [] + + if not recent_history: + return current_question + + # 拼接:历史问题 + 当前问题 + context_parts = recent_history + [current_question] + + # 使用分隔符拼接,保持语义连贯 + context_query = " | ".join(context_parts) + + SQLBotLogUtil.info(f"Context query for embedding: {context_query}") + + return context_query + + +def get_table_embedding(tables: list[dict], question: str, history_questions: List[str] = None): _list = [] for table in tables: _list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0}) @@ -25,7 +57,9 @@ def get_table_embedding(tables: list[dict], question: str): end_time = time.time() SQLBotLogUtil.info(str(end_time - start_time)) - q_embedding = model.embed_query(question) + # 构建包含上下文的查询 + context_query = build_context_query(question, history_questions) + q_embedding = model.embed_query(context_query) for index in range(len(results)): item = results[index] _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) @@ -40,7 +74,18 @@ def get_table_embedding(tables: list[dict], question: str): return _list -def calc_table_embedding(tables: list[dict], question: str): +def calc_table_embedding(tables: list[dict], question: str, history_questions: List[str] = None): + """ + 计算表结构与问题的embedding相似度 + + Args: + tables: 表结构列表 + question: 当前问题 + history_questions: 历史问题列表(可选,用于多轮对话) + + Returns: + 按相似度排序的表列表 + """ _list = [] for table in tables: _list.append( @@ -58,7 +103,9 @@ def calc_table_embedding(tables: list[dict], question: str): # SQLBotLogUtil.info(str(end_time - start_time)) results = [item.get('embedding') for item in _list] - q_embedding = model.embed_query(question) + # 构建包含上下文的查询 + context_query = build_context_query(question, history_questions) + q_embedding = model.embed_query(context_query) for index in range(len(results)): item = results[index] if item: diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 4e09c201..69d17dfd 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -115,6 +115,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: TABLE_EMBEDDING_COUNT: int = 10 DS_EMBEDDING_COUNT: int = 10 + # Multi-turn embedding settings + MULTI_TURN_EMBEDDING_ENABLED: bool = True + MULTI_TURN_HISTORY_COUNT: int = 3 + ORACLE_CLIENT_PATH: str = '/opt/sqlbot/db_client/oracle_instant_client' @field_validator('SQL_DEBUG', @@ -123,6 +127,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: 'PARSE_REASONING_BLOCK_ENABLED', 'PG_POOL_PRE_PING', 'TABLE_EMBEDDING_ENABLED', + 'MULTI_TURN_EMBEDDING_ENABLED', mode='before') @classmethod def lowercase_bool(cls, v: Any) -> Any: