Skip to content

Commit 96f33ae

Browse files
committed
Enhance learner evaluation checks and handle NaN targets in DoubleML class
1 parent 33a86d0 commit 96f33ae

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

doubleml/double_ml.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,12 +1274,19 @@ def evaluate_learners(self, learners=None, metric=_rmse):
12741274
for learner in learners:
12751275
for rep in range(self.n_rep):
12761276
for coef_idx in range(self._dml_data.n_coefs):
1277-
res = metric(
1278-
y_pred=self.predictions[learner][:, rep, coef_idx].reshape(1, -1),
1279-
y_true=self.nuisance_targets[learner][:, rep, coef_idx].reshape(1, -1),
1280-
)
1281-
if not np.isfinite(res):
1282-
raise ValueError(f"Evaluation from learner {str(learner)} is not finite.")
1277+
targets = self.nuisance_targets[learner][:, rep, coef_idx].reshape(1, -1)
1278+
1279+
if np.all(np.isnan(targets)):
1280+
res = np.nan
1281+
else:
1282+
predictions = self.predictions[learner][:, rep, coef_idx].reshape(1, -1)
1283+
res = metric(
1284+
y_pred=predictions,
1285+
y_true=targets,
1286+
)
1287+
if not np.isfinite(res):
1288+
raise ValueError(f"Evaluation from learner {str(learner)} is not finite.")
1289+
12831290
dist[learner][rep, coef_idx] = res
12841291
return dist
12851292
else:

doubleml/utils/_check_return_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def check_basic_predictions_and_targets(dml_obj, n_obs, n_treat, n_rep):
113113
assert isinstance(dml_obj.nuisance_loss[key], np.ndarray)
114114
assert dml_obj.nuisance_loss[key].shape == (n_rep, n_treat)
115115

116+
learner_eval = dml_obj.evaluate_learners()
117+
assert isinstance(learner_eval, dict)
118+
for key in expected_keys:
119+
assert key in learner_eval
120+
assert isinstance(learner_eval[key], np.ndarray)
121+
assert learner_eval[key].shape == (n_rep, n_treat)
116122
return
117123

118124

0 commit comments

Comments
 (0)