Skip to content

Commit 9fbc227

Browse files
committed
update survshap for new version of treeshap
1 parent 19365d1 commit 9fbc227

File tree

1 file changed

+13
-32
lines changed

1 file changed

+13
-32
lines changed

R/surv_shap.R

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ surv_shap <- function(explainer,
104104
res$result <- switch(calculation_method,
105105
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, N, ...),
106106
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, N, ...),
107-
"treeshap" = use_treeshap(explainer, new_observation, ...),
107+
"treeshap" = use_treeshap(explainer, new_observation, output_type, ...),
108108
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
109109
# quality-check here
110110
stopifnot(
@@ -282,7 +282,7 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
282282
return(shap_values)
283283
}
284284

285-
use_treeshap <- function(explainer, new_observation, ...){
285+
use_treeshap <- function(explainer, new_observation, output_type, ...){
286286

287287
stopifnot(
288288
"new_observation must be a data.frame" = inherits(
@@ -292,45 +292,26 @@ use_treeshap <- function(explainer, new_observation, ...){
292292
# init unify_append_args
293293
unify_append_args <- list()
294294

295-
if (inherits(explainer$model, "ranger")) {
296-
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
297-
# that are supported by treeshap
298-
UNIFY_FUN <- treeshap::ranger_surv.unify
299-
unify_append_args <- list(type = "survival", times = explainer$times)
300-
} else {
295+
if (!inherits(explainer$model, "ranger")) {
301296
stop("Support for `treeshap` is currently only implemented for `ranger`.")
302297
}
303298

304-
unify_args <- list(
305-
rf_model = explainer$model,
306-
data = explainer$data
307-
)
308-
309-
if (length(unify_append_args) > 0) {
310-
unify_args <- c(unify_args, unify_append_args)
311-
}
312-
313-
tmp_unified <- do.call(UNIFY_FUN, unify_args)
299+
tmp_unified <- treeshap::unify(explainer$model,
300+
explainer$data,
301+
type = output_type,
302+
times = explainer$times)
314303

315304
shap_values <- sapply(
316305
X = as.character(seq_len(nrow(new_observation))),
317306
FUN = function(i) {
307+
# ensure that matrix has expected dimensions; as.integer is
308+
# necessary for valid comparison with "identical"
309+
new_obs_mat <- new_observation[as.integer(i), ]
310+
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
311+
318312
tmp_res <- do.call(
319313
rbind,
320-
lapply(
321-
tmp_unified,
322-
function(m) {
323-
new_obs_mat <- new_observation[as.integer(i), ]
324-
# ensure that matrix has expected dimensions; as.integer is
325-
# necessary for valid comparison with "identical"
326-
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
327-
treeshap::treeshap(
328-
unified_model = m,
329-
x = new_obs_mat,
330-
...
331-
)$shaps
332-
}
333-
)
314+
lapply(treeshap::treeshap(tmp_unified, x = new_obs_mat, ...), function(x) x$shaps)
334315
)
335316

336317
tmp_shap_values <- data.frame(tmp_res)

0 commit comments

Comments
 (0)