From d308dd84386aa19590d1fb09c4d29c1227d6c56f Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Fri, 5 Dec 2025 10:17:34 -0500 Subject: [PATCH 1/4] Continue refactor --- dspy/adapters/base.py | 128 ++++++++--- dspy/adapters/types/history.py | 153 +++++++++++-- dspy/utils/inspect_history.py | 100 +++++--- tests/adapters/test_baml_adapter.py | 10 +- tests/adapters/test_chat_adapter.py | 340 +++++++++++++++++++++++++++- tests/adapters/test_json_adapter.py | 10 +- 6 files changed, 643 insertions(+), 98 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 8696697d3a..08b0dcc06b 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -9,6 +9,7 @@ from dspy.adapters.types.reasoning import Reasoning from dspy.adapters.types.tool import Tool, ToolCalls from dspy.experimental import Citations +from dspy.signatures.field import InputField, OutputField from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks @@ -452,13 +453,13 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]]) return messages - def _get_history_field_name(self, signature: type[Signature]) -> bool: + def _get_history_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.input_fields.items(): if field.annotation == History: return name return None - def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: + def _get_tool_call_input_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.input_fields.items(): # Look for annotation `list[dspy.Tool]` or `dspy.Tool` origin = get_origin(field.annotation) @@ -468,54 +469,119 @@ def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool: return name return None - def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: + def _get_tool_call_output_field_name(self, signature: type[Signature]) -> str | None: for name, field in signature.output_fields.items(): if field.annotation == ToolCalls: return name return None + def _serialize_kv_value(self, v: Any) -> Any: + """Safely serialize values for kv-mode formatting.""" + if isinstance(v, (str, int, float, bool)) or v is None: + return v + try: + return str(v) + except Exception: + return f"" + + def _make_dynamic_signature_for_inputs(self, keys: list[str]) -> type[Signature]: + """Create a dynamic signature with input fields only (no instructions).""" + return Signature({k: InputField() for k in keys}, instructions="") + + def _make_dynamic_signature_for_outputs(self, keys: list[str]) -> type[Signature]: + """Create a dynamic signature with output fields only (no instructions).""" + return Signature({k: OutputField() for k in keys}, instructions="") + def format_conversation_history( self, signature: type[Signature], history_field_name: str, inputs: dict[str, Any], ) -> list[dict[str, Any]]: - """Format the conversation history. + """Format the conversation history as multiturn messages. - This method formats the conversation history and the current input as multiturn messages. - - Args: - signature: The DSPy signature for which to format the conversation history. - history_field_name: The name of the history field in the signature. - inputs: The input arguments to the DSPy module. + Supports four modes: + - raw: Direct LM messages → passed through as-is + - demo: {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs + - flat: Arbitrary kv pairs → single user message per dict (default) + - signature: Dict keys match signature fields → user/assistant pairs - Returns: - A list of multiturn messages. + For backward compatibility, flat-mode histories whose message keys are subsets of the + signature fields (and overlap output fields) are treated as signature-mode. """ - conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None - - if conversation_history is None: + history = inputs.get(history_field_name) + if history is None: return [] - messages = [] - for message in conversation_history: - messages.append( - { + del inputs[history_field_name] + + if history.mode == "raw": + return [dict(msg) for msg in history.messages] + if history.mode == "demo": + return self._format_demo_history(history.messages) + if history.mode == "signature": + return self._format_signature_history(signature, history.messages) + + # Backward-compat shim: treat flat-mode as signature-mode if messages look like + # signature-style conversation history (keys subset of signature fields, overlapping outputs) + if history.mode == "flat" and history.messages: + sig_keys = set(signature.fields.keys()) + output_keys = set(signature.output_fields.keys()) + msg_key_sets = [set(m.keys()) for m in history.messages] + + if all(ks <= sig_keys for ks in msg_key_sets): + if any(ks & output_keys for ks in msg_key_sets): + return self._format_signature_history(signature, history.messages) + + return self._format_flat_history(history.messages) + + def _format_demo_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Format demo-mode history (input_fields/output_fields → user/assistant).""" + result = [] + for msg in messages: + if "input_fields" in msg: + input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()} + sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys())) + result.append({ "role": "user", - "content": self.format_user_message_content(signature, message), - } - ) - messages.append( - { + "content": self.format_user_message_content(sig, input_dict), + }) + if "output_fields" in msg: + output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()} + sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys())) + result.append({ "role": "assistant", - "content": self.format_assistant_message_content(signature, message), - } - ) - - # Remove the history field from the inputs - del inputs[history_field_name] + "content": self.format_assistant_message_content(sig, output_dict), + }) + return result - return messages + def _format_signature_history( + self, signature: type[Signature], messages: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Format signature-mode history (signature fields → user/assistant pairs).""" + result = [] + for msg in messages: + result.append({ + "role": "user", + "content": self.format_user_message_content(signature, msg), + }) + result.append({ + "role": "assistant", + "content": self.format_assistant_message_content(signature, msg), + }) + return result + + def _format_flat_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Format flat-mode history (all kv pairs in single user message).""" + result = [] + for msg in messages: + serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()} + sig = self._make_dynamic_signature_for_inputs(list(serialized.keys())) + result.append({ + "role": "user", + "content": self.format_user_message_content(sig, serialized), + }) + return result def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]: """Parse the LM output into a dictionary of the output fields. diff --git a/dspy/adapters/types/history.py b/dspy/adapters/types/history.py index 2c39d5c4ab..b177396089 100644 --- a/dspy/adapters/types/history.py +++ b/dspy/adapters/types/history.py @@ -1,25 +1,47 @@ -from typing import Any +import warnings +from typing import Any, Literal import pydantic class History(pydantic.BaseModel): - """Class representing the conversation history. - - The conversation history is a list of messages, each message entity should have keys from the associated signature. - For example, if you have the following signature: - - ``` - class MySignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - answer: str = dspy.OutputField() - ``` - - Then the history should be a list of dictionaries with keys "question" and "answer". + """Class representing conversation history. + + History supports four message formats via the `mode` parameter: + + 1. **Raw mode**: Direct LM messages with `{"role": "...", "content": "..."}`. + Used for ReAct trajectories and native tool calling. + ```python + history = dspy.History(messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], mode="raw") + ``` + + 2. **Demo mode**: Nested `{"input_fields": {...}, "output_fields": {...}}` pairs. + Used for few-shot demonstrations with explicit input/output separation. + ```python + history = dspy.History(messages=[ + {"input_fields": {"question": "2+2?"}, "output_fields": {"answer": "4"}}, + ], mode="demo") + ``` + + 3. **Flat mode** (default): Arbitrary key-value pairs in a single user message. + ```python + history = dspy.History(messages=[ + {"thought": "I need to search", "tool_name": "search", "observation": "Found it"}, + ]) + ``` + + 4. **Signature mode**: Dict keys match signature fields → user/assistant pairs. + ```python + history = dspy.History(messages=[ + {"question": "What is 2+2?", "answer": "4"}, + ], mode="signature") + ``` Example: - ``` + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) @@ -29,19 +51,16 @@ class MySignature(dspy.Signature): history: dspy.History = dspy.InputField() answer: str = dspy.OutputField() - history = dspy.History( - messages=[ - {"question": "What is the capital of France?", "answer": "Paris"}, - {"question": "What is the capital of Germany?", "answer": "Berlin"}, - ] - ) + history = dspy.History(messages=[ + {"question": "What is the capital of France?", "answer": "Paris"}, + ], mode="signature") predict = dspy.Predict(MySignature) outputs = predict(question="What is the capital of France?", history=history) ``` Example of capturing the conversation history: - ``` + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) @@ -53,12 +72,19 @@ class MySignature(dspy.Signature): predict = dspy.Predict(MySignature) outputs = predict(question="What is the capital of France?") - history = dspy.History(messages=[{"question": "What is the capital of France?", **outputs}]) + history = dspy.History(messages=[{"question": "What is the capital of France?", **outputs}], mode="signature") outputs_with_history = predict(question="Are you sure?", history=history) ``` """ messages: list[dict[str, Any]] + mode: Literal["signature", "demo", "flat", "raw"] = "flat" + """The message format mode for this history. + + Note: For backward compatibility, some adapters (e.g., ChatAdapter) may treat + flat-mode histories whose keys match a signature's fields as signature-mode, + formatting them as user/assistant pairs rather than single user messages. + """ model_config = pydantic.ConfigDict( frozen=True, @@ -66,3 +92,84 @@ class MySignature(dspy.Signature): validate_assignment=True, extra="forbid", ) + + @staticmethod + def _infer_mode_from_msg(msg: dict) -> str: + """Infer the mode from a message's structure. + + Detection rules (conservative): + - Raw: has "role" key and ONLY LM-like keys (role, content, tool_calls, tool_call_id, name) + - Demo: keys are ONLY "input_fields" and/or "output_fields" + - Flat: everything else (signature mode must be explicit) + """ + keys = set(msg.keys()) + lm_keys = {"role", "content", "tool_calls", "tool_call_id", "name"} + + if "role" in keys and keys <= lm_keys: + return "raw" + + if keys <= {"input_fields", "output_fields"} and keys: + return "demo" + + return "flat" + + def _validate_msg_for_mode(self, msg: dict, mode: str) -> None: + """Validate a message conforms to the expected mode structure.""" + if mode == "raw": + if not isinstance(msg.get("role"), str): + raise ValueError(f"Raw mode: 'role' must be a string: {msg}") + content = msg.get("content") + if content is not None and not isinstance(content, (str, list)): + raise ValueError(f"Raw mode: 'content' must be a string, list, or None: {msg}") + + elif mode == "demo": + if "input_fields" in msg and not isinstance(msg["input_fields"], dict): + raise ValueError(f"Demo mode: 'input_fields' must be a dict: {msg}") + if "output_fields" in msg and not isinstance(msg["output_fields"], dict): + raise ValueError(f"Demo mode: 'output_fields' must be a dict: {msg}") + + elif mode == "signature": + if not isinstance(msg, dict) or not msg: + raise ValueError(f"Signature mode: messages must be non-empty dicts: {msg}") + + def _warn_if_likely_wrong_mode(self, msg: dict, stacklevel: int = 2) -> None: + """Warn if a flat-mode message looks like it was intended for another mode.""" + keys = set(msg.keys()) + + if "role" in keys: + warnings.warn( + f"History message has 'role' key but is in flat mode. " + f"Did you mean to use mode='raw'? Message keys: {sorted(keys)}", + UserWarning, + stacklevel=stacklevel, + ) + elif keys & {"input_fields", "output_fields"}: + warnings.warn( + f"History message has 'input_fields'/'output_fields' but is in flat mode. " + f"Did you mean to use mode='demo'? Message keys: {sorted(keys)}", + UserWarning, + stacklevel=stacklevel, + ) + + @pydantic.model_validator(mode="after") + def _validate_messages(self) -> "History": + if not self.messages: + return self + + # Only infer if mode is the default "flat" and messages clearly match another mode + if self.mode == "flat": + inferred = self._infer_mode_from_msg(self.messages[0]) + if inferred in {"raw", "demo"}: + object.__setattr__(self, "mode", inferred) + + for msg in self.messages: + self._validate_msg_for_mode(msg, self.mode) + if self.mode == "flat": + # stacklevel=6: warn -> _warn_if_likely_wrong_mode -> _validate_messages -> validator -> __init__ -> caller + self._warn_if_likely_wrong_mode(msg, stacklevel=6) + + return self + + def with_messages(self, messages: list[dict[str, Any]]) -> "History": + """Return a new History with additional messages appended.""" + return History(messages=[*self.messages, *messages], mode=self.mode) diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 07934157fd..65a32fab25 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -10,6 +10,14 @@ def _blue(text: str, end: str = "\n"): return "\x1b[34m" + str(text) + "\x1b[0m" + end +def _yellow(text: str, end: str = "\n"): + return "\x1b[33m" + str(text) + "\x1b[0m" + end + + +def _cyan(text: str, end: str = "\n"): + return "\x1b[36m" + str(text) + "\x1b[0m" + end + + def pretty_print_history(history, n: int = 1): """Prints the last n prompts and their completions.""" @@ -22,37 +30,67 @@ def pretty_print_history(history, n: int = 1): print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") for msg in messages: - print(_red(f"{msg['role'].capitalize()} message:")) - if isinstance(msg["content"], str): - print(msg["content"].strip()) - else: - if isinstance(msg["content"], list): - for c in msg["content"]: - if c["type"] == "text": - print(c["text"].strip()) - elif c["type"] == "image_url": - image_str = "" - if "base64" in c["image_url"].get("url", ""): - len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - image_str = ( - f"<{c['image_url']['url'].split('base64,')[0]}base64," - f"" - ) - else: - image_str = f"" - print(_blue(image_str.strip())) - elif c["type"] == "input_audio": - audio_format = c["input_audio"]["format"] - len_audio = len(c["input_audio"]["data"]) - audio_str = f"