Skip to content

Commit 8f7125f

Browse files
committed
Exceptions fixed
1 parent 99e78bf commit 8f7125f

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

doubleml/plm/tests/test_lplr_exceptions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_lplr_exception_resampling():
7070

7171
@pytest.mark.ci
7272
def test_lplr_exception_get_params():
73-
msg = "Invalid nuisance learner ml_x. Valid nuisance learner ml_m or ml_t or ml_M or ml_a."
73+
msg = r"Invalid nuisance learner ml_x. Valid nuisance learner ml_m or ml_a or ml_t or ml_M.*"
7474
with pytest.raises(ValueError, match=msg):
7575
dml_lplr.get_params("ml_x")
7676

@@ -147,7 +147,7 @@ def test_lplr_exception_confint():
147147
@pytest.mark.ci
148148
def test_lplr_exception_set_ml_nuisance_params():
149149
# invalid learner name
150-
msg = "Invalid nuisance learner g. Valid nuisance learner ml_m or ml_t or ml_M or ml_a."
150+
msg = "Invalid nuisance learner g. Valid nuisance learner ml_m or ml_a or ml_t or ml_M.*"
151151
with pytest.raises(ValueError, match=msg):
152152
dml_lplr.set_ml_nuisance_params("g", "d", {"alpha": 0.1})
153153
# invalid treatment variable
@@ -246,13 +246,13 @@ def test_lplr_exception_and_warning_learner():
246246
with pytest.raises(TypeError, match=msg):
247247
_ = DoubleMLLPLR(dml_data, Lasso(), ml_t, ml_m)
248248
msg = (
249-
r"The ml_m learner RandomForestRegressor\(\) was identified as regressor but at least one treatment "
249+
r"The ml_m learner RandomForestRegressor\(.*\) was identified as regressor but at least one treatment "
250250
r"variable is binary with values 0 and 1."
251251
)
252252
with pytest.warns(match=msg):
253253
_ = DoubleMLLPLR(dml_data_binary, ml_M, ml_t, ml_m)
254254
msg = (
255-
r"The ml_a learner RandomForestRegressor\(\) was identified as regressor but at least one treatment "
255+
r"The ml_a learner RandomForestRegressor\(.*\) was identified as regressor but at least one treatment "
256256
r"variable is binary with values 0 and 1."
257257
)
258258
with pytest.warns(match=msg):
@@ -314,7 +314,8 @@ def test_double_ml_exception_evaluate_learner():
314314
dml_lplr_obj.evaluate_learners(metric="mse")
315315

316316
msg = (
317-
r"The learners have to be a subset of \['ml_m', 'ml_t', 'ml_M', 'ml_a'\]\. " r"Learners \['ml_mu', 'ml_p'\] provided."
317+
r"The learners have to be a subset of \['ml_m', 'ml_a', 'ml_t', 'ml_M'.*\]\. "
318+
r"Learners \['ml_mu', 'ml_p'\] provided."
318319
)
319320
with pytest.raises(ValueError, match=msg):
320321
dml_lplr_obj.evaluate_learners(learners=["ml_mu", "ml_p"])

0 commit comments

Comments
 (0)