diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 3b1ce2fc..4c960171 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -197,6 +197,7 @@ class APIADDRequest(BaseRequest): source: str | None = Field(None, description="Source of the memory") chat_history: list[MessageDict] | None = Field(None, description="Chat history") session_id: str | None = Field(None, description="Session id") + customized_prompt: str | None = Field("", description="Customized prompt") operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 7d9f141d..6a13412f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -580,6 +580,7 @@ def _process_text_mem() -> list[dict[str, str]]: info={ "user_id": add_req.user_id, "session_id": target_session_id, + "customized_prompt": add_req.customized_prompt, }, mode="fast" if sync_mode == "async" else "fine", ) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 3845f37d..60e27599 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -31,7 +31,6 @@ ) from memos.utils import timed - logger = log.get_logger(__name__) PROMPT_DICT = { "chat": { @@ -201,17 +200,36 @@ def _make_memory_item( ), ) - def _get_llm_response(self, mem_str: str) -> dict: - lang = detect_lang(mem_str) - template = PROMPT_DICT["chat"][lang] - examples = PROMPT_DICT["chat"][f"{lang}_example"] - prompt = template.replace("${conversation}", mem_str) - if self.config.remove_prompt_example: - prompt = prompt.replace(examples, "") + def _get_llm_response(self, mem_str: str, info: dict | None = None) -> dict: + if info.get("customized_prompt", ""): + template = info["customized_prompt"] + prompt = template.replace("${conversation}", mem_str) + else: + lang = detect_lang(mem_str) + template = PROMPT_DICT["chat"][lang] + examples = PROMPT_DICT["chat"][f"{lang}_example"] + prompt = template.replace("${conversation}", mem_str) + if self.config.remove_prompt_example: + prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] + try: - response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + if info.get("customized_prompt", ""): + response_text = self.llm.generate(messages) + response_json = { + "memory list": [ + { + "key": response_text[:10], + "memory_type": "UserMemory", + "value": response_text, + "tags": [], + } + ], + "summary": "", + } + else: + response_text = self.llm.generate(messages) + response_json = self.parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { @@ -284,7 +302,14 @@ def _build_fast_node(w): mem_type = "UserMemory" if roles == {"user"} else "LongTermMemory" tags = ["mode:fast"] return self._make_memory_item( - value=text, info=info, memory_type=mem_type, tags=tags, sources=w["sources"] + value=text, + info=info, + memory_type=mem_type, + tags=tags, + sources=w["sources"], + background="" + if not info.get("customized_prompt", "") + else info["customized_prompt"], ) with ContextThreadPoolExecutor(max_workers=8) as ex: @@ -304,7 +329,7 @@ def _build_fast_node(w): logger.debug("Using unified Fine Mode") chat_read_nodes = [] for w in windows: - resp = self._get_llm_response(w["text"]) + resp = self._get_llm_response(w["text"], info) for m in resp.get("memory list", []): try: memory_type = ( @@ -328,7 +353,8 @@ def _build_fast_node(w): def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): raw_memory = raw_node.memory - response_json = self._get_llm_response(raw_memory) + info = {"customized_prompt": raw_node.metadata.background} + response_json = self._get_llm_response(raw_memory, info) chat_read_nodes = [] for memory_i_raw in response_json.get("memory list", []): try: