Skip to content

Commit 3dfaf42

Browse files
committed
refactor and modularisation
1 parent ed22b2b commit 3dfaf42

File tree

1 file changed

+64
-8
lines changed

1 file changed

+64
-8
lines changed

ml_grid/pipeline/main.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,35 +61,52 @@ class run:
6161

6262

6363

64-
def __init__(self, ml_grid_object: pipe, local_param_dict: Dict[str, Any]):
64+
def __init__(self, local_param_dict: Dict[str, Any], **kwargs):
6565
"""Initializes the run class.
6666
6767
This class takes the main data pipeline object and a dictionary of local
6868
parameters to set up and prepare for executing a series of hyperparameter
6969
searches across multiple machine learning models.
7070
71+
For hyperopt, this constructor can also accept keyword arguments to
72+
create the `pipe` object internally.
73+
7174
Args:
72-
ml_grid_object (pipe): The main data pipeline object, which contains
73-
the data (X_train, y_train, etc.) and a list of model classes
74-
to be evaluated.
7575
local_param_dict (Dict[str, Any]): A dictionary of parameters for the
7676
current experimental run, such as `param_space_size`.
77+
**kwargs: Keyword arguments to be passed to the `pipe` constructor.
78+
Expected keys include `file_name`, `drop_term_list`, `model_class_dict`,
79+
`base_project_dir`, `experiment_dir`, and `outcome_var`.
7780
"""
7881
self.global_params = global_parameters
7982

8083
self.logger = logging.getLogger('ml_grid')
8184

8285
self.verbose = self.global_params.verbose
8386

84-
self.error_raise = self.global_params.error_raise
87+
if 'ml_grid_object' in kwargs:
88+
self.ml_grid_object = kwargs['ml_grid_object']
89+
else:
90+
# Create the pipe object from the provided kwargs
91+
pipe_kwargs = {
92+
'file_name': kwargs.get('file_name'),
93+
'drop_term_list': kwargs.get('drop_term_list'),
94+
'model_class_dict': kwargs.get('model_class_dict'),
95+
'local_param_dict': local_param_dict,
96+
'base_project_dir': kwargs.get('base_project_dir'),
97+
'experiment_dir': kwargs.get('experiment_dir'),
98+
'outcome_var': kwargs.get('outcome_var'),
99+
'param_space_index': kwargs.get('param_space_index', 0)
100+
}
101+
self.ml_grid_object = pipe(**pipe_kwargs)
85102

86-
self.ml_grid_object = ml_grid_object
103+
self.error_raise = self.global_params.error_raise
87104

88105
self.sub_sample_param_space_pct = self.global_params.sub_sample_param_space_pct
89106

90107
self.parameter_space_size = local_param_dict.get("param_space_size")
91108

92-
self.model_class_list = ml_grid_object.model_class_list
109+
self.model_class_list = self.ml_grid_object.model_class_list
93110

94111
if self.verbose >= 2:
95112
self.logger.info(f"{len(self.model_class_list)} models loaded")
@@ -104,6 +121,7 @@ def __init__(self, ml_grid_object: pipe, local_param_dict: Dict[str, Any]):
104121
pg = ParameterGrid(elem.parameter_space)
105122
pg = len(pg)
106123
else:
124+
107125
pg = calculate_combinations(elem.parameter_space, steps=10)
108126

109127
#pg = ParameterGrid(elem.parameter_space)
@@ -179,6 +197,44 @@ def __init__(self, ml_grid_object: pipe, local_param_dict: Dict[str, Any]):
179197
if self.verbose >= 2:
180198
self.logger.info(f"Passed main init, len(arg_list): {len(self.arg_list)}")
181199

200+
def _prepare_run(self, model_class):
201+
"""Prepares a single model run by creating the necessary arguments."""
202+
return (
203+
model_class.algorithm_implementation,
204+
model_class.parameter_space,
205+
model_class.method_name,
206+
self.ml_grid_object,
207+
self.sub_sample_parameter_val,
208+
self.project_score_save_class_instance,
209+
)
210+
211+
def execute_single_model(self, args: Tuple) -> float:
212+
"""
213+
Executes the grid search for a single model and returns its score.
214+
This method is designed to be called within a hyperopt objective function.
215+
"""
216+
try:
217+
self.logger.info(f"Starting grid search for {args[2]}...")
218+
gscv_instance = grid_search_cross_validate.grid_search_crossvalidate(
219+
*args
220+
)
221+
score = gscv_instance.grid_search_cross_validate_score_result
222+
self.logger.info(f"Score for {args[2]}: {score:.4f}")
223+
return score
224+
225+
except Exception as e:
226+
self.logger.error(
227+
f"An exception occurred during grid search for {args[2]}: {e}",
228+
exc_info=True,
229+
)
230+
self.model_error_list.append([args[0], e, traceback.format_exc()])
231+
if self.error_raise:
232+
self.logger.critical("Halting due to 'error_raise' flag.")
233+
raise
234+
else:
235+
self.logger.warning("Continuing as 'error_raise' is False.")
236+
return 0.0 # Return a poor score on failure
237+
182238
def execute(self) -> Tuple[List[List[Any]], float]:
183239
"""Executes the grid search for each model in the list.
184240
@@ -223,7 +279,7 @@ def multi_run_wrapper(args: Tuple) -> Any:
223279
self.highest_score = max(self.highest_score, gscv_instance.grid_search_cross_validate_score_result)
224280
self.logger.info(f"Current highest score: {self.highest_score:.4f}")
225281

226-
except Exception as e: # Catches any exception from grid_search_crossvalidate
282+
except Exception as e: # Catches any exception from grid_search_crossvalidate
227283
self.logger.error(f"An exception occurred during grid search for {self.arg_list[k][2]}: {e}", exc_info=True)
228284

229285
self.model_error_list.append(

0 commit comments

Comments
 (0)