Skip to content

Commit e35603a

Browse files
committed
refactor(gepa): eliminate create-delete pattern in base_program build
- Process ReAct modules first, then individual predictors - Skip predictors already part of module configs (check inside JSON) - Remove redundant base_program.pop() calls - No duplicate enable_tool_optimization checks
1 parent a635768 commit e35603a

File tree

1 file changed

+56
-42
lines changed

1 file changed

+56
-42
lines changed

dspy/teleprompt/gepa/gepa.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -541,46 +541,18 @@ def feedback_fn(
541541

542542
# Instantiate GEPA with the simpler adapter-based API
543543
base_program = {}
544-
for name, pred in student.named_predictors():
545-
# Detect tool-using predictors via type checking
546-
def is_tool_field(annotation) -> bool:
547-
"""Check if a field annotation is Tool or contains Tool."""
548-
if annotation is Tool:
549-
return True
550-
origin = get_origin(annotation)
551-
if origin is not None:
552-
args = get_args(annotation)
553-
for arg in args:
554-
if is_tool_field(arg): # Recursive for nested types
555-
return True
556-
return False
557-
558-
# Detect tool-using predictors
559-
if self.enable_tool_optimization and any(is_tool_field(field.annotation) for field in pred.signature.input_fields.values()):
560-
# Use prefixed key for tool modules
561-
module_key = f"{TOOL_MODULE_PREFIX}:{name}"
562-
base_program[module_key] = json.dumps({
563-
name: pred.signature.instructions, # Use actual predictor name as key
564-
"tools": {} # Populated from traces
565-
}, indent=2)
566-
else:
567-
# Regular string instruction, no tools
568-
base_program[name] = pred.signature.instructions
569544

570-
# Always traverse to detect ReAct modules
571-
for module_path, module in student.named_sub_modules():
572-
# Only process ReAct modules
573-
if not isinstance(module, ReAct):
574-
continue
545+
# First, process ReAct modules to claim their predictors
546+
if self.enable_tool_optimization:
547+
for module_path, module in student.named_sub_modules():
548+
if not isinstance(module, ReAct):
549+
continue
575550

576-
if self.enable_tool_optimization:
577551
# Get predictor names via object identity
578552
extract_predictor = module.extract.predict
579553
react_predictor = module.react
580-
581554
extract_predictor_name = None
582555
react_predictor_name = None
583-
584556
for name, pred in student.named_predictors():
585557
if pred is extract_predictor:
586558
extract_predictor_name = name
@@ -605,16 +577,58 @@ def is_tool_field(annotation) -> bool:
605577
}
606578
}
607579

608-
# Remove the individual predictor keys (they're now part of ReAct module config)
609-
base_program.pop(react_predictor_name, None)
610-
base_program.pop(extract_predictor_name, None)
611580
base_program[module_key] = json.dumps(config, indent=2)
612-
else:
613-
logger.warning(
614-
f"Detected ReAct module at '{module_path}'. Consider using "
615-
"`enable_tool_optimization=True` to jointly optimize react instructions, "
616-
"extract instructions, tool descriptions, and tool argument descriptions."
617-
)
581+
else:
582+
# Warn if ReAct modules found but tool optimization disabled
583+
for module_path, module in student.named_sub_modules():
584+
if isinstance(module, ReAct):
585+
logger.warning(
586+
f"Detected ReAct module at '{module_path}'. Consider using "
587+
"`enable_tool_optimization=True` to jointly optimize react instructions, "
588+
"extract instructions, tool descriptions, and tool argument descriptions."
589+
)
590+
591+
# Then, process individual predictors (skip if already part of a module config)
592+
for name, pred in student.named_predictors():
593+
if self.enable_tool_optimization:
594+
# Skip if predictor is part of a module config (e.g., ReAct)
595+
found = False
596+
for val in base_program.values():
597+
try:
598+
config = json.loads(val)
599+
if name in config:
600+
found = True
601+
break
602+
except (json.JSONDecodeError, TypeError, ValueError):
603+
pass
604+
605+
if found:
606+
continue
607+
608+
# Detect tool-using predictors via type checking
609+
def is_tool_field(annotation) -> bool:
610+
"""Check if a field annotation is Tool or contains Tool."""
611+
if annotation is Tool:
612+
return True
613+
origin = get_origin(annotation)
614+
if origin is not None:
615+
args = get_args(annotation)
616+
for arg in args:
617+
if is_tool_field(arg): # Recursive for nested types
618+
return True
619+
return False
620+
621+
# Add tool module if predictor uses tools
622+
if any(is_tool_field(field.annotation) for field in pred.signature.input_fields.values()):
623+
module_key = f"{TOOL_MODULE_PREFIX}:{name}"
624+
base_program[module_key] = json.dumps({
625+
name: pred.signature.instructions,
626+
"tools": {} # Populated from traces
627+
}, indent=2)
628+
continue
629+
630+
# Add regular predictor (no tool optimization or no tools detected)
631+
base_program[name] = pred.signature.instructions
618632

619633
# Log base_program keys for debugging
620634
logger.info(f"Initialized base_program with {len(base_program)} components:")

0 commit comments

Comments
 (0)