Skip to content

Commit a635768

Browse files
committed
refactor(gepa): generalize proposer to support both ReAct and tool modules
- Rename ReActModuleProposer → ToolModuleProposer - Rename signature to GenerateImprovedToolModuleDescriptionsFromFeedback - Make base signature generic (current_predictor_instruction) - Dynamically add extract fields only for ReAct modules - Use prefix checks (REACT_MODULE_PREFIX) for reliable type detection - Support both 1-predictor (tool) and 2-predictor (ReAct) modules - Update routing to handle both TOOL_MODULE_PREFIX and REACT_MODULE_PREFIX - Clean variable names: primary_predictor_key, extract_predictor_key - Update all docstrings to reflect tool-using modules (not just ReAct)
1 parent 0a6016d commit a635768

File tree

2 files changed

+83
-88
lines changed

2 files changed

+83
-88
lines changed

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ def default_instruction_proposer(
138138

139139
instruction_proposer = default_instruction_proposer
140140

141-
# Init ReAct module proposer if tool optimization is enabled
142-
react_module_proposer = None
141+
# Init tool module proposer if tool optimization is enabled
142+
tool_module_proposer = None
143143
if self.enable_tool_optimization:
144-
from .instruction_proposal import ReActModuleProposer
144+
from .instruction_proposal import ToolModuleProposer
145145

146-
react_module_proposer = ReActModuleProposer()
146+
tool_module_proposer = ToolModuleProposer()
147147

148148
def propose_component_texts(
149149
candidate: dict[str, str],
@@ -160,9 +160,15 @@ def propose_component_texts(
160160
)
161161

162162
# Otherwise, route to appropriate proposers
163-
# Separate react_module components from regular instruction components
164-
react_module_components = [c for c in components_to_update if c.startswith(REACT_MODULE_PREFIX)]
165-
instruction_components = [c for c in components_to_update if not c.startswith(REACT_MODULE_PREFIX)]
163+
# Separate into two categories: components with tools vs regular instructions
164+
tool_module_components = []
165+
instruction_components = []
166+
167+
for c in components_to_update:
168+
if c.startswith(REACT_MODULE_PREFIX) or c.startswith(TOOL_MODULE_PREFIX):
169+
tool_module_components.append(c)
170+
else:
171+
instruction_components.append(c)
166172

167173
results: dict[str, str] = {}
168174

@@ -178,14 +184,14 @@ def propose_component_texts(
178184
)
179185
)
180186

181-
# Handle ReAct module components
182-
if react_module_components:
183-
logger.debug(f"Routing {len(react_module_components)} react_module components to react_module_proposer")
187+
# Handle components with tools (ReAct and Tool modules)
188+
if tool_module_components:
189+
logger.debug(f"Routing {len(tool_module_components)} tool_module components to tool_module_proposer")
184190
results.update(
185-
react_module_proposer(
191+
tool_module_proposer(
186192
candidate=candidate,
187193
reflective_dataset=reflective_dataset,
188-
components_to_update=react_module_components,
194+
components_to_update=tool_module_components,
189195
)
190196
)
191197

dspy/teleprompt/gepa/instruction_proposal.py

Lines changed: 65 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66

77
import dspy
88
from dspy.adapters.types.base_type import Type
9-
from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample
9+
from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX, ReflectiveExample
1010

1111
logger = logging.getLogger(__name__)
1212

13-
# Constants for ReAct module optimization
14-
REACT_MODULE_PREFIX = "react_module"
15-
1613

1714
class GenerateEnhancedMultimodalInstructionFromFeedback(dspy.Signature):
1815
"""I provided an assistant with instructions to perform a task involving visual content, but the assistant's performance needs improvement based on the examples and feedback below.
@@ -318,20 +315,17 @@ def __call__(
318315

319316
return updated_components
320317

321-
class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature):
322-
"""Improve a ReAct agent based on execution examples and feedback.
318+
class GenerateImprovedToolModuleDescriptionsFromFeedback(dspy.Signature):
319+
"""Improve a tool-using module based on execution examples and feedback.
323320
324321
These components are progressively optimized - refine what needs improvement.
325-
Analyze the trajectories to identify successful patterns and failure causes.
326-
Generate improved texts to help the agent succeed on similar tasks.
322+
Analyze the examples_with_feedback to identify successful patterns and failure causes.
323+
Generate improved texts to help the module succeed on similar tasks.
327324
Place improved texts at their appropriate level of abstraction and/or specificity.
328325
"""
329326

330-
current_react_instruction = dspy.InputField(
331-
desc="Current ReAct module instruction guiding the ReAct agent's reasoning and tool selection"
332-
)
333-
current_extract_instruction = dspy.InputField(
334-
desc="Current Extract module instruction for extracting final answers from trajectories"
327+
current_predictor_instruction = dspy.InputField(
328+
desc="Current instruction guiding the predictor"
335329
)
336330
current_tools = dspy.InputField(
337331
annotation=list[dspy.Tool],
@@ -341,88 +335,75 @@ class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature):
341335
desc="Execution examples with feedback showing successes and failures"
342336
)
343337

344-
improved_react_instruction: str | None = dspy.OutputField(
345-
desc="ReAct instruction for reasoning and tool selection",
346-
default=None
347-
)
348-
improved_extract_instruction: str | None = dspy.OutputField(
349-
desc="Extract instruction for answer extraction",
338+
improved_predictor_instruction: str | None = dspy.OutputField(
339+
desc="Improved instruction for the predictor",
350340
default=None
351341
)
352342

353343

354344

355345

356346

357-
class ReActModuleProposer(ProposalFn):
358-
"""Proposer for optimizing ReAct module configurations.
347+
class ToolModuleProposer(ProposalFn):
348+
"""Proposer for optimizing tool-using module configurations.
349+
350+
Supports two types of modules:
351+
- Tool modules (1 predictor): Optimizes predictor instruction and tool descriptions
352+
- ReAct modules (2 predictors): Jointly optimizes react instruction, extract instruction, and tool descriptions
359353
360-
Jointly optimizes three components of a ReAct module: the react instruction that guides
361-
reasoning and tool selection, the extract instruction for answer extraction from trajectories,
362-
and tool descriptions with their parameters. Uses dynamic signature generation to create
363-
output fields for each tool and parameter, enabling the reflection LM to optimize all parts
364-
cohesively based on execution feedback.
354+
Uses dynamic signature generation to create output fields for each tool and parameter,
355+
enabling the reflection LM to optimize all components cohesively based on execution feedback.
365356
366357
This joint optimization approach allows the LM to see how instructions and tool descriptions
367358
work together, leading to more coherent improvements than optimizing each component separately.
368359
"""
369360

370-
def __init__(self):
371-
"""Initialize the ReAct module proposer."""
372-
pass
373-
374361
def __call__(
375362
self,
376363
candidate: dict[str, str],
377364
reflective_dataset: dict[str, list[ReflectiveExample]],
378365
components_to_update: list[str],
379366
) -> dict[str, str]:
380-
"""Optimize ReAct module components.
367+
"""Optimize tool-using module components.
381368
382369
Args:
383370
candidate: Current component name -> JSON config mapping
384371
reflective_dataset: Component name -> list of reflective examples
385-
components_to_update: List of react_module component names to update
372+
components_to_update: List of tool-using module component names to update
386373
387374
Returns:
388375
dict: Mapping of component names to improved JSON configs
389376
"""
390377

391-
logger.info("\n=== ReActModuleProposer Called ===")
392-
logger.info(f"components_to_update: {components_to_update}")
393-
logger.info(f"candidate keys: {list(candidate.keys())}")
394-
logger.info(f"reflective_dataset keys: {list(reflective_dataset.keys())}")
395-
396378
updated_components = {}
397379

398380
for module_key in components_to_update:
399-
# Only handle react_module components
400-
if not module_key.startswith(REACT_MODULE_PREFIX):
401-
logger.debug(f"Skipping non-react_module component: {module_key}")
402-
continue
403-
404381
if module_key not in candidate or module_key not in reflective_dataset:
405382
logger.warning(f"Skipping {module_key}: not in candidate={module_key not in candidate}, not in reflective_dataset={module_key not in reflective_dataset}")
406383
continue
407384

408-
logger.info(f"\nProcessing react_module: {module_key}")
409-
410-
# Deserialize react module config
385+
# Deserialize module config
411386
try:
412-
current_react_config = json.loads(candidate[module_key])
413-
logger.debug(f"Deserialized config keys: {list(current_react_config.keys())}")
387+
current_module_config = json.loads(candidate[module_key])
414388
except json.JSONDecodeError as e:
415389
logger.error(f"Failed to deserialize config for {module_key}: {e}")
416390
continue
417391

392+
# Extract predictor keys (all keys except "tools")
393+
# Predictor keys are expected to be 1 for tool modules and 2 for ReAct modules (extra extract predictor)
394+
predictor_keys = [k for k in current_module_config if k != "tools"]
395+
logger.debug(f"Predictor keys: {predictor_keys}")
396+
primary_predictor_key = predictor_keys[0]
397+
extract_predictor_key = predictor_keys[1] if module_key.startswith(REACT_MODULE_PREFIX) else None
398+
418399
# Reconstruct Tool objects from JSON metadata so the adapter can format them for the reflection LM.
419400
# Tool.func cannot be serialized in JSON, so we use a placeholder (never executed).
420-
current_tools_dict = current_react_config.get("tools", {})
401+
current_tools_dict = current_module_config.get("tools", {})
421402
logger.info(f"Found {len(current_tools_dict)} tools: {list(current_tools_dict.keys())}")
422403
tools_list = []
423404
for tool_name, tool_info in current_tools_dict.items():
424405
tool = dspy.Tool(
425-
func=lambda: None, # Placeholder - Tool requires Callable, but only schema is used
406+
func=lambda *args, **kwargs: None, # Placeholder - Tool requires Callable, but only schema is used
426407
name=tool_name,
427408
desc=tool_info.get("desc", ""),
428409
)
@@ -431,7 +412,7 @@ def __call__(
431412
tools_list.append(tool)
432413

433414
# Build dynamic signature by extending base signature
434-
signature = GenerateImprovedReActDescriptionsFromFeedback
415+
signature = GenerateImprovedToolModuleDescriptionsFromFeedback
435416

436417
logger.debug(f"Building dynamic signature with {len(tools_list)} tools...")
437418

@@ -458,41 +439,49 @@ def __call__(
458439
)
459440
)
460441

461-
# Format examples
462-
formatted_examples = self._format_examples(reflective_dataset[module_key])
463-
logger.info(f"Formatted {len(reflective_dataset[module_key])} reflective examples")
464-
logger.debug(f"Examples preview: {formatted_examples[:200]}...")
465442

466-
logger.info("Calling reflection LM with dynamic signature...")
443+
kwargs = {
444+
"current_predictor_instruction": current_module_config[primary_predictor_key],
445+
"current_tools": tools_list,
446+
"examples_with_feedback": self._format_examples(reflective_dataset[module_key]),
447+
}
448+
# If module has extract predictor, add extract fields
449+
if extract_predictor_key is not None:
450+
signature = signature.append(
451+
"current_extract_instruction",
452+
dspy.InputField(desc="Current instruction for extraction predictor")
453+
)
454+
signature = signature.append(
455+
"improved_extract_instruction",
456+
dspy.OutputField(desc="Improved instruction for extraction", default=None)
457+
)
458+
kwargs["current_extract_instruction"] = current_module_config[extract_predictor_key]
459+
467460
propose_descriptions = dspy.Predict(signature)
468-
result = propose_descriptions(
469-
current_react_instruction=current_react_config.get("react", ""),
470-
current_extract_instruction=current_react_config.get("extract", ""),
471-
current_tools=tools_list, # List of Tool objects for adapter formatting
472-
examples_with_feedback=formatted_examples,
473-
)
461+
462+
result = propose_descriptions(**kwargs)
474463

475464
# Build improved config from reflection LM suggestions
476465
# Reflection LM returns None for components it doesn't want to change, or text for improvements
477466
logger.info("Building improved config from reflection LM response...")
478-
improved_react_config = {}
467+
improved_module_config = {}
479468

480-
# Update react instruction if reflection LM suggested improvement
481-
if result.improved_react_instruction is not None:
482-
improved_react_config["react"] = result.improved_react_instruction
483-
logger.debug(f"React instruction: {len(result.improved_react_instruction)} chars")
469+
# Update primary predictor instruction if reflection LM suggested improvement
470+
if result.improved_predictor_instruction is not None:
471+
improved_module_config[primary_predictor_key] = result.improved_predictor_instruction
472+
logger.debug(f"{primary_predictor_key}: {len(result.improved_predictor_instruction)} chars")
484473
else:
485-
logger.debug("React instruction: reflection LM suggests keeping original")
474+
logger.debug(f"{primary_predictor_key}: reflection LM suggests keeping original")
486475

487-
# Update extract instruction if reflection LM suggested improvement
488-
if result.improved_extract_instruction is not None:
489-
improved_react_config["extract"] = result.improved_extract_instruction
490-
logger.debug(f"Extract instruction: {len(result.improved_extract_instruction)} chars")
476+
# Update extract instruction if exists and reflection LM suggested improvement
477+
if extract_predictor_key is not None and result.improved_extract_instruction is not None:
478+
improved_module_config[extract_predictor_key] = result.improved_extract_instruction
479+
logger.debug(f"{extract_predictor_key}: {len(result.improved_extract_instruction)} chars")
491480
else:
492-
logger.debug("Extract instruction: reflection LM suggests keeping original)")
481+
logger.debug(f"{extract_predictor_key}: reflection LM suggests keeping original")
493482

494483
# Update tool descriptions if reflection LM suggested improvements
495-
improved_react_config["tools"] = {}
484+
improved_module_config["tools"] = {}
496485
for tool_name, tool_info in current_tools_dict.items():
497486
# Check if reflection LM suggested improving this tool's description
498487
improved_desc = getattr(result, f"improved_tool_{tool_name}_desc", None)
@@ -515,15 +504,15 @@ def __call__(
515504
if arg_desc is not None: # Reflection LM suggested improvement
516505
improved_tool_info["arg_desc"][arg_name] = arg_desc
517506

518-
improved_react_config["tools"][tool_name] = improved_tool_info
507+
improved_module_config["tools"][tool_name] = improved_tool_info
519508
logger.debug(f" Tool '{tool_name}': desc={len(improved_desc)} chars, params={len(improved_tool_info['arg_desc'])}")
520509

521510
# Serialize back to JSON
522-
updated_components[module_key] = json.dumps(improved_react_config, indent=2)
511+
updated_components[module_key] = json.dumps(improved_module_config, indent=2)
523512
logger.info(f"Successfully optimized {module_key}")
524513
logger.debug(f"Serialized config length: {len(updated_components[module_key])} chars")
525514

526-
logger.info(f"\nReActModuleProposer returning {len(updated_components)} components: {list(updated_components.keys())}")
515+
logger.info(f"\nToolModuleProposer returning {len(updated_components)} components: {list(updated_components.keys())}")
527516
return updated_components
528517

529518
def _format_examples(self, reflective_dataset: list[ReflectiveExample]) -> str:

0 commit comments

Comments
 (0)