@@ -104,7 +104,7 @@ surv_shap <- function(explainer,
104104 res $ result <- switch (calculation_method ,
105105 " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , N , ... ),
106106 " kernelshap" = use_kernelshap(explainer , new_observation , output_type , N , ... ),
107- " treeshap" = use_treeshap(explainer , new_observation , ... ),
107+ " treeshap" = use_treeshap(explainer , new_observation , output_type , ... ),
108108 stop(" Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented" ))
109109 # quality-check here
110110 stopifnot(
@@ -282,7 +282,7 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
282282 return (shap_values )
283283}
284284
285- use_treeshap <- function (explainer , new_observation , ... ){
285+ use_treeshap <- function (explainer , new_observation , output_type , ... ){
286286
287287 stopifnot(
288288 " new_observation must be a data.frame" = inherits(
@@ -292,45 +292,26 @@ use_treeshap <- function(explainer, new_observation, ...){
292292 # init unify_append_args
293293 unify_append_args <- list ()
294294
295- if (inherits(explainer $ model , " ranger" )) {
296- # UNIFY_FUN to prepare code for easy Integration of other ml algorithms
297- # that are supported by treeshap
298- UNIFY_FUN <- treeshap :: ranger_surv.unify
299- unify_append_args <- list (type = " survival" , times = explainer $ times )
300- } else {
295+ if (! inherits(explainer $ model , " ranger" )) {
301296 stop(" Support for `treeshap` is currently only implemented for `ranger`." )
302297 }
303298
304- unify_args <- list (
305- rf_model = explainer $ model ,
306- data = explainer $ data
307- )
308-
309- if (length(unify_append_args ) > 0 ) {
310- unify_args <- c(unify_args , unify_append_args )
311- }
312-
313- tmp_unified <- do.call(UNIFY_FUN , unify_args )
299+ tmp_unified <- treeshap :: unify(explainer $ model ,
300+ explainer $ data ,
301+ type = output_type ,
302+ times = explainer $ times )
314303
315304 shap_values <- sapply(
316305 X = as.character(seq_len(nrow(new_observation ))),
317306 FUN = function (i ) {
307+ # ensure that matrix has expected dimensions; as.integer is
308+ # necessary for valid comparison with "identical"
309+ new_obs_mat <- new_observation [as.integer(i ), ]
310+ stopifnot(identical(dim(new_obs_mat ), as.integer(c(1L , ncol(new_observation )))))
311+
318312 tmp_res <- do.call(
319313 rbind ,
320- lapply(
321- tmp_unified ,
322- function (m ) {
323- new_obs_mat <- new_observation [as.integer(i ), ]
324- # ensure that matrix has expected dimensions; as.integer is
325- # necessary for valid comparison with "identical"
326- stopifnot(identical(dim(new_obs_mat ), as.integer(c(1L , ncol(new_observation )))))
327- treeshap :: treeshap(
328- unified_model = m ,
329- x = new_obs_mat ,
330- ...
331- )$ shaps
332- }
333- )
314+ lapply(treeshap :: treeshap(tmp_unified , x = new_obs_mat , ... ), function (x ) x $ shaps )
334315 )
335316
336317 tmp_shap_values <- data.frame (tmp_res )
0 commit comments