@@ -61,35 +61,52 @@ class run:
6161
6262
6363
64- def __init__ (self , ml_grid_object : pipe , local_param_dict : Dict [str , Any ]):
64+ def __init__ (self , local_param_dict : Dict [str , Any ], ** kwargs ):
6565 """Initializes the run class.
6666
6767 This class takes the main data pipeline object and a dictionary of local
6868 parameters to set up and prepare for executing a series of hyperparameter
6969 searches across multiple machine learning models.
7070
71+ For hyperopt, this constructor can also accept keyword arguments to
72+ create the `pipe` object internally.
73+
7174 Args:
72- ml_grid_object (pipe): The main data pipeline object, which contains
73- the data (X_train, y_train, etc.) and a list of model classes
74- to be evaluated.
7575 local_param_dict (Dict[str, Any]): A dictionary of parameters for the
7676 current experimental run, such as `param_space_size`.
77+ **kwargs: Keyword arguments to be passed to the `pipe` constructor.
78+ Expected keys include `file_name`, `drop_term_list`, `model_class_dict`,
79+ `base_project_dir`, `experiment_dir`, and `outcome_var`.
7780 """
7881 self .global_params = global_parameters
7982
8083 self .logger = logging .getLogger ('ml_grid' )
8184
8285 self .verbose = self .global_params .verbose
8386
84- self .error_raise = self .global_params .error_raise
87+ if 'ml_grid_object' in kwargs :
88+ self .ml_grid_object = kwargs ['ml_grid_object' ]
89+ else :
90+ # Create the pipe object from the provided kwargs
91+ pipe_kwargs = {
92+ 'file_name' : kwargs .get ('file_name' ),
93+ 'drop_term_list' : kwargs .get ('drop_term_list' ),
94+ 'model_class_dict' : kwargs .get ('model_class_dict' ),
95+ 'local_param_dict' : local_param_dict ,
96+ 'base_project_dir' : kwargs .get ('base_project_dir' ),
97+ 'experiment_dir' : kwargs .get ('experiment_dir' ),
98+ 'outcome_var' : kwargs .get ('outcome_var' ),
99+ 'param_space_index' : kwargs .get ('param_space_index' , 0 )
100+ }
101+ self .ml_grid_object = pipe (** pipe_kwargs )
85102
86- self .ml_grid_object = ml_grid_object
103+ self .error_raise = self . global_params . error_raise
87104
88105 self .sub_sample_param_space_pct = self .global_params .sub_sample_param_space_pct
89106
90107 self .parameter_space_size = local_param_dict .get ("param_space_size" )
91108
92- self .model_class_list = ml_grid_object .model_class_list
109+ self .model_class_list = self . ml_grid_object .model_class_list
93110
94111 if self .verbose >= 2 :
95112 self .logger .info (f"{ len (self .model_class_list )} models loaded" )
@@ -104,6 +121,7 @@ def __init__(self, ml_grid_object: pipe, local_param_dict: Dict[str, Any]):
104121 pg = ParameterGrid (elem .parameter_space )
105122 pg = len (pg )
106123 else :
124+
107125 pg = calculate_combinations (elem .parameter_space , steps = 10 )
108126
109127 #pg = ParameterGrid(elem.parameter_space)
@@ -179,6 +197,44 @@ def __init__(self, ml_grid_object: pipe, local_param_dict: Dict[str, Any]):
179197 if self .verbose >= 2 :
180198 self .logger .info (f"Passed main init, len(arg_list): { len (self .arg_list )} " )
181199
200+ def _prepare_run (self , model_class ):
201+ """Prepares a single model run by creating the necessary arguments."""
202+ return (
203+ model_class .algorithm_implementation ,
204+ model_class .parameter_space ,
205+ model_class .method_name ,
206+ self .ml_grid_object ,
207+ self .sub_sample_parameter_val ,
208+ self .project_score_save_class_instance ,
209+ )
210+
211+ def execute_single_model (self , args : Tuple ) -> float :
212+ """
213+ Executes the grid search for a single model and returns its score.
214+ This method is designed to be called within a hyperopt objective function.
215+ """
216+ try :
217+ self .logger .info (f"Starting grid search for { args [2 ]} ..." )
218+ gscv_instance = grid_search_cross_validate .grid_search_crossvalidate (
219+ * args
220+ )
221+ score = gscv_instance .grid_search_cross_validate_score_result
222+ self .logger .info (f"Score for { args [2 ]} : { score :.4f} " )
223+ return score
224+
225+ except Exception as e :
226+ self .logger .error (
227+ f"An exception occurred during grid search for { args [2 ]} : { e } " ,
228+ exc_info = True ,
229+ )
230+ self .model_error_list .append ([args [0 ], e , traceback .format_exc ()])
231+ if self .error_raise :
232+ self .logger .critical ("Halting due to 'error_raise' flag." )
233+ raise
234+ else :
235+ self .logger .warning ("Continuing as 'error_raise' is False." )
236+ return 0.0 # Return a poor score on failure
237+
182238 def execute (self ) -> Tuple [List [List [Any ]], float ]:
183239 """Executes the grid search for each model in the list.
184240
@@ -223,7 +279,7 @@ def multi_run_wrapper(args: Tuple) -> Any:
223279 self .highest_score = max (self .highest_score , gscv_instance .grid_search_cross_validate_score_result )
224280 self .logger .info (f"Current highest score: { self .highest_score :.4f} " )
225281
226- except Exception as e : # Catches any exception from grid_search_crossvalidate
282+ except Exception as e : # Catches any exception from grid_search_crossvalidate
227283 self .logger .error (f"An exception occurred during grid search for { self .arg_list [k ][2 ]} : { e } " , exc_info = True )
228284
229285 self .model_error_list .append (
0 commit comments