Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 100 additions & 31 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from typing import TYPE_CHECKING, Any, get_origin

Expand All @@ -8,7 +9,9 @@
from dspy.adapters.types.base_type import split_message_content_for_custom_types
from dspy.adapters.types.reasoning import Reasoning
from dspy.adapters.types.tool import Tool, ToolCalls
from dspy.adapters.utils import serialize_for_json
from dspy.experimental import Citations
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback, with_callbacks

Expand Down Expand Up @@ -452,13 +455,13 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])

return messages

def _get_history_field_name(self, signature: type[Signature]) -> bool:
def _get_history_field_name(self, signature: type[Signature]) -> str | None:
for name, field in signature.input_fields.items():
if field.annotation == History:
return name
return None

def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
def _get_tool_call_input_field_name(self, signature: type[Signature]) -> str | None:
for name, field in signature.input_fields.items():
# Look for annotation `list[dspy.Tool]` or `dspy.Tool`
origin = get_origin(field.annotation)
Expand All @@ -468,54 +471,120 @@ def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
return name
return None

def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
def _get_tool_call_output_field_name(self, signature: type[Signature]) -> str | None:
for name, field in signature.output_fields.items():
if field.annotation == ToolCalls:
return name
return None

def _serialize_kv_value(self, v: Any) -> str:
"""Serialize a value to string for flat-mode history formatting.

Uses the same pattern as format_field_value in adapters/utils.py.
"""
jsonable = serialize_for_json(v)
if isinstance(jsonable, (dict, list)):
return json.dumps(jsonable, ensure_ascii=False)
return str(jsonable)

def _make_dynamic_signature_for_inputs(self, keys: list[str]) -> type[Signature]:
"""Create a dynamic signature with input fields only (no instructions)."""
return Signature({k: InputField() for k in keys}, instructions="")

def _make_dynamic_signature_for_outputs(self, keys: list[str]) -> type[Signature]:
"""Create a dynamic signature with output fields only (no instructions)."""
return Signature({k: OutputField() for k in keys}, instructions="")

def format_conversation_history(
self,
signature: type[Signature],
history_field_name: str,
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
"""Format the conversation history.
"""Format the conversation history as multiturn messages.

This method formats the conversation history and the current input as multiturn messages.
Supports four modes:
- raw: Direct LM messages → passed through as-is
- demo: {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs
- flat: Arbitrary kv pairs → single user message per dict (default)
- signature: Dict keys match signature fields → user/assistant pairs

Args:
signature: The DSPy signature for which to format the conversation history.
history_field_name: The name of the history field in the signature.
inputs: The input arguments to the DSPy module.

Returns:
A list of multiturn messages.
For backward compatibility, flat-mode histories whose message keys are subsets of the
signature fields (and overlap output fields) are treated as signature-mode.
"""
conversation_history = inputs[history_field_name].messages if history_field_name in inputs else None

if conversation_history is None:
history = inputs.get(history_field_name)
if history is None:
return []

messages = []
for message in conversation_history:
messages.append(
{
del inputs[history_field_name]

if history.mode == "raw":
return [dict(msg) for msg in history.messages]
if history.mode == "demo":
return self._format_demo_history(history.messages)
if history.mode == "signature":
return self._format_signature_history(signature, history.messages)

# Backward-compat shim: treat flat-mode as signature-mode if messages look like
# signature-style conversation history (keys subset of signature fields, overlapping outputs)
if history.mode == "flat" and history.messages:
sig_keys = set(signature.fields.keys())
output_keys = set(signature.output_fields.keys())
msg_key_sets = [set(m.keys()) for m in history.messages]

if all(ks <= sig_keys for ks in msg_key_sets):
if any(ks & output_keys for ks in msg_key_sets):
return self._format_signature_history(signature, history.messages)

return self._format_flat_history(history.messages)

def _format_demo_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Format demo-mode history (input_fields/output_fields → user/assistant)."""
result = []
for msg in messages:
if "input_fields" in msg:
input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()}
sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys()))
result.append({
"role": "user",
"content": self.format_user_message_content(signature, message),
}
)
messages.append(
{
"content": self.format_user_message_content(sig, input_dict),
})
if "output_fields" in msg:
output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()}
sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys()))
result.append({
"role": "assistant",
"content": self.format_assistant_message_content(signature, message),
}
)

# Remove the history field from the inputs
del inputs[history_field_name]
"content": self.format_assistant_message_content(sig, output_dict),
})
return result

return messages
def _format_signature_history(
self, signature: type[Signature], messages: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Format signature-mode history (signature fields → user/assistant pairs)."""
result = []
for msg in messages:
result.append({
"role": "user",
"content": self.format_user_message_content(signature, msg),
})
result.append({
"role": "assistant",
"content": self.format_assistant_message_content(signature, msg),
})
return result

def _format_flat_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Format flat-mode history (all kv pairs in single user message)."""
result = []
for msg in messages:
serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()}
sig = self._make_dynamic_signature_for_inputs(list(serialized.keys()))
result.append({
"role": "user",
"content": self.format_user_message_content(sig, serialized),
})
return result

def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]:
"""Parse the LM output into a dictionary of the output fields.
Expand Down
Loading