Skip to content

Commit 4479460

Browse files
committed
prepare new cran version
1 parent aad49b3 commit 4479460

File tree

5 files changed

+38
-4
lines changed

5 files changed

+38
-4
lines changed

DESCRIPTION

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ Depends:
3939
Imports:
4040
backports,
4141
checkmate (>= 2.0.0),
42-
lifecycle,
4342
data.table (>= 1.13.6),
4443
mlr3 (>= 0.10),
4544
mlr3misc (>= 0.8.0),

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Fixed a bug for additive weight updates, were updates went
66
in the wrong direction.
7+
* Added new parameter `eval_fulldata` that allows to compute
8+
auditor effect across the full sample (as opposed to the bucket).
79

810
# mcboost 0.3.0
911

R/MCBoost.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,15 @@ MCBoost = R6::R6Class("MCBoost",
5757
#' Currently only supports "simple", even split along probabilities.
5858
#' Only relevant for `num_buckets` > 1.
5959
bucket_strategy = NULL,
60+
6061
#' @field rebucket [`logical`] \cr
6162
#' Should buckets be re-calculated at each iteration?
6263
rebucket = NULL,
64+
65+
#' @field eval_fulldata [`logical`] \cr
66+
#' Should auditor be evaluated on the full data?
67+
eval_fulldata = NULL,
68+
6369
#' @field partition [`logical`] \cr
6470
#' True/False flag for whether to split up predictions by their "partition"
6571
#' (e.g., predictions less than 0.5 and predictions greater than 0.5).
@@ -118,6 +124,12 @@ MCBoost = R6::R6Class("MCBoost",
118124
#' Only taken into account for `num_buckets` > 1.
119125
#' @param rebucket [`logical`] \cr
120126
#' Should buckets be re-done at each iteration? Default `FALSE`.
127+
#' @param eval_fulldata [`logical`] \cr
128+
#' Should the auditor be evaluated on the full data or on the respective bucket for determining
129+
#' the stopping criterion? Default `FALSE`, auditor is only evaluated on the bucket.
130+
#' This setting keeps the implementation closer to the Algorithm proposed in the corresponding
131+
#' multi-accuracy paper (Kim et al., 2019) where auditor effects are computed across the full
132+
#' sample (i.e. eval_fulldata = TRUE).
121133
#' @param multiplicative [`logical`] \cr
122134
#' Specifies the strategy for updating the weights (multiplicative weight vs additive).
123135
#' Defaults to `TRUE` (multi-accuracy boosting). Set to `FALSE` for multi-calibration.
@@ -141,6 +153,7 @@ MCBoost = R6::R6Class("MCBoost",
141153
#' "split" splits the data into `max_iter` parts and validates on each sample in each iteration.\cr
142154
#' "bootstrap" uses a new bootstrap sample in each iteration.\cr
143155
#' "none" uses the same dataset in each iteration.
156+
144157
initialize = function(
145158
max_iter=5,
146159
alpha=1e-4,
@@ -149,6 +162,7 @@ MCBoost = R6::R6Class("MCBoost",
149162
num_buckets=2,
150163
bucket_strategy="simple",
151164
rebucket=FALSE,
165+
eval_fulldata=FALSE,
152166
multiplicative=TRUE,
153167
auditor_fitter=NULL,
154168
subpops=NULL,
@@ -162,6 +176,7 @@ MCBoost = R6::R6Class("MCBoost",
162176
self$num_buckets = assert_int(num_buckets)
163177
self$bucket_strategy = assert_choice(bucket_strategy, choices = c("simple"))
164178
self$rebucket = assert_flag(rebucket)
179+
self$eval_fulldata = assert_flag(eval_fulldata)
165180
self$partition = assert_flag(partition)
166181
self$multiplicative = assert_flag(multiplicative)
167182
self$iter_sampling = assert_choice(iter_sampling, choices = c("none", "bootstrap", "split"))
@@ -283,6 +298,13 @@ MCBoost = R6::R6Class("MCBoost",
283298
models[[j]] = out[[2]]
284299
}
285300

301+
if (self$eval_fulldata) {
302+
corrs = map_dbl(models, function(m) {
303+
if (is.null(m)) return(0)
304+
mean(m$predict(data[idx,]) * resid[idx])
305+
})
306+
}
307+
286308
self$iter_corr = c(self$iter_corr, list(corrs))
287309
if (abs(max(corrs)) < self$alpha) {
288310
break

man/MCBoost.Rd

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_mcboost.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ test_that("MCBoost various settings", {
165165
# Check a list of settings
166166
mcs = list(
167167
MCBoost$new(auditor_fitter = NULL),
168-
MCBoost$new(alpha = 0.05)
168+
MCBoost$new(alpha = 0.05),
169+
MCBoost$new(eval_fulldata = TRUE),
170+
MCBoost$new(eval_fulldata = TRUE, multiplicative = FALSE)
169171
)
170172
for (mc in mcs) {
171173
mc$multicalibrate(data, labels)
@@ -335,5 +337,3 @@ test_that("mcboost on training data sanity checks", {
335337
df = do.call("rbind", mc$iter_corr)
336338
expect_true(all(diff(df) <= 0))
337339
})
338-
339-

0 commit comments

Comments
 (0)