Skip to content

Commit 0e7e7c3

Browse files
add docstring for class DoubleMLAPOS
1 parent 477fdb5 commit 0e7e7c3

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

doubleml/irm/apos.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,69 @@
2121

2222

2323
class DoubleMLAPOS(SampleSplittingMixin):
24-
"""Double machine learning for interactive regression models with multiple discrete treatments."""
24+
"""Double machine learning for interactive regression models with multiple discrete
25+
treatments.
26+
27+
Parameters
28+
----------
29+
obj_dml_data : :class:`DoubleMLData` object
30+
The :class:`DoubleMLData` object providing the data and specifying the variables for the causal model.
31+
32+
ml_g : estimator implementing ``fit()`` and ``predict()``
33+
A machine learner implementing ``fit()`` and ``predict()`` methods (e.g.
34+
:py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`g_0(D, X) = E[Y | X, D]`.
35+
For a binary outcome variable :math:`Y` (with values 0 and 1), a classifier implementing ``fit()`` and
36+
``predict_proba()`` can also be specified. If :py:func:`sklearn.base.is_classifier` returns ``True``,
37+
``predict_proba()`` is used otherwise ``predict()``.
38+
39+
ml_m : classifier implementing ``fit()`` and ``predict_proba()``
40+
A machine learner implementing ``fit()`` and ``predict_proba()`` methods (e.g.
41+
:py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D | X]`.
42+
43+
treatment_levels : iterable of int or float
44+
The treatment levels for which average potential outcomes are evaluated. Each element must be present in the
45+
treatment variable ``d`` of ``obj_dml_data``.
46+
47+
n_folds : int
48+
Number of folds.
49+
Default is ``5``.
50+
51+
n_rep : int
52+
Number of repetitions for the sample splitting.
53+
Default is ``1``.
54+
55+
score : str
56+
A str (``'APO'``) specifying the score function.
57+
Default is ``'APO'``.
58+
59+
weights : array, dict or None
60+
A numpy array of weights for each individual observation. If ``None``, then the ``'APO'`` score
61+
is applied (corresponds to weights equal to 1).
62+
An array has to be of shape ``(n,)``, where ``n`` is the number of observations.
63+
A dictionary can be used to specify weights which depend on the treatment variable.
64+
In this case, the dictionary has to contain two keys ``weights`` and ``weights_bar``, where the values
65+
have to be arrays of shape ``(n,)`` and ``(n, n_rep)``.
66+
Default is ``None``.
67+
68+
normalize_ipw : bool
69+
Indicates whether the inverse probability weights are normalized.
70+
Default is ``False``.
71+
72+
trimming_rule : str, optional, deprecated
73+
(DEPRECATED) A str (``'truncate'`` is the only choice) specifying the trimming approach.
74+
Use ``ps_processor_config`` instead. Will be removed in a future version.
75+
76+
trimming_threshold : float, optional, deprecated
77+
(DEPRECATED) The threshold used for trimming.
78+
Use ``ps_processor_config`` instead. Will be removed in a future version.
79+
80+
ps_processor_config : PSProcessorConfig, optional
81+
Configuration for propensity score processing (clipping, calibration, etc.).
82+
83+
draw_sample_splitting : bool
84+
Indicates whether the sample splitting should be drawn during initialization of the object.
85+
Default is ``True``.
86+
"""
2587

2688
def __init__(
2789
self,

0 commit comments

Comments
 (0)