Skip to content

Commit aad49b3

Browse files
committed
Fix additve
1 parent 26470f7 commit aad49b3

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: mcboost
22
Type: Package
33
Title: Multi-Calibration Boosting
4-
Version: 0.3.0.9000
4+
Version: 0.3.1
55
Authors@R:
66
c(person(given = "Florian",
77
family = "Pfisterer",

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# mcboost (development version)
22

3+
# mcboost 0.3.1
4+
5+
* Fixed a bug for additive weight updates, were updates went
6+
in the wrong direction.
7+
38
# mcboost 0.3.0
49

510
* First CRAN-ready version of the package.

R/MCBoost.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ MCBoost = R6::R6Class("MCBoost",
390390
new_preds = update_weights * pmax(orig_preds, 1e-4)
391391
} else {
392392
update_weights = (self$eta * deltas)
393-
new_preds = orig_preds + update_weights
393+
new_preds = orig_preds - update_weights
394394
}
395395
if (audit) {
396396
self$auditor_effects = c(self$auditor_effects, list(abs(deltas)))

tests/testthat/test_mcboost.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,30 @@ test_that("init predictor wrapper works", {
310310
expect_warning(mc$predict_probs(d), "multicalibrate was not run!")
311311

312312
})
313+
314+
test_that("mcboost on training data sanity checks", {
315+
tsk = tsk("sonar")
316+
d = tsk$data(cols = tsk$feature_names)
317+
l = tsk$data(cols = tsk$target_names)[[1]]
318+
mc = MCBoost$new(auditor_fitter = "TreeAuditorFitter")
319+
mc$multicalibrate(d[1:200,], l[1:200])
320+
df = do.call("rbind", mc$iter_corr)
321+
expect_true(all(diff(df) <= 0))
322+
323+
mc = MCBoost$new(auditor_fitter = "TreeAuditorFitter", multiplicative = FALSE)
324+
mc$multicalibrate(d[1:200,], l[1:200])
325+
df = do.call("rbind", mc$iter_corr)
326+
expect_true(all(diff(df) <= 0))
327+
328+
mc = MCBoost$new(auditor_fitter = "RidgeAuditorFitter", partition = TRUE, num_buckets = 5)
329+
mc$multicalibrate(d[1:200,], l[1:200])
330+
df = do.call("rbind", mc$iter_corr)
331+
expect_true(all(diff(df) <= 0))
332+
333+
mc = MCBoost$new(auditor_fitter = "RidgeAuditorFitter", partition = TRUE, num_buckets = 5, multiplicative = FALSE)
334+
mc$multicalibrate(d[1:200,], l[1:200])
335+
df = do.call("rbind", mc$iter_corr)
336+
expect_true(all(diff(df) <= 0))
337+
})
338+
339+

0 commit comments

Comments
 (0)