@@ -57,9 +57,15 @@ MCBoost = R6::R6Class("MCBoost",
5757 # ' Currently only supports "simple", even split along probabilities.
5858 # ' Only relevant for `num_buckets` > 1.
5959 bucket_strategy = NULL ,
60+
6061 # ' @field rebucket [`logical`] \cr
6162 # ' Should buckets be re-calculated at each iteration?
6263 rebucket = NULL ,
64+
65+ # ' @field eval_fulldata [`logical`] \cr
66+ # ' Should auditor be evaluated on the full data?
67+ eval_fulldata = NULL ,
68+
6369 # ' @field partition [`logical`] \cr
6470 # ' True/False flag for whether to split up predictions by their "partition"
6571 # ' (e.g., predictions less than 0.5 and predictions greater than 0.5).
@@ -118,6 +124,12 @@ MCBoost = R6::R6Class("MCBoost",
118124 # ' Only taken into account for `num_buckets` > 1.
119125 # ' @param rebucket [`logical`] \cr
120126 # ' Should buckets be re-done at each iteration? Default `FALSE`.
127+ # ' @param eval_fulldata [`logical`] \cr
128+ # ' Should the auditor be evaluated on the full data or on the respective bucket for determining
129+ # ' the stopping criterion? Default `FALSE`, auditor is only evaluated on the bucket.
130+ # ' This setting keeps the implementation closer to the Algorithm proposed in the corresponding
131+ # ' multi-accuracy paper (Kim et al., 2019) where auditor effects are computed across the full
132+ # ' sample (i.e. eval_fulldata = TRUE).
121133 # ' @param multiplicative [`logical`] \cr
122134 # ' Specifies the strategy for updating the weights (multiplicative weight vs additive).
123135 # ' Defaults to `TRUE` (multi-accuracy boosting). Set to `FALSE` for multi-calibration.
@@ -141,6 +153,7 @@ MCBoost = R6::R6Class("MCBoost",
141153 # ' "split" splits the data into `max_iter` parts and validates on each sample in each iteration.\cr
142154 # ' "bootstrap" uses a new bootstrap sample in each iteration.\cr
143155 # ' "none" uses the same dataset in each iteration.
156+
144157 initialize = function (
145158 max_iter = 5 ,
146159 alpha = 1e-4 ,
@@ -149,6 +162,7 @@ MCBoost = R6::R6Class("MCBoost",
149162 num_buckets = 2 ,
150163 bucket_strategy = " simple" ,
151164 rebucket = FALSE ,
165+ eval_fulldata = FALSE ,
152166 multiplicative = TRUE ,
153167 auditor_fitter = NULL ,
154168 subpops = NULL ,
@@ -162,6 +176,7 @@ MCBoost = R6::R6Class("MCBoost",
162176 self $ num_buckets = assert_int(num_buckets )
163177 self $ bucket_strategy = assert_choice(bucket_strategy , choices = c(" simple" ))
164178 self $ rebucket = assert_flag(rebucket )
179+ self $ eval_fulldata = assert_flag(eval_fulldata )
165180 self $ partition = assert_flag(partition )
166181 self $ multiplicative = assert_flag(multiplicative )
167182 self $ iter_sampling = assert_choice(iter_sampling , choices = c(" none" , " bootstrap" , " split" ))
@@ -283,6 +298,13 @@ MCBoost = R6::R6Class("MCBoost",
283298 models [[j ]] = out [[2 ]]
284299 }
285300
301+ if (self $ eval_fulldata ) {
302+ corrs = map_dbl(models , function (m ) {
303+ if (is.null(m )) return (0 )
304+ mean(m $ predict(data [idx ,]) * resid [idx ])
305+ })
306+ }
307+
286308 self $ iter_corr = c(self $ iter_corr , list (corrs ))
287309 if (abs(max(corrs )) < self $ alpha ) {
288310 break
0 commit comments