From 0294b13a80cf31df128c743e8e0ab77e1332a830 Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Wed, 3 Dec 2025 16:35:22 -0500 Subject: [PATCH 1/2] Initial refactor. To Clean --- dspy/adapters/base.py | 76 +++++++-- dspy/adapters/types/history.py | 134 +++++++++++++-- dspy/predict/react.py | 207 +++++++++++++++++----- dspy/utils/inspect_history.py | 100 +++++++---- scripts/test_coding_agent.py | 18 ++ tests/adapters/test_baml_adapter.py | 3 +- tests/adapters/test_chat_adapter.py | 200 +++++++++++++++++++++- tests/adapters/test_json_adapter.py | 3 +- tests/predict/test_react.py | 255 ++++++++++++++++++---------- 9 files changed, 799 insertions(+), 197 deletions(-) create mode 100644 scripts/test_coding_agent.py diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 8696697d3a..9ac3146671 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -9,6 +9,7 @@ from dspy.adapters.types.reasoning import Reasoning from dspy.adapters.types.tool import Tool, ToolCalls from dspy.experimental import Citations +from dspy.signatures.field import InputField, OutputField from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks @@ -474,6 +475,23 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool: return name return None + def _serialize_kv_value(self, v: Any) -> Any: + """Safely serialize values for kv-mode formatting.""" + if isinstance(v, (str, int, float, bool)) or v is None: + return v + try: + return str(v) + except Exception: + return f"" + + def _make_dynamic_signature_for_inputs(self, keys: list[str]) -> type[Signature]: + """Create a dynamic signature with input fields only (no instructions).""" + return Signature({k: InputField() for k in keys}, instructions="") + + def _make_dynamic_signature_for_outputs(self, keys: list[str]) -> type[Signature]: + """Create a dynamic signature with output fields only (no instructions).""" + return Signature({k: OutputField() for k in keys}, instructions="") + def format_conversation_history( self, signature: type[Signature], @@ -483,6 +501,11 @@ def format_conversation_history( """Format the conversation history. This method formats the conversation history and the current input as multiturn messages. + Supports four modes: + - signature: Dict keys match signature input/output fields → user/assistant pairs + - kv: Nested {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs + - dict: Arbitrary serializable kv pairs → all in single user message (default) + - raw: Direct LM messages with {"role": "user", "content": "..."} → passed through Args: signature: The DSPy signature for which to format the conversation history. @@ -492,25 +515,50 @@ def format_conversation_history( Returns: A list of multiturn messages. """ - conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None - - if conversation_history is None: + history = inputs.get(history_field_name) + if history is None: return [] messages = [] - for message in conversation_history: - messages.append( - { + for msg in history.messages: + mode = history._detect_mode(msg) + + if mode == "raw": + messages.append(dict(msg)) + + elif mode == "kv": + if "input_fields" in msg: + input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()} + sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys())) + messages.append({ + "role": "user", + "content": self.format_user_message_content(sig, input_dict), + }) + if "output_fields" in msg: + output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()} + sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys())) + messages.append({ + "role": "assistant", + "content": self.format_assistant_message_content(sig, output_dict), + }) + + elif mode == "signature": + messages.append({ "role": "user", - "content": self.format_user_message_content(signature, message), - } - ) - messages.append( - { + "content": self.format_user_message_content(signature, msg), + }) + messages.append({ "role": "assistant", - "content": self.format_assistant_message_content(signature, message), - } - ) + "content": self.format_assistant_message_content(signature, msg), + }) + + else: # dict mode (default) - all kv pairs go into single user message + serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()} + sig = self._make_dynamic_signature_for_inputs(list(serialized.keys())) + messages.append({ + "role": "user", + "content": self.format_user_message_content(sig, serialized), + }) # Remove the history field from the inputs del inputs[history_field_name] diff --git a/dspy/adapters/types/history.py b/dspy/adapters/types/history.py index 2c39d5c4ab..9659a06c30 100644 --- a/dspy/adapters/types/history.py +++ b/dspy/adapters/types/history.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal import pydantic @@ -6,20 +6,42 @@ class History(pydantic.BaseModel): """Class representing the conversation history. - The conversation history is a list of messages, each message entity should have keys from the associated signature. - For example, if you have the following signature: - - ``` - class MySignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - answer: str = dspy.OutputField() - ``` - - Then the history should be a list of dictionaries with keys "question" and "answer". + History supports four message formats: + + 1. **Signature mode**: Dict keys match signature input/output fields → user/assistant pairs. + Must be explicitly set via mode="signature". + ```python + history = dspy.History(messages=[ + {"question": "What is 2+2?", "answer": "4"}, + ], mode="signature") + ``` + + 2. **KV mode**: Nested `{"input_fields": {...}, "output_fields": {...}}` → user/assistant pairs. + ```python + history = dspy.History.from_kv([ + {"input_fields": {"thought": "...", "tool_name": "search"}, "output_fields": {"observation": "..."}}, + ]) + ``` + + 3. **Dict mode** (default): Arbitrary serializable key-value pairs → all in single user message. + ```python + history = dspy.History(messages=[ + {"thought": "I need to search", "tool_name": "search", "observation": "Results found"}, + ]) + ``` + + 4. **Raw mode**: Direct LM messages with `{"role": "user", "content": "..."}` → passed through. + ```python + history = dspy.History.from_raw([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ]) + ``` + + The mode is auto-detected from the first message if not explicitly provided. Example: - ``` + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) @@ -41,7 +63,7 @@ class MySignature(dspy.Signature): ``` Example of capturing the conversation history: - ``` + ```python import dspy dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) @@ -59,6 +81,7 @@ class MySignature(dspy.Signature): """ messages: list[dict[str, Any]] + mode: Literal["signature", "kv", "dict", "raw"] | None = None model_config = pydantic.ConfigDict( frozen=True, @@ -66,3 +89,86 @@ class MySignature(dspy.Signature): validate_assignment=True, extra="forbid", ) + + def _detect_mode(self, msg: dict) -> str: + """Detect the mode for a message based on its structure. + + Detection rules: + - Raw: has "role" and "content" keys, but NOT "input_fields"/"output_fields" + - KV: keys are ONLY "input_fields" and/or "output_fields" + - Signature: must be explicitly set (requires matching against signature fields) + - Dict: everything else (default) - arbitrary kv pairs go into user message + """ + if self.mode: + return self.mode + + keys = set(msg.keys()) + + if {"role", "content"} <= keys and not ({"input_fields", "output_fields"} & keys): + return "raw" + + if keys <= {"input_fields", "output_fields"} and keys: + return "kv" + + return "dict" + + @pydantic.model_validator(mode="after") + def _validate_messages(self) -> "History": + for msg in self.messages: + detected = self._detect_mode(msg) + + if detected == "raw": + if not isinstance(msg.get("role"), str): + raise ValueError(f"'role' must be a string: {msg}") + # content can be None for tool call messages, or string otherwise + content = msg.get("content") + if content is not None and not isinstance(content, str): + raise ValueError(f"'content' must be a string or None: {msg}") + + elif detected == "kv": + if "input_fields" in msg and not isinstance(msg["input_fields"], dict): + raise ValueError(f"'input_fields' must be a dict: {msg}") + if "output_fields" in msg and not isinstance(msg["output_fields"], dict): + raise ValueError(f"'output_fields' must be a dict: {msg}") + + return self + + def with_messages(self, messages: list[dict[str, Any]]) -> "History": + """Return a new History with additional messages appended. + + Args: + messages: List of messages to append. + + Returns: + A new History instance with the messages appended. + """ + return History(messages=[*self.messages, *messages], mode=self.mode) + + @classmethod + def from_kv(cls, messages: list[dict[str, Any]]) -> "History": + """Create a History instance with KV mode. + + KV mode expects messages with "input_fields" and/or "output_fields" keys, + each containing a dict of field names to values. + + Args: + messages: List of dicts with "input_fields" and/or "output_fields" keys. + + Returns: + A History instance with mode="kv". + """ + return cls(messages=messages, mode="kv") + + @classmethod + def from_raw(cls, messages: list[dict[str, Any]]) -> "History": + """Create a History instance with raw mode. + + Raw mode expects direct LM messages with "role" and "content" keys. + + Args: + messages: List of dicts with "role" and "content" keys. + + Returns: + A History instance with mode="raw". + """ + return cls(messages=messages, mode="raw") diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 5f87879f80..b7e1d076cd 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -1,9 +1,12 @@ +import json import logging +import uuid from typing import TYPE_CHECKING, Any, Callable, Literal from litellm import ContextWindowExceededError import dspy +from dspy.adapters.types.history import History from dspy.adapters.types.tool import Tool from dspy.primitives.module import Module from dspy.signatures.signature import ensure_signature @@ -73,115 +76,231 @@ def get_weather(city: str) -> str: react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) - .append("trajectory", dspy.InputField(), type_=str) + .append("trajectory", dspy.InputField(), type_=History) .append("next_thought", dspy.OutputField(), type_=str) .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) ) + extract_instructions = ( + "You are an extraction Agent whose job it is to extract the fields: {outputs} from the given trajectory." + + "The original task was:\n" + + signature.instructions + + "\nIn trying to solve this task, an executor agent with has used tools to generate the conversation below." + + "\nGiven this trajectory, your only job is to extract the fields: {outputs}." + ) fallback_signature = dspy.Signature( {**signature.input_fields, **signature.output_fields}, - signature.instructions, - ).append("trajectory", dspy.InputField(), type_=str) + extract_instructions, + ).append("trajectory", dspy.InputField(desc="The history of the conversation. There is enough context to produce the final output"), type_=History) self.tools = tools self.react = dspy.Predict(react_signature) self.extract = dspy.ChainOfThought(fallback_signature) - def _format_trajectory(self, trajectory: dict[str, Any]): - adapter = dspy.settings.adapter or dspy.ChatAdapter() - trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") - return adapter.format_user_message_content(trajectory_signature, trajectory) - def forward(self, **input_args): - trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) - for idx in range(max_iters): + + # Check for existing history in input_args, otherwise start empty + trajectory = input_args.pop("trajectory", None) + if trajectory is None: + trajectory = History(messages=[], mode="raw") + + for _ in range(max_iters): try: - pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) + pred, trajectory = self._call_with_potential_truncation(self.react, trajectory, **input_args) except ValueError as err: logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args + # Add the agent's action to trajectory + trajectory, tool_call_id = self._append_action( + trajectory, + thought=pred.next_thought, + tool_name=pred.next_tool_name, + tool_args=pred.next_tool_args, + ) + # Execute tool and get observation try: - trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) + observation = self.tools[pred.next_tool_name](**pred.next_tool_args) except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + observation = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + + # Add observation to trajectory + trajectory = self._append_observation(trajectory, observation, tool_call_id) if pred.next_tool_name == "finish": break - extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) + extract, trajectory = self._call_with_potential_truncation(self.extract, trajectory, **input_args) + + # Add the extract step to the trajectory + trajectory = self._append_extract(trajectory, extract) + return dspy.Prediction(trajectory=trajectory, **extract) async def aforward(self, **input_args): - trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) - for idx in range(max_iters): + + # Check for existing history in input_args, otherwise start empty + trajectory = input_args.pop("trajectory", None) + if trajectory is None: + trajectory = History(messages=[], mode="raw") + + for _ in range(max_iters): try: - pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args) + pred, trajectory = await self._async_call_with_potential_truncation(self.react, trajectory, **input_args) except ValueError as err: logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}") break - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args + # Add the agent's action to trajectory + trajectory, tool_call_id = self._append_action( + trajectory, + thought=pred.next_thought, + tool_name=pred.next_tool_name, + tool_args=pred.next_tool_args, + ) + # Execute tool and get observation try: - trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) + observation = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + observation = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + + # Add observation to trajectory + trajectory = self._append_observation(trajectory, observation, tool_call_id) if pred.next_tool_name == "finish": break - extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) + extract, trajectory = await self._async_call_with_potential_truncation(self.extract, trajectory, **input_args) + + # Add the extract step to the trajectory + trajectory = self._append_extract(trajectory, extract) + return dspy.Prediction(trajectory=trajectory, **extract) - def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _append_action(self, trajectory: History, thought: str, tool_name: str, tool_args: dict) -> tuple[History, str]: + """Append an action (thought + tool call) to the trajectory. + + Returns: + Tuple of (updated trajectory, tool_call_id for matching with observation) + """ + tool_call_id = self._generate_tool_call_id() + new_msg = { + "role": "assistant", + "content": thought, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(tool_args), + }, + } + ], + } + return trajectory.with_messages([new_msg]), tool_call_id + + def _append_observation(self, trajectory: History, observation: Any, tool_call_id: str) -> History: + """Append a tool response to the trajectory.""" + if isinstance(observation, str): + content = observation + else: + try: + content = json.dumps(observation) + except (TypeError, ValueError): + content = str(observation) + + new_msg = { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + return trajectory.with_messages([new_msg]) + + def _append_extract(self, trajectory: History, extract) -> History: + """Append the extract step (final reasoning and outputs) to the trajectory.""" + extract_dict = dict(extract) + reasoning = extract_dict.pop("reasoning", None) + + content_parts = [] + if reasoning: + content_parts.append(f"Reasoning: {reasoning}") + for key, value in extract_dict.items(): + if isinstance(value, str): + content_parts.append(f"{key}: {value}") + else: + try: + content_parts.append(f"{key}: {json.dumps(value)}") + except (TypeError, ValueError): + content_parts.append(f"{key}: {value}") + + new_msg = { + "role": "assistant", + "content": "\n".join(content_parts), + } + return trajectory.with_messages([new_msg]) + + def _call_with_potential_truncation(self, module, trajectory: History, **input_args) -> tuple[Any, History]: + """Call module with trajectory, truncating if context window exceeded. + + Returns: + Tuple of (module result, potentially truncated trajectory) + """ for _ in range(3): try: - return module( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) + return module(**input_args, trajectory=trajectory), trajectory except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) + return None, trajectory - async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): + async def _async_call_with_potential_truncation(self, module, trajectory: History, **input_args) -> tuple[Any, History]: + """Call module with trajectory, truncating if context window exceeded. + + Returns: + Tuple of (module result, potentially truncated trajectory) + """ for _ in range(3): try: - return await module.acall( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) + return await module.acall(**input_args, trajectory=trajectory), trajectory except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) + return None, trajectory - def truncate_trajectory(self, trajectory): + def truncate_trajectory(self, trajectory: History) -> History: """Truncates the trajectory so that it fits in the context window. Users can override this method to implement their own truncation logic. + For tool call format, we remove pairs of messages (assistant + tool) together. """ - keys = list(trajectory.keys()) - if len(keys) < 4: - # Every tool call has 4 keys: thought, tool_name, tool_args, and observation. + if len(trajectory.messages) < 2: raise ValueError( "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be " "truncated because it only has one tool call." ) - for key in keys[:4]: - trajectory.pop(key) - - return trajectory + # Remove the oldest pair (assistant message with tool_calls + tool response) + messages = list(trajectory.messages) + if messages and messages[0].get("role") == "assistant" and messages[0].get("tool_calls"): + # Remove assistant + following tool message(s) + messages = messages[1:] + while messages and messages[0].get("role") == "tool": + messages = messages[1:] + else: + # Fallback: just remove the first message + messages = messages[1:] + + return History(messages=messages, mode="raw") def _fmt_exc(err: BaseException, *, limit: int = 5) -> str: diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 07934157fd..65a32fab25 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -10,6 +10,14 @@ def _blue(text: str, end: str = "\n"): return "\x1b[34m" + str(text) + "\x1b[0m" + end +def _yellow(text: str, end: str = "\n"): + return "\x1b[33m" + str(text) + "\x1b[0m" + end + + +def _cyan(text: str, end: str = "\n"): + return "\x1b[36m" + str(text) + "\x1b[0m" + end + + def pretty_print_history(history, n: int = 1): """Prints the last n prompts and their completions.""" @@ -22,37 +30,67 @@ def pretty_print_history(history, n: int = 1): print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") for msg in messages: - print(_red(f"{msg['role'].capitalize()} message:")) - if isinstance(msg["content"], str): - print(msg["content"].strip()) - else: - if isinstance(msg["content"], list): - for c in msg["content"]: - if c["type"] == "text": - print(c["text"].strip()) - elif c["type"] == "image_url": - image_str = "" - if "base64" in c["image_url"].get("url", ""): - len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - image_str = ( - f"<{c['image_url']['url'].split('base64,')[0]}base64," - f"" - ) - else: - image_str = f"" - print(_blue(image_str.strip())) - elif c["type"] == "input_audio": - audio_format = c["input_audio"]["format"] - len_audio = len(c["input_audio"]["data"]) - audio_str = f"