Skip to content

Commit fffbbf7

Browse files
committed
Merge branch 'main' of https://github.com/DoubleML/doubleml-for-py into 0.9.X
2 parents 0f0f280 + 60f4a43 commit fffbbf7

37 files changed

+2467
-92
lines changed

doubleml/double_ml_data.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class DoubleMLData(DoubleMLBaseData):
110110
Default is ``None``.
111111
112112
s_col : None or str
113-
The selection variable (only relevant/used for SSM Estimatiors).
113+
The score or selection variable (only relevant/used for RDD or SSM Estimatiors).
114114
Default is ``None``.
115115
116116
use_other_treat_as_covariate : bool
@@ -182,7 +182,7 @@ def _data_summary_str(self):
182182
if self.t_col is not None:
183183
data_summary += f'Time variable: {self.t_col}\n'
184184
if self.s_col is not None:
185-
data_summary += f'Selection variable: {self.s_col}\n'
185+
data_summary += f'Score/Selection variable: {self.s_col}\n'
186186
data_summary += f'No. Observations: {self.n_obs}\n'
187187
return data_summary
188188

@@ -212,7 +212,7 @@ def from_arrays(cls, x, y, d, z=None, t=None, s=None, use_other_treat_as_covaria
212212
Default is ``None``.
213213
214214
s : :class:`numpy.ndarray`
215-
Array of the selection variable (only relevant/used for SSM models).
215+
Array of the score or selection variable (only relevant/used for RDD and SSM models).
216216
Default is ``None``.
217217
218218
use_other_treat_as_covariate : bool
@@ -351,7 +351,7 @@ def t(self):
351351
@property
352352
def s(self):
353353
"""
354-
Array of selection variable.
354+
Array of score or selection variable.
355355
"""
356356
if self.s_col is not None:
357357
return self._s.values
@@ -538,7 +538,7 @@ def t_col(self, value):
538538
@property
539539
def s_col(self):
540540
"""
541-
The selection variable.
541+
The score or selection variable.
542542
"""
543543
return self._s_col
544544

@@ -547,10 +547,10 @@ def s_col(self, value):
547547
reset_value = hasattr(self, '_s_col')
548548
if value is not None:
549549
if not isinstance(value, str):
550-
raise TypeError('The selection variable s_col must be of str type (or None). '
550+
raise TypeError('The score or selection variable s_col must be of str type (or None). '
551551
f'{str(value)} of type {str(type(value))} was passed.')
552552
if value not in self.all_variables:
553-
raise ValueError('Invalid selection variable s_col. '
553+
raise ValueError('Invalid score or selection variable s_col. '
554554
f'{value} is no data column.')
555555
self._s_col = value
556556
else:
@@ -725,24 +725,24 @@ def _check_disjoint_sets_t_s(self):
725725
if self.s_col is not None:
726726
s_col_set = {self.s_col}
727727
if not s_col_set.isdisjoint(x_cols_set):
728-
raise ValueError(f'{str(self.s_col)} cannot be set as selection variable ``s_col`` and covariate in '
728+
raise ValueError(f'{str(self.s_col)} cannot be set as score or selection variable ``s_col`` and covariate in '
729729
'``x_cols``.')
730730
if not s_col_set.isdisjoint(d_cols_set):
731-
raise ValueError(f'{str(self.s_col)} cannot be set as selection variable ``s_col`` and treatment variable in '
732-
'``d_cols``.')
731+
raise ValueError(f'{str(self.s_col)} cannot be set as score or selection variable ``s_col`` and treatment '
732+
'variable in ``d_cols``.')
733733
if not s_col_set.isdisjoint(y_col_set):
734-
raise ValueError(f'{str(self.s_col)} cannot be set as selection variable ``s_col`` and outcome variable '
735-
'``y_col``.')
734+
raise ValueError(f'{str(self.s_col)} cannot be set as score or selection variable ``s_col`` and outcome '
735+
'variable ``y_col``.')
736736
if self.z_cols is not None:
737737
z_cols_set = set(self.z_cols)
738738
if not s_col_set.isdisjoint(z_cols_set):
739-
raise ValueError(f'{str(self.s_col)} cannot be set as selection variable ``s_col`` and instrumental '
740-
'variable in ``z_cols``.')
739+
raise ValueError(f'{str(self.s_col)} cannot be set as score or selection variable ``s_col`` and '
740+
'instrumental variable in ``z_cols``.')
741741
if self.t_col is not None:
742742
t_col_set = {self.t_col}
743743
if not s_col_set.isdisjoint(t_col_set):
744-
raise ValueError(f'{str(self.s_col)} cannot be set as selection variable ``s_col`` and time variable '
745-
'``t_col``.')
744+
raise ValueError(f'{str(self.s_col)} cannot be set as score or selection variable ``s_col`` and time '
745+
'variable ``t_col``.')
746746

747747

748748
class DoubleMLClusterData(DoubleMLData):
@@ -780,7 +780,7 @@ class DoubleMLClusterData(DoubleMLData):
780780
Default is ``None``.
781781
782782
s_col : None or str
783-
The selection variable (only relevant/used for SSM Estimatiors).
783+
The score or selection variable (only relevant/used for RDD and SSM Estimatiors).
784784
Default is ``None``.
785785
786786
use_other_treat_as_covariate : bool
@@ -854,7 +854,7 @@ def _data_summary_str(self):
854854
if self.t_col is not None:
855855
data_summary += f'Time variable: {self.t_col}\n'
856856
if self.s_col is not None:
857-
data_summary += f'Selection variable: {self.s_col}\n'
857+
data_summary += f'Score/Selection variable: {self.s_col}\n'
858858

859859
data_summary += f'No. Observations: {self.n_obs}\n'
860860
return data_summary
@@ -888,7 +888,7 @@ def from_arrays(cls, x, y, d, cluster_vars, z=None, t=None, s=None, use_other_tr
888888
Default is ``None``.
889889
890890
s : :class:`numpy.ndarray`
891-
Array of the selection variable (only relevant/used for SSM models).
891+
Array of the score or selection variable (only relevant/used for RDD or SSM models).
892892
Default is ``None``.
893893
894894
use_other_treat_as_covariate : bool
@@ -1039,7 +1039,7 @@ def _check_disjoint_sets_cluster_cols(self):
10391039
'cluster variable in ``cluster_cols``.')
10401040
if self.s_col is not None:
10411041
if not s_col_set.isdisjoint(cluster_cols_set):
1042-
raise ValueError(f'{str(self.s_col)} cannot be set as selection variable ``s_col`` and '
1042+
raise ValueError(f'{str(self.s_col)} cannot be set as score or selection variable ``s_col`` and '
10431043
'cluster variable in ``cluster_cols``.')
10441044

10451045
def _set_cluster_vars(self):

doubleml/irm/apo.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _check_data(self, obj_dml_data):
389389

390390
return
391391

392-
def capo(self, basis, is_gate=False):
392+
def capo(self, basis, is_gate=False, **kwargs):
393393
"""
394394
Calculate conditional average potential outcomes (CAPO) for a given basis.
395395
@@ -398,10 +398,14 @@ def capo(self, basis, is_gate=False):
398398
basis : :class:`pandas.DataFrame`
399399
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
400400
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
401+
401402
is_gate : bool
402403
Indicates whether the basis is constructed for GATE/GAPOs (dummy-basis).
403404
Default is ``False``.
404405
406+
**kwargs: dict
407+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
408+
405409
Returns
406410
-------
407411
model : :class:`doubleML.DoubleMLBLP`
@@ -420,10 +424,10 @@ def capo(self, basis, is_gate=False):
420424
orth_signal = self.psi_elements['psi_b'].reshape(-1)
421425
# fit the best linear predictor
422426
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
423-
model.fit()
427+
model.fit(**kwargs)
424428
return model
425429

426-
def gapo(self, groups):
430+
def gapo(self, groups, **kwargs):
427431
"""
428432
Calculate group average potential outcomes (GAPO) for groups.
429433
@@ -434,6 +438,9 @@ def gapo(self, groups):
434438
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
435439
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
436440
441+
**kwargs: dict
442+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
443+
437444
Returns
438445
-------
439446
model : :class:`doubleML.DoubleMLBLP`
@@ -453,5 +460,5 @@ def gapo(self, groups):
453460
if any(groups.sum(0) <= 5):
454461
warnings.warn('At least one group effect is estimated with less than 6 observations.')
455462

456-
model = self.capo(groups, is_gate=True)
463+
model = self.capo(groups, is_gate=True, **kwargs)
457464
return model

doubleml/irm/irm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
431431

432432
return res
433433

434-
def cate(self, basis, is_gate=False):
434+
def cate(self, basis, is_gate=False, **kwargs):
435435
"""
436436
Calculate conditional average treatment effects (CATE) for a given basis.
437437
@@ -440,10 +440,14 @@ def cate(self, basis, is_gate=False):
440440
basis : :class:`pandas.DataFrame`
441441
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
442442
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
443+
443444
is_gate : bool
444445
Indicates whether the basis is constructed for GATEs (dummy-basis).
445446
Default is ``False``.
446447
448+
**kwargs: dict
449+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
450+
447451
Returns
448452
-------
449453
model : :class:`doubleML.DoubleMLBLP`
@@ -462,10 +466,10 @@ def cate(self, basis, is_gate=False):
462466
orth_signal = self.psi_elements['psi_b'].reshape(-1)
463467
# fit the best linear predictor
464468
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
465-
model.fit()
469+
model.fit(**kwargs)
466470
return model
467471

468-
def gate(self, groups):
472+
def gate(self, groups, **kwargs):
469473
"""
470474
Calculate group average treatment effects (GATE) for groups.
471475
@@ -476,6 +480,9 @@ def gate(self, groups):
476480
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
477481
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
478482
483+
**kwargs: dict
484+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
485+
479486
Returns
480487
-------
481488
model : :class:`doubleML.DoubleMLBLP`
@@ -495,7 +502,7 @@ def gate(self, groups):
495502
if any(groups.sum(0) <= 5):
496503
warnings.warn('At least one group effect is estimated with less than 6 observations.')
497504

498-
model = self.cate(groups, is_gate=True)
505+
model = self.cate(groups, is_gate=True, **kwargs)
499506
return model
500507

501508
def policy_tree(self, features, depth=2, **tree_params):

doubleml/irm/tests/test_apo.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,14 @@ def test_dml_apo_sensitivity(dml_apo_fixture):
200200
rtol=1e-9, atol=1e-4)
201201

202202

203+
@pytest.fixture(scope='module',
204+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
205+
def cov_type(request):
206+
return request.param
207+
208+
203209
@pytest.mark.ci
204-
def test_dml_apo_capo_gapo(treatment_level):
210+
def test_dml_apo_capo_gapo(treatment_level, cov_type):
205211
n = 20
206212
# collect data
207213
np.random.seed(42)
@@ -221,25 +227,28 @@ def test_dml_apo_capo_gapo(treatment_level):
221227
dml_obj.fit()
222228
# create a random basis
223229
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
224-
capo = dml_obj.capo(random_basis)
230+
capo = dml_obj.capo(random_basis, cov_type=cov_type)
225231
assert isinstance(capo, dml.utils.blp.DoubleMLBLP)
226232
assert isinstance(capo.confint(), pd.DataFrame)
233+
assert capo.blp_model.cov_type == cov_type
227234

228235
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= -1.0,
229236
obj_dml_data.data['X1'] > 0.2]),
230237
columns=['Group 1', 'Group 2'])
231238
msg = ('At least one group effect is estimated with less than 6 observations.')
232239
with pytest.warns(UserWarning, match=msg):
233-
gapo_1 = dml_obj.gapo(groups_1)
240+
gapo_1 = dml_obj.gapo(groups_1, cov_type=cov_type)
234241
assert isinstance(gapo_1, dml.utils.blp.DoubleMLBLP)
235242
assert isinstance(gapo_1.confint(), pd.DataFrame)
236243
assert all(gapo_1.confint().index == groups_1.columns.to_list())
244+
assert gapo_1.blp_model.cov_type == cov_type
237245

238246
np.random.seed(42)
239247
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n, p=[0.1, 0.9]))
240248
msg = ('At least one group effect is estimated with less than 6 observations.')
241249
with pytest.warns(UserWarning, match=msg):
242-
gapo_2 = dml_obj.gapo(groups_2)
250+
gapo_2 = dml_obj.gapo(groups_2, cov_type=cov_type)
243251
assert isinstance(gapo_2, dml.utils.blp.DoubleMLBLP)
244252
assert isinstance(gapo_2.confint(), pd.DataFrame)
245253
assert all(gapo_2.confint().index == ["Group_1", "Group_2"])
254+
assert gapo_2.blp_model.cov_type == cov_type

doubleml/irm/tests/test_irm.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ def test_dml_irm_sensitivity_rho0(dml_irm_fixture):
187187
rtol=1e-9, atol=1e-4)
188188

189189

190+
@pytest.fixture(scope='module',
191+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
192+
def cov_type(request):
193+
return request.param
194+
195+
190196
@pytest.mark.ci
191-
def test_dml_irm_cate_gate():
197+
def test_dml_irm_cate_gate(cov_type):
192198
n = 9
193199
# collect data
194200
np.random.seed(42)
@@ -207,28 +213,31 @@ def test_dml_irm_cate_gate():
207213
dml_irm_obj.fit()
208214
# create a random basis
209215
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
210-
cate = dml_irm_obj.cate(random_basis)
216+
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
211217
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
212218
assert isinstance(cate.confint(), pd.DataFrame)
219+
assert cate.blp_model.cov_type == cov_type
213220

214221
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= 0,
215222
obj_dml_data.data['X1'] > 0.2]),
216223
columns=['Group 1', 'Group 2'])
217224
msg = ('At least one group effect is estimated with less than 6 observations.')
218225
with pytest.warns(UserWarning, match=msg):
219-
gate_1 = dml_irm_obj.gate(groups_1)
226+
gate_1 = dml_irm_obj.gate(groups_1, cov_type=cov_type)
220227
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
221228
assert isinstance(gate_1.confint(), pd.DataFrame)
222229
assert all(gate_1.confint().index == groups_1.columns.to_list())
230+
assert gate_1.blp_model.cov_type == cov_type
223231

224232
np.random.seed(42)
225233
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
226234
msg = ('At least one group effect is estimated with less than 6 observations.')
227235
with pytest.warns(UserWarning, match=msg):
228-
gate_2 = dml_irm_obj.gate(groups_2)
236+
gate_2 = dml_irm_obj.gate(groups_2, cov_type=cov_type)
229237
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
230238
assert isinstance(gate_2.confint(), pd.DataFrame)
231239
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
240+
assert gate_2.blp_model.cov_type == cov_type
232241

233242

234243
@pytest.fixture(scope='module',

doubleml/plm/plr.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
341341

342342
return res
343343

344-
def cate(self, basis, is_gate=False):
344+
def cate(self, basis, is_gate=False, **kwargs):
345345
"""
346346
Calculate conditional average treatment effects (CATE) for a given basis.
347347
@@ -350,10 +350,14 @@ def cate(self, basis, is_gate=False):
350350
basis : :class:`pandas.DataFrame`
351351
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
352352
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
353+
353354
is_gate : bool
354355
Indicates whether the basis is constructed for GATEs (dummy-basis).
355356
Default is ``False``.
356357
358+
**kwargs: dict
359+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
360+
357361
Returns
358362
-------
359363
model : :class:`doubleML.DoubleMLBLP`
@@ -374,10 +378,10 @@ def cate(self, basis, is_gate=False):
374378
basis=D_basis,
375379
is_gate=is_gate,
376380
)
377-
model.fit()
381+
model.fit(**kwargs)
378382
return model
379383

380-
def gate(self, groups):
384+
def gate(self, groups, **kwargs):
381385
"""
382386
Calculate group average treatment effects (GATE) for groups.
383387
@@ -388,6 +392,9 @@ def gate(self, groups):
388392
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
389393
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
390394
395+
**kwargs: dict
396+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
397+
391398
Returns
392399
-------
393400
model : :class:`doubleML.DoubleMLBLP`
@@ -407,7 +414,7 @@ def gate(self, groups):
407414
if any(groups.sum(0) <= 5):
408415
warnings.warn('At least one group effect is estimated with less than 6 observations.')
409416

410-
model = self.cate(groups, is_gate=True)
417+
model = self.cate(groups, is_gate=True, **kwargs)
411418
return model
412419

413420
def _partial_out(self):

0 commit comments

Comments
 (0)