Skip to content

Commit b24b448

Browse files
committed
Make mode explicit and consistent between messages
1 parent 0294b13 commit b24b448

File tree

6 files changed

+342
-377
lines changed

6 files changed

+342
-377
lines changed

dspy/adapters/base.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,13 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])
453453

454454
return messages
455455

456-
def _get_history_field_name(self, signature: type[Signature]) -> bool:
456+
def _get_history_field_name(self, signature: type[Signature]) -> str | None:
457457
for name, field in signature.input_fields.items():
458458
if field.annotation == History:
459459
return name
460460
return None
461461

462-
def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
462+
def _get_tool_call_input_field_name(self, signature: type[Signature]) -> str | None:
463463
for name, field in signature.input_fields.items():
464464
# Look for annotation `list[dspy.Tool]` or `dspy.Tool`
465465
origin = get_origin(field.annotation)
@@ -469,7 +469,7 @@ def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
469469
return name
470470
return None
471471

472-
def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
472+
def _get_tool_call_output_field_name(self, signature: type[Signature]) -> str | None:
473473
for name, field in signature.output_fields.items():
474474
if field.annotation == ToolCalls:
475475
return name
@@ -498,72 +498,75 @@ def format_conversation_history(
498498
history_field_name: str,
499499
inputs: dict[str, Any],
500500
) -> list[dict[str, Any]]:
501-
"""Format the conversation history.
501+
"""Format the conversation history as multiturn messages.
502502
503-
This method formats the conversation history and the current input as multiturn messages.
504503
Supports four modes:
505-
- signature: Dict keys match signature input/output fields → user/assistant pairs
506-
- kv: Nested {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs
507-
- dict: Arbitrary serializable kv pairs → all in single user message (default)
508-
- raw: Direct LM messages with {"role": "user", "content": "..."} → passed through
509-
510-
Args:
511-
signature: The DSPy signature for which to format the conversation history.
512-
history_field_name: The name of the history field in the signature.
513-
inputs: The input arguments to the DSPy module.
514-
515-
Returns:
516-
A list of multiturn messages.
504+
- raw: Direct LM messages → passed through as-is
505+
- demo: {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs
506+
- flat: Arbitrary kv pairs → single user message per dict (default)
507+
- signature: Dict keys match signature fields → user/assistant pairs
517508
"""
518509
history = inputs.get(history_field_name)
519510
if history is None:
520511
return []
521512

522-
messages = []
523-
for msg in history.messages:
524-
mode = history._detect_mode(msg)
525-
526-
if mode == "raw":
527-
messages.append(dict(msg))
528-
529-
elif mode == "kv":
530-
if "input_fields" in msg:
531-
input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()}
532-
sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys()))
533-
messages.append({
534-
"role": "user",
535-
"content": self.format_user_message_content(sig, input_dict),
536-
})
537-
if "output_fields" in msg:
538-
output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()}
539-
sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys()))
540-
messages.append({
541-
"role": "assistant",
542-
"content": self.format_assistant_message_content(sig, output_dict),
543-
})
544-
545-
elif mode == "signature":
546-
messages.append({
513+
del inputs[history_field_name]
514+
515+
if history.mode == "raw":
516+
return [dict(msg) for msg in history.messages]
517+
if history.mode == "demo":
518+
return self._format_demo_history(history.messages)
519+
if history.mode == "signature":
520+
return self._format_signature_history(signature, history.messages)
521+
return self._format_flat_history(history.messages)
522+
523+
def _format_demo_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
524+
"""Format demo-mode history (input_fields/output_fields → user/assistant)."""
525+
result = []
526+
for msg in messages:
527+
if "input_fields" in msg:
528+
input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()}
529+
sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys()))
530+
result.append({
547531
"role": "user",
548-
"content": self.format_user_message_content(signature, msg),
532+
"content": self.format_user_message_content(sig, input_dict),
549533
})
550-
messages.append({
534+
if "output_fields" in msg:
535+
output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()}
536+
sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys()))
537+
result.append({
551538
"role": "assistant",
552-
"content": self.format_assistant_message_content(signature, msg),
553-
})
554-
555-
else: # dict mode (default) - all kv pairs go into single user message
556-
serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()}
557-
sig = self._make_dynamic_signature_for_inputs(list(serialized.keys()))
558-
messages.append({
559-
"role": "user",
560-
"content": self.format_user_message_content(sig, serialized),
539+
"content": self.format_assistant_message_content(sig, output_dict),
561540
})
541+
return result
562542

563-
# Remove the history field from the inputs
564-
del inputs[history_field_name]
565-
566-
return messages
543+
def _format_signature_history(
544+
self, signature: type[Signature], messages: list[dict[str, Any]]
545+
) -> list[dict[str, Any]]:
546+
"""Format signature-mode history (signature fields → user/assistant pairs)."""
547+
result = []
548+
for msg in messages:
549+
result.append({
550+
"role": "user",
551+
"content": self.format_user_message_content(signature, msg),
552+
})
553+
result.append({
554+
"role": "assistant",
555+
"content": self.format_assistant_message_content(signature, msg),
556+
})
557+
return result
558+
559+
def _format_flat_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
560+
"""Format flat-mode history (all kv pairs in single user message)."""
561+
result = []
562+
for msg in messages:
563+
serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()}
564+
sig = self._make_dynamic_signature_for_inputs(list(serialized.keys()))
565+
result.append({
566+
"role": "user",
567+
"content": self.format_user_message_content(sig, serialized),
568+
})
569+
return result
567570

568571
def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]:
569572
"""Parse the LM output into a dictionary of the output fields.

dspy/adapters/types/history.py

Lines changed: 83 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,42 @@
44

55

66
class History(pydantic.BaseModel):
7-
"""Class representing the conversation history.
7+
"""Class representing conversation history.
88
9-
History supports four message formats:
10-
11-
1. **Signature mode**: Dict keys match signature input/output fields → user/assistant pairs.
12-
Must be explicitly set via mode="signature".
9+
History supports four message formats, with one mode per History instance:
10+
11+
1. **Raw mode**: Direct LM messages with `{"role": "...", "content": "..."}`.
12+
Used for ReAct trajectories and native tool calling.
1313
```python
14-
history = dspy.History(messages=[
15-
{"question": "What is 2+2?", "answer": "4"},
16-
], mode="signature")
14+
history = dspy.History.from_raw([
15+
{"role": "user", "content": "Hello"},
16+
{"role": "assistant", "content": "Hi there!"},
17+
])
1718
```
18-
19-
2. **KV mode**: Nested `{"input_fields": {...}, "output_fields": {...}}` → user/assistant pairs.
19+
20+
2. **Demo mode**: Nested `{"input_fields": {...}, "output_fields": {...}}` pairs.
21+
Used for few-shot demonstrations with explicit input/output separation.
2022
```python
21-
history = dspy.History.from_kv([
22-
{"input_fields": {"thought": "...", "tool_name": "search"}, "output_fields": {"observation": "..."}},
23+
history = dspy.History.from_demo([
24+
{"input_fields": {"question": "2+2?"}, "output_fields": {"answer": "4"}},
2325
])
2426
```
25-
26-
3. **Dict mode** (default): Arbitrary serializable key-value pairs → all in single user message.
27+
28+
3. **Flat mode** (default): Arbitrary key-value pairs in a single user message.
2729
```python
2830
history = dspy.History(messages=[
29-
{"thought": "I need to search", "tool_name": "search", "observation": "Results found"},
31+
{"thought": "I need to search", "tool_name": "search", "observation": "Found it"},
3032
])
3133
```
32-
33-
4. **Raw mode**: Direct LM messages with `{"role": "user", "content": "..."}` → passed through.
34+
35+
4. **Signature mode**: Dict keys match signature fields → user/assistant pairs.
36+
Must be explicitly set.
3437
```python
35-
history = dspy.History.from_raw([
36-
{"role": "user", "content": "Hello"},
37-
{"role": "assistant", "content": "Hi there!"},
38+
history = dspy.History.from_signature([
39+
{"question": "What is 2+2?", "answer": "4"},
3840
])
3941
```
4042
41-
The mode is auto-detected from the first message if not explicitly provided.
42-
4343
Example:
4444
```python
4545
import dspy
@@ -51,12 +51,9 @@ class MySignature(dspy.Signature):
5151
history: dspy.History = dspy.InputField()
5252
answer: str = dspy.OutputField()
5353
54-
history = dspy.History(
55-
messages=[
56-
{"question": "What is the capital of France?", "answer": "Paris"},
57-
{"question": "What is the capital of Germany?", "answer": "Berlin"},
58-
]
59-
)
54+
history = dspy.History.from_signature([
55+
{"question": "What is the capital of France?", "answer": "Paris"},
56+
])
6057
6158
predict = dspy.Predict(MySignature)
6259
outputs = predict(question="What is the capital of France?", history=history)
@@ -81,7 +78,7 @@ class MySignature(dspy.Signature):
8178
"""
8279

8380
messages: list[dict[str, Any]]
84-
mode: Literal["signature", "kv", "dict", "raw"] | None = None
81+
mode: Literal["signature", "demo", "flat", "raw"] = "flat"
8582

8683
model_config = pydantic.ConfigDict(
8784
frozen=True,
@@ -90,85 +87,85 @@ class MySignature(dspy.Signature):
9087
extra="forbid",
9188
)
9289

93-
def _detect_mode(self, msg: dict) -> str:
94-
"""Detect the mode for a message based on its structure.
95-
96-
Detection rules:
97-
- Raw: has "role" and "content" keys, but NOT "input_fields"/"output_fields"
98-
- KV: keys are ONLY "input_fields" and/or "output_fields"
99-
- Signature: must be explicitly set (requires matching against signature fields)
100-
- Dict: everything else (default) - arbitrary kv pairs go into user message
101-
"""
102-
if self.mode:
103-
return self.mode
90+
@staticmethod
91+
def _infer_mode_from_msg(msg: dict) -> str:
92+
"""Infer the mode from a message's structure.
10493
94+
Detection rules (conservative):
95+
- Raw: has "role" key and ONLY LM-like keys (role, content, tool_calls, tool_call_id, name)
96+
- Demo: keys are ONLY "input_fields" and/or "output_fields"
97+
- Flat: everything else (signature mode must be explicit)
98+
"""
10599
keys = set(msg.keys())
100+
lm_keys = {"role", "content", "tool_calls", "tool_call_id", "name"}
106101

107-
if {"role", "content"} <= keys and not ({"input_fields", "output_fields"} & keys):
102+
if "role" in keys and keys <= lm_keys:
108103
return "raw"
109104

110105
if keys <= {"input_fields", "output_fields"} and keys:
111-
return "kv"
106+
return "demo"
107+
108+
return "flat"
109+
110+
def _validate_msg_for_mode(self, msg: dict, mode: str) -> None:
111+
"""Validate a message conforms to the expected mode structure."""
112+
if mode == "raw":
113+
if not isinstance(msg.get("role"), str):
114+
raise ValueError(f"Raw mode: 'role' must be a string: {msg}")
115+
content = msg.get("content")
116+
if content is not None and not isinstance(content, str):
117+
raise ValueError(f"Raw mode: 'content' must be a string or None: {msg}")
112118

113-
return "dict"
119+
elif mode == "demo":
120+
if "input_fields" in msg and not isinstance(msg["input_fields"], dict):
121+
raise ValueError(f"Demo mode: 'input_fields' must be a dict: {msg}")
122+
if "output_fields" in msg and not isinstance(msg["output_fields"], dict):
123+
raise ValueError(f"Demo mode: 'output_fields' must be a dict: {msg}")
124+
125+
elif mode == "signature":
126+
if not isinstance(msg, dict) or not msg:
127+
raise ValueError(f"Signature mode: messages must be non-empty dicts: {msg}")
114128

115129
@pydantic.model_validator(mode="after")
116130
def _validate_messages(self) -> "History":
131+
if not self.messages:
132+
return self
133+
134+
# Only infer if mode is the default "flat" and messages clearly match another mode
135+
if self.mode == "flat":
136+
inferred = self._infer_mode_from_msg(self.messages[0])
137+
if inferred in {"raw", "demo"}:
138+
object.__setattr__(self, "mode", inferred)
139+
117140
for msg in self.messages:
118-
detected = self._detect_mode(msg)
119-
120-
if detected == "raw":
121-
if not isinstance(msg.get("role"), str):
122-
raise ValueError(f"'role' must be a string: {msg}")
123-
# content can be None for tool call messages, or string otherwise
124-
content = msg.get("content")
125-
if content is not None and not isinstance(content, str):
126-
raise ValueError(f"'content' must be a string or None: {msg}")
127-
128-
elif detected == "kv":
129-
if "input_fields" in msg and not isinstance(msg["input_fields"], dict):
130-
raise ValueError(f"'input_fields' must be a dict: {msg}")
131-
if "output_fields" in msg and not isinstance(msg["output_fields"], dict):
132-
raise ValueError(f"'output_fields' must be a dict: {msg}")
141+
self._validate_msg_for_mode(msg, self.mode)
133142

134143
return self
135144

136145
def with_messages(self, messages: list[dict[str, Any]]) -> "History":
137-
"""Return a new History with additional messages appended.
138-
139-
Args:
140-
messages: List of messages to append.
141-
142-
Returns:
143-
A new History instance with the messages appended.
144-
"""
146+
"""Return a new History with additional messages appended."""
145147
return History(messages=[*self.messages, *messages], mode=self.mode)
146148

147149
@classmethod
148-
def from_kv(cls, messages: list[dict[str, Any]]) -> "History":
149-
"""Create a History instance with KV mode.
150-
151-
KV mode expects messages with "input_fields" and/or "output_fields" keys,
152-
each containing a dict of field names to values.
153-
154-
Args:
155-
messages: List of dicts with "input_fields" and/or "output_fields" keys.
156-
157-
Returns:
158-
A History instance with mode="kv".
150+
def from_demo(cls, messages: list[dict[str, Any]]) -> "History":
151+
"""Create a History with demo mode.
152+
153+
Demo mode expects messages with "input_fields" and/or "output_fields" keys.
159154
"""
160-
return cls(messages=messages, mode="kv")
155+
return cls(messages=messages, mode="demo")
161156

162157
@classmethod
163158
def from_raw(cls, messages: list[dict[str, Any]]) -> "History":
164-
"""Create a History instance with raw mode.
165-
159+
"""Create a History with raw mode.
160+
166161
Raw mode expects direct LM messages with "role" and "content" keys.
167-
168-
Args:
169-
messages: List of dicts with "role" and "content" keys.
170-
171-
Returns:
172-
A History instance with mode="raw".
173162
"""
174163
return cls(messages=messages, mode="raw")
164+
165+
@classmethod
166+
def from_signature(cls, messages: list[dict[str, Any]]) -> "History":
167+
"""Create a History with signature mode.
168+
169+
Signature mode expects dicts with keys matching the signature's fields.
170+
"""
171+
return cls(messages=messages, mode="signature")

0 commit comments

Comments
 (0)