1313from doubleml .utils ._estimation import (
1414 _dml_cv_predict ,
1515 _dml_tune ,
16+ _double_dml_cv_predict ,
1617)
1718
1819
@@ -104,10 +105,6 @@ def __init__(
104105
105106 ml_m_is_classifier = self ._check_learner (ml_m , "ml_m" , regressor = True , classifier = True )
106107 self ._learner = {"ml_m" : ml_m , "ml_t" : ml_t , "ml_M" : ml_M }
107- # replace aggregated inner names with per-inner-fold names
108- inner_M_names = [f"ml_M_inner_{ i } " for i in range (self .n_folds_inner )]
109- inner_a_names = [f"ml_a_inner_{ i } " for i in range (self .n_folds_inner )]
110- self ._predictions_names = ["ml_r" , "ml_m" , "ml_a" , "ml_t" , "ml_M" ] + inner_M_names + inner_a_names
111108
112109 if ml_a is not None :
113110 ml_a_is_classifier = self ._check_learner (ml_a , "ml_a" , regressor = True , classifier = True )
@@ -162,56 +159,15 @@ def __init__(
162159 self ._sensitivity_implemented = False
163160
164161 def _initialize_ml_nuisance_params (self ):
165- self ._params = {learner : {key : [None ] * self .n_rep for key in self ._dml_data .d_cols } for learner in self ._learner }
162+ inner_M_names = [f"ml_M_inner_{ i } " for i in range (self .n_folds )]
163+ inner_a_names = [f"ml_a_inner_{ i } " for i in range (self .n_folds )]
164+ params_names = ["ml_m" , "ml_a" , "ml_t" , "ml_M" ] + inner_M_names + inner_a_names
165+ self ._params = {learner : {key : [None ] * self .n_rep for key in self ._dml_data .d_cols } for learner in params_names }
166166
167167 def _check_data (self , obj_dml_data ):
168168 if not np .array_equal (np .unique (obj_dml_data .y ), [0 , 1 ]):
169169 raise TypeError ("The outcome variable y must be binary with values 0 and 1." )
170170
171- def _double_dml_cv_predict (
172- self ,
173- estimator ,
174- estimator_name ,
175- x ,
176- y ,
177- smpls = None ,
178- smpls_inner = None ,
179- n_jobs = None ,
180- est_params = None ,
181- method = "predict" ,
182- sample_weights = None ,
183- ):
184- res = {}
185- res ["preds" ] = np .zeros (y .shape , dtype = float )
186- res ["preds_inner" ] = []
187- res ["targets_inner" ] = []
188- res ["models" ] = []
189- for smpls_single_split , smpls_double_split in zip (smpls , smpls_inner ):
190- res_inner = _dml_cv_predict (
191- estimator ,
192- x ,
193- y ,
194- smpls = smpls_double_split ,
195- n_jobs = n_jobs ,
196- est_params = est_params ,
197- method = method ,
198- return_models = True ,
199- sample_weights = sample_weights ,
200- )
201- _check_finite_predictions (res_inner ["preds" ], estimator , estimator_name , smpls_double_split )
202-
203- res ["preds_inner" ].append (res_inner ["preds" ])
204- res ["targets_inner" ].append (res_inner ["targets" ])
205- for model in res_inner ["models" ]:
206- res ["models" ].append (model )
207- if method == "predict_proba" :
208- res ["preds" ][smpls_single_split [1 ]] += model .predict_proba (x [smpls_single_split [1 ]])[:, 1 ]
209- else :
210- res ["preds" ][smpls_single_split [1 ]] += model .predict (x [smpls_single_split [1 ]])
211- res ["preds" ] /= len (smpls )
212- res ["targets" ] = np .copy (y )
213- return res
214-
215171 def _nuisance_est (self , smpls , n_jobs_cv , external_predictions , return_models = False ):
216172 x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
217173 x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
@@ -234,9 +190,14 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
234190 f"have to be provided (missing: { ', ' .join ([str (i ) for i in missing ])} )."
235191 )
236192 M_hat_inner = [external_predictions [f"ml_M_inner_{ i } " ] for i in range (self .n_folds_inner )]
237- M_hat = {"preds" : external_predictions ["ml_M" ], "preds_inner" : M_hat_inner , "targets" : None , "models" : None }
193+ M_hat = {
194+ "preds" : external_predictions ["ml_M" ],
195+ "preds_inner" : M_hat_inner ,
196+ "targets" : self ._dml_data .y ,
197+ "models" : None ,
198+ }
238199 else :
239- M_hat = self . _double_dml_cv_predict (
200+ M_hat = _double_dml_cv_predict (
240201 self ._learner ["ml_M" ],
241202 "ml_M" ,
242203 x_d_concat ,
@@ -250,7 +211,7 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
250211
251212 # nuisance m
252213 if m_external :
253- m_hat = {"preds" : external_predictions ["ml_m" ], "targets" : None , "models" : None }
214+ m_hat = {"preds" : external_predictions ["ml_m" ], "targets" : self . _dml_data . d , "models" : None }
254215 else :
255216 if self .score == "instrument" :
256217 weights = M_hat ["preds" ] * (1 - M_hat ["preds" ])
@@ -303,9 +264,14 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
303264 f"have to be provided (missing: { ', ' .join ([str (i ) for i in missing ])} )."
304265 )
305266 a_hat_inner = [external_predictions [f"ml_a_inner_{ i } " ] for i in range (self .n_folds_inner )]
306- a_hat = {"preds" : external_predictions ["ml_a" ], "preds_inner" : a_hat_inner , "targets" : None , "models" : None }
267+ a_hat = {
268+ "preds" : external_predictions ["ml_a" ],
269+ "preds_inner" : a_hat_inner ,
270+ "targets" : self ._dml_data .d ,
271+ "models" : None ,
272+ }
307273 else :
308- a_hat = self . _double_dml_cv_predict (
274+ a_hat = _double_dml_cv_predict (
309275 self ._learner ["ml_a" ],
310276 "ml_a" ,
311277 x ,
@@ -404,13 +370,6 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa
404370
405371 return psi_elements , preds
406372
407- @property
408- def predictions_names (self ):
409- """
410- The names of predictions for the nuisance functions.
411- """
412- return self ._predictions_names
413-
414373 def _score_elements (self , y , d , r_hat , m_hat ):
415374 # compute residual
416375 d_tilde = d - m_hat
@@ -438,8 +397,6 @@ def _sensitivity_element_est(self, preds):
438397 def _nuisance_tuning (
439398 self , smpls , param_grids , scoring_methods , n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search
440399 ):
441- if self ._i_rep is None :
442- raise ValueError ("tune_on_folds must be True as targets have to be created for ml_t on folds." )
443400 x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
444401 x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
445402 x_d_concat = np .hstack ((d .reshape (- 1 , 1 ), x ))
@@ -500,34 +457,16 @@ def _nuisance_tuning(
500457 a_best_params = [xx .best_params_ for xx in a_tune_res ]
501458
502459 # Create targets for tuning ml_t
503- M_hat = self ._double_dml_cv_predict (
504- self ._learner ["ml_M" ],
505- "ml_M" ,
506- x_d_concat ,
507- y ,
508- smpls = smpls ,
509- smpls_inner = self ._DoubleML__smpls__inner ,
510- n_jobs = n_jobs_cv ,
511- est_params = M_best_params ,
512- method = self ._predict_method ["ml_M" ],
513- )
514460
515- W_inner = []
516- for i , (train , _ ) in enumerate (smpls ):
517- M_iteration = M_hat ["preds_inner" ][i ][train ]
518- M_iteration = np .clip (M_iteration , 1e-8 , 1 - 1e-8 )
519- w = scipy .special .logit (M_iteration )
520- W_inner .append (w )
461+ M_hat = np .full_like (y , np .nan )
462+ for idx , (train_index , _ ) in enumerate (smpls ):
463+ M_hat [train_index ] = M_tune_res [idx ].predict_proba (x_d_concat [train_index , :])[:, 1 ]
521464
522- # Reshape W_inner into full-length arrays per fold: fill train indices, others are NaN
523- W_targets = []
524- for i , train in enumerate (train_inds ):
525- wt = np .full (x .shape [0 ], np .nan , dtype = float )
526- wt [train ] = W_inner [i ]
527- W_targets .append (wt )
465+ M_hat = np .clip (M_hat , 1e-8 , 1 - 1e-8 )
466+ W_hat = scipy .special .logit (M_hat )
528467
529468 t_tune_res = _dml_tune (
530- W_inner ,
469+ W_hat ,
531470 x ,
532471 train_inds ,
533472 self ._learner ["ml_t" ],
@@ -537,7 +476,6 @@ def _nuisance_tuning(
537476 n_jobs_cv ,
538477 search_mode ,
539478 n_iter_randomized_search ,
540- fold_specific_target = True ,
541479 )
542480 t_best_params = [xx .best_params_ for xx in t_tune_res ]
543481
0 commit comments