66
77import dspy
88from dspy .adapters .types .base_type import Type
9- from dspy .teleprompt .gepa .gepa_utils import ReflectiveExample
9+ from dspy .teleprompt .gepa .gepa_utils import REACT_MODULE_PREFIX , ReflectiveExample
1010
1111logger = logging .getLogger (__name__ )
1212
13- # Constants for ReAct module optimization
14- REACT_MODULE_PREFIX = "react_module"
15-
1613
1714class GenerateEnhancedMultimodalInstructionFromFeedback (dspy .Signature ):
1815 """I provided an assistant with instructions to perform a task involving visual content, but the assistant's performance needs improvement based on the examples and feedback below.
@@ -318,20 +315,17 @@ def __call__(
318315
319316 return updated_components
320317
321- class GenerateImprovedReActDescriptionsFromFeedback (dspy .Signature ):
322- """Improve a ReAct agent based on execution examples and feedback.
318+ class GenerateImprovedToolModuleDescriptionsFromFeedback (dspy .Signature ):
319+ """Improve a tool-using module based on execution examples and feedback.
323320
324321 These components are progressively optimized - refine what needs improvement.
325- Analyze the trajectories to identify successful patterns and failure causes.
326- Generate improved texts to help the agent succeed on similar tasks.
322+ Analyze the examples_with_feedback to identify successful patterns and failure causes.
323+ Generate improved texts to help the module succeed on similar tasks.
327324 Place improved texts at their appropriate level of abstraction and/or specificity.
328325 """
329326
330- current_react_instruction = dspy .InputField (
331- desc = "Current ReAct module instruction guiding the ReAct agent's reasoning and tool selection"
332- )
333- current_extract_instruction = dspy .InputField (
334- desc = "Current Extract module instruction for extracting final answers from trajectories"
327+ current_predictor_instruction = dspy .InputField (
328+ desc = "Current instruction guiding the predictor"
335329 )
336330 current_tools = dspy .InputField (
337331 annotation = list [dspy .Tool ],
@@ -341,88 +335,75 @@ class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature):
341335 desc = "Execution examples with feedback showing successes and failures"
342336 )
343337
344- improved_react_instruction : str | None = dspy .OutputField (
345- desc = "ReAct instruction for reasoning and tool selection" ,
346- default = None
347- )
348- improved_extract_instruction : str | None = dspy .OutputField (
349- desc = "Extract instruction for answer extraction" ,
338+ improved_predictor_instruction : str | None = dspy .OutputField (
339+ desc = "Improved instruction for the predictor" ,
350340 default = None
351341 )
352342
353343
354344
355345
356346
357- class ReActModuleProposer (ProposalFn ):
358- """Proposer for optimizing ReAct module configurations.
347+ class ToolModuleProposer (ProposalFn ):
348+ """Proposer for optimizing tool-using module configurations.
349+
350+ Supports two types of modules:
351+ - Tool modules (1 predictor): Optimizes predictor instruction and tool descriptions
352+ - ReAct modules (2 predictors): Jointly optimizes react instruction, extract instruction, and tool descriptions
359353
360- Jointly optimizes three components of a ReAct module: the react instruction that guides
361- reasoning and tool selection, the extract instruction for answer extraction from trajectories,
362- and tool descriptions with their parameters. Uses dynamic signature generation to create
363- output fields for each tool and parameter, enabling the reflection LM to optimize all parts
364- cohesively based on execution feedback.
354+ Uses dynamic signature generation to create output fields for each tool and parameter,
355+ enabling the reflection LM to optimize all components cohesively based on execution feedback.
365356
366357 This joint optimization approach allows the LM to see how instructions and tool descriptions
367358 work together, leading to more coherent improvements than optimizing each component separately.
368359 """
369360
370- def __init__ (self ):
371- """Initialize the ReAct module proposer."""
372- pass
373-
374361 def __call__ (
375362 self ,
376363 candidate : dict [str , str ],
377364 reflective_dataset : dict [str , list [ReflectiveExample ]],
378365 components_to_update : list [str ],
379366 ) -> dict [str , str ]:
380- """Optimize ReAct module components.
367+ """Optimize tool-using module components.
381368
382369 Args:
383370 candidate: Current component name -> JSON config mapping
384371 reflective_dataset: Component name -> list of reflective examples
385- components_to_update: List of react_module component names to update
372+ components_to_update: List of tool-using module component names to update
386373
387374 Returns:
388375 dict: Mapping of component names to improved JSON configs
389376 """
390377
391- logger .info ("\n === ReActModuleProposer Called ===" )
392- logger .info (f"components_to_update: { components_to_update } " )
393- logger .info (f"candidate keys: { list (candidate .keys ())} " )
394- logger .info (f"reflective_dataset keys: { list (reflective_dataset .keys ())} " )
395-
396378 updated_components = {}
397379
398380 for module_key in components_to_update :
399- # Only handle react_module components
400- if not module_key .startswith (REACT_MODULE_PREFIX ):
401- logger .debug (f"Skipping non-react_module component: { module_key } " )
402- continue
403-
404381 if module_key not in candidate or module_key not in reflective_dataset :
405382 logger .warning (f"Skipping { module_key } : not in candidate={ module_key not in candidate } , not in reflective_dataset={ module_key not in reflective_dataset } " )
406383 continue
407384
408- logger .info (f"\n Processing react_module: { module_key } " )
409-
410- # Deserialize react module config
385+ # Deserialize module config
411386 try :
412- current_react_config = json .loads (candidate [module_key ])
413- logger .debug (f"Deserialized config keys: { list (current_react_config .keys ())} " )
387+ current_module_config = json .loads (candidate [module_key ])
414388 except json .JSONDecodeError as e :
415389 logger .error (f"Failed to deserialize config for { module_key } : { e } " )
416390 continue
417391
392+ # Extract predictor keys (all keys except "tools")
393+ # Predictor keys are expected to be 1 for tool modules and 2 for ReAct modules (extra extract predictor)
394+ predictor_keys = [k for k in current_module_config if k != "tools" ]
395+ logger .debug (f"Predictor keys: { predictor_keys } " )
396+ primary_predictor_key = predictor_keys [0 ]
397+ extract_predictor_key = predictor_keys [1 ] if module_key .startswith (REACT_MODULE_PREFIX ) else None
398+
418399 # Reconstruct Tool objects from JSON metadata so the adapter can format them for the reflection LM.
419400 # Tool.func cannot be serialized in JSON, so we use a placeholder (never executed).
420- current_tools_dict = current_react_config .get ("tools" , {})
401+ current_tools_dict = current_module_config .get ("tools" , {})
421402 logger .info (f"Found { len (current_tools_dict )} tools: { list (current_tools_dict .keys ())} " )
422403 tools_list = []
423404 for tool_name , tool_info in current_tools_dict .items ():
424405 tool = dspy .Tool (
425- func = lambda : None , # Placeholder - Tool requires Callable, but only schema is used
406+ func = lambda * args , ** kwargs : None , # Placeholder - Tool requires Callable, but only schema is used
426407 name = tool_name ,
427408 desc = tool_info .get ("desc" , "" ),
428409 )
@@ -431,7 +412,7 @@ def __call__(
431412 tools_list .append (tool )
432413
433414 # Build dynamic signature by extending base signature
434- signature = GenerateImprovedReActDescriptionsFromFeedback
415+ signature = GenerateImprovedToolModuleDescriptionsFromFeedback
435416
436417 logger .debug (f"Building dynamic signature with { len (tools_list )} tools..." )
437418
@@ -458,41 +439,49 @@ def __call__(
458439 )
459440 )
460441
461- # Format examples
462- formatted_examples = self ._format_examples (reflective_dataset [module_key ])
463- logger .info (f"Formatted { len (reflective_dataset [module_key ])} reflective examples" )
464- logger .debug (f"Examples preview: { formatted_examples [:200 ]} ..." )
465442
466- logger .info ("Calling reflection LM with dynamic signature..." )
443+ kwargs = {
444+ "current_predictor_instruction" : current_module_config [primary_predictor_key ],
445+ "current_tools" : tools_list ,
446+ "examples_with_feedback" : self ._format_examples (reflective_dataset [module_key ]),
447+ }
448+ # If module has extract predictor, add extract fields
449+ if extract_predictor_key is not None :
450+ signature = signature .append (
451+ "current_extract_instruction" ,
452+ dspy .InputField (desc = "Current instruction for extraction predictor" )
453+ )
454+ signature = signature .append (
455+ "improved_extract_instruction" ,
456+ dspy .OutputField (desc = "Improved instruction for extraction" , default = None )
457+ )
458+ kwargs ["current_extract_instruction" ] = current_module_config [extract_predictor_key ]
459+
467460 propose_descriptions = dspy .Predict (signature )
468- result = propose_descriptions (
469- current_react_instruction = current_react_config .get ("react" , "" ),
470- current_extract_instruction = current_react_config .get ("extract" , "" ),
471- current_tools = tools_list , # List of Tool objects for adapter formatting
472- examples_with_feedback = formatted_examples ,
473- )
461+
462+ result = propose_descriptions (** kwargs )
474463
475464 # Build improved config from reflection LM suggestions
476465 # Reflection LM returns None for components it doesn't want to change, or text for improvements
477466 logger .info ("Building improved config from reflection LM response..." )
478- improved_react_config = {}
467+ improved_module_config = {}
479468
480- # Update react instruction if reflection LM suggested improvement
481- if result .improved_react_instruction is not None :
482- improved_react_config [ "react" ] = result .improved_react_instruction
483- logger .debug (f"React instruction : { len (result .improved_react_instruction )} chars" )
469+ # Update primary predictor instruction if reflection LM suggested improvement
470+ if result .improved_predictor_instruction is not None :
471+ improved_module_config [ primary_predictor_key ] = result .improved_predictor_instruction
472+ logger .debug (f"{ primary_predictor_key } : { len (result .improved_predictor_instruction )} chars" )
484473 else :
485- logger .debug ("React instruction : reflection LM suggests keeping original" )
474+ logger .debug (f" { primary_predictor_key } : reflection LM suggests keeping original" )
486475
487- # Update extract instruction if reflection LM suggested improvement
488- if result .improved_extract_instruction is not None :
489- improved_react_config [ "extract" ] = result .improved_extract_instruction
490- logger .debug (f"Extract instruction : { len (result .improved_extract_instruction )} chars" )
476+ # Update extract instruction if exists and reflection LM suggested improvement
477+ if extract_predictor_key is not None and result .improved_extract_instruction is not None :
478+ improved_module_config [ extract_predictor_key ] = result .improved_extract_instruction
479+ logger .debug (f"{ extract_predictor_key } : { len (result .improved_extract_instruction )} chars" )
491480 else :
492- logger .debug ("Extract instruction : reflection LM suggests keeping original) " )
481+ logger .debug (f" { extract_predictor_key } : reflection LM suggests keeping original" )
493482
494483 # Update tool descriptions if reflection LM suggested improvements
495- improved_react_config ["tools" ] = {}
484+ improved_module_config ["tools" ] = {}
496485 for tool_name , tool_info in current_tools_dict .items ():
497486 # Check if reflection LM suggested improving this tool's description
498487 improved_desc = getattr (result , f"improved_tool_{ tool_name } _desc" , None )
@@ -515,15 +504,15 @@ def __call__(
515504 if arg_desc is not None : # Reflection LM suggested improvement
516505 improved_tool_info ["arg_desc" ][arg_name ] = arg_desc
517506
518- improved_react_config ["tools" ][tool_name ] = improved_tool_info
507+ improved_module_config ["tools" ][tool_name ] = improved_tool_info
519508 logger .debug (f" Tool '{ tool_name } ': desc={ len (improved_desc )} chars, params={ len (improved_tool_info ['arg_desc' ])} " )
520509
521510 # Serialize back to JSON
522- updated_components [module_key ] = json .dumps (improved_react_config , indent = 2 )
511+ updated_components [module_key ] = json .dumps (improved_module_config , indent = 2 )
523512 logger .info (f"Successfully optimized { module_key } " )
524513 logger .debug (f"Serialized config length: { len (updated_components [module_key ])} chars" )
525514
526- logger .info (f"\n ReActModuleProposer returning { len (updated_components )} components: { list (updated_components .keys ())} " )
515+ logger .info (f"\n ToolModuleProposer returning { len (updated_components )} components: { list (updated_components .keys ())} " )
527516 return updated_components
528517
529518 def _format_examples (self , reflective_dataset : list [ReflectiveExample ]) -> str :
0 commit comments