Skip to content

Commit 393e4ba

Browse files
committed
Make mode explicit and consistent between messages
1 parent f87ceff commit 393e4ba

File tree

6 files changed

+345
-399
lines changed

6 files changed

+345
-399
lines changed

dspy/adapters/base.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -449,13 +449,13 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])
449449

450450
return messages
451451

452-
def _get_history_field_name(self, signature: type[Signature]) -> bool:
452+
def _get_history_field_name(self, signature: type[Signature]) -> str | None:
453453
for name, field in signature.input_fields.items():
454454
if field.annotation == History:
455455
return name
456456
return None
457457

458-
def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
458+
def _get_tool_call_input_field_name(self, signature: type[Signature]) -> str | None:
459459
for name, field in signature.input_fields.items():
460460
# Look for annotation `list[dspy.Tool]` or `dspy.Tool`
461461
origin = get_origin(field.annotation)
@@ -465,7 +465,7 @@ def _get_tool_call_input_field_name(self, signature: type[Signature]) -> bool:
465465
return name
466466
return None
467467

468-
def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
468+
def _get_tool_call_output_field_name(self, signature: type[Signature]) -> str | None:
469469
for name, field in signature.output_fields.items():
470470
if field.annotation == ToolCalls:
471471
return name
@@ -494,72 +494,75 @@ def format_conversation_history(
494494
history_field_name: str,
495495
inputs: dict[str, Any],
496496
) -> list[dict[str, Any]]:
497-
"""Format the conversation history.
497+
"""Format the conversation history as multiturn messages.
498498
499-
This method formats the conversation history and the current input as multiturn messages.
500499
Supports four modes:
501-
- signature: Dict keys match signature input/output fields → user/assistant pairs
502-
- kv: Nested {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs
503-
- dict: Arbitrary serializable kv pairs → all in single user message (default)
504-
- raw: Direct LM messages with {"role": "user", "content": "..."} → passed through
505-
506-
Args:
507-
signature: The DSPy signature for which to format the conversation history.
508-
history_field_name: The name of the history field in the signature.
509-
inputs: The input arguments to the DSPy module.
510-
511-
Returns:
512-
A list of multiturn messages.
500+
- raw: Direct LM messages → passed through as-is
501+
- demo: {"input_fields": {...}, "output_fields": {...}} → user/assistant pairs
502+
- flat: Arbitrary kv pairs → single user message per dict (default)
503+
- signature: Dict keys match signature fields → user/assistant pairs
513504
"""
514505
history = inputs.get(history_field_name)
515506
if history is None:
516507
return []
517508

518-
messages = []
519-
for msg in history.messages:
520-
mode = history._detect_mode(msg)
521-
522-
if mode == "raw":
523-
messages.append(dict(msg))
524-
525-
elif mode == "kv":
526-
if "input_fields" in msg:
527-
input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()}
528-
sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys()))
529-
messages.append({
530-
"role": "user",
531-
"content": self.format_user_message_content(sig, input_dict),
532-
})
533-
if "output_fields" in msg:
534-
output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()}
535-
sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys()))
536-
messages.append({
537-
"role": "assistant",
538-
"content": self.format_assistant_message_content(sig, output_dict),
539-
})
540-
541-
elif mode == "signature":
542-
messages.append({
509+
del inputs[history_field_name]
510+
511+
if history.mode == "raw":
512+
return [dict(msg) for msg in history.messages]
513+
if history.mode == "demo":
514+
return self._format_demo_history(history.messages)
515+
if history.mode == "signature":
516+
return self._format_signature_history(signature, history.messages)
517+
return self._format_flat_history(history.messages)
518+
519+
def _format_demo_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
520+
"""Format demo-mode history (input_fields/output_fields → user/assistant)."""
521+
result = []
522+
for msg in messages:
523+
if "input_fields" in msg:
524+
input_dict = {k: self._serialize_kv_value(v) for k, v in msg["input_fields"].items()}
525+
sig = self._make_dynamic_signature_for_inputs(list(input_dict.keys()))
526+
result.append({
543527
"role": "user",
544-
"content": self.format_user_message_content(signature, msg),
528+
"content": self.format_user_message_content(sig, input_dict),
545529
})
546-
messages.append({
530+
if "output_fields" in msg:
531+
output_dict = {k: self._serialize_kv_value(v) for k, v in msg["output_fields"].items()}
532+
sig = self._make_dynamic_signature_for_outputs(list(output_dict.keys()))
533+
result.append({
547534
"role": "assistant",
548-
"content": self.format_assistant_message_content(signature, msg),
549-
})
550-
551-
else: # dict mode (default) - all kv pairs go into single user message
552-
serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()}
553-
sig = self._make_dynamic_signature_for_inputs(list(serialized.keys()))
554-
messages.append({
555-
"role": "user",
556-
"content": self.format_user_message_content(sig, serialized),
535+
"content": self.format_assistant_message_content(sig, output_dict),
557536
})
537+
return result
558538

559-
# Remove the history field from the inputs
560-
del inputs[history_field_name]
561-
562-
return messages
539+
def _format_signature_history(
540+
self, signature: type[Signature], messages: list[dict[str, Any]]
541+
) -> list[dict[str, Any]]:
542+
"""Format signature-mode history (signature fields → user/assistant pairs)."""
543+
result = []
544+
for msg in messages:
545+
result.append({
546+
"role": "user",
547+
"content": self.format_user_message_content(signature, msg),
548+
})
549+
result.append({
550+
"role": "assistant",
551+
"content": self.format_assistant_message_content(signature, msg),
552+
})
553+
return result
554+
555+
def _format_flat_history(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
556+
"""Format flat-mode history (all kv pairs in single user message)."""
557+
result = []
558+
for msg in messages:
559+
serialized = {k: self._serialize_kv_value(v) for k, v in msg.items()}
560+
sig = self._make_dynamic_signature_for_inputs(list(serialized.keys()))
561+
result.append({
562+
"role": "user",
563+
"content": self.format_user_message_content(sig, serialized),
564+
})
565+
return result
563566

564567
def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]:
565568
"""Parse the LM output into a dictionary of the output fields.

0 commit comments

Comments
 (0)