Skip to content

Commit 2c20ae5

Browse files
authored
Merge pull request #12 from pythonhealthdatascience/dev
Dev
2 parents 7c2c18b + 85dcb6a commit 2c20ae5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1367
-1798
lines changed

NAMESPACE

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,18 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
export(DistributionRegistry)
34
export(add_patient_generator)
4-
export(check_all_param_names)
5-
export(check_all_positive)
6-
export(check_nonneg_integer)
7-
export(check_param_names)
8-
export(check_param_values)
9-
export(check_positive_integer)
10-
export(check_prob_vector)
5+
export(check_log_file_path)
116
export(check_run_number)
12-
export(create_asu_arrivals)
13-
export(create_asu_los)
14-
export(create_asu_routing)
157
export(create_asu_trajectory)
16-
export(create_parameters)
17-
export(create_rehab_arrivals)
18-
export(create_rehab_los)
19-
export(create_rehab_routing)
208
export(create_rehab_trajectory)
219
export(filter_warmup)
2210
export(get_occupancy_stats)
2311
export(model)
12+
export(parameters)
2413
export(runner)
25-
export(sample_routing)
26-
export(transform_to_lnorm)
2714
export(valid_inputs)
15+
importFrom(R6,R6Class)
2816
importFrom(dplyr,bind_rows)
2917
importFrom(dplyr,filter)
3018
importFrom(dplyr,group_by)
@@ -35,6 +23,7 @@ importFrom(future,multisession)
3523
importFrom(future,plan)
3624
importFrom(future,sequential)
3725
importFrom(future.apply,future_lapply)
26+
importFrom(jsonlite,fromJSON)
3827
importFrom(rlang,.data)
3928
importFrom(simmer,add_generator)
4029
importFrom(simmer,add_resource)

R/add_patient_generator.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ add_patient_generator <- function(env, trajectory, unit, patient_type, param) {
2929
name_prefix = paste0(unit, "_", patient_type),
3030
trajectory = trajectory,
3131
distribution = function() {
32-
rexp(1L, 1L / param[[paste0(unit, "_arrivals")]][[patient_type]])
32+
param[["dist"]][["arrival"]][[unit]][[patient_type]]()
3333
}
3434
)
3535
}

R/create_asu_trajectory.R

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,44 +32,32 @@ create_asu_trajectory <- function(env, patient_type, param) {
3232

3333
# Sample destination after ASU (as destination influences length of stay)
3434
set_attribute("post_asu_destination", function() {
35-
sample_routing(prob_list = param[["asu_routing"]][[patient_type]])
35+
param[["dist"]][["routing"]][["asu"]][[patient_type]]()
3636
}) |>
3737

3838
log_(function() {
39-
# Retrieve attribute, and use to get post-ASU destination as a string
40-
dest_index <- get_attribute(env, "post_asu_destination")
41-
dest_names <- names(param[["asu_routing"]][[patient_type]])
42-
dest <- dest_names[dest_index]
43-
# Create log message
44-
paste0("\U0001F3AF Planned ASU -> ", dest_index, " (", dest, ")")
39+
dest_num <- get_attribute(env, "post_asu_destination")
40+
dest <- param[["map_num2val"]][as.character(dest_num)]
41+
paste0("\U0001F3AF Planned ASU -> ", dest_num, " (", dest, ")")
4542
}, level = 1L) |>
4643

44+
# Sample ASU LOS. For stroke patients, LOS distribution is based on
45+
# the planned destination after the ASU.
4746
set_attribute("asu_los", function() {
48-
# Retrieve attribute, and use to get post-ASU destination as a string
49-
dest_index <- get_attribute(env, "post_asu_destination")
50-
dest_names <- names(param[["asu_routing"]][[patient_type]])
51-
dest <- dest_names[dest_index]
52-
53-
# Determine which LOS distribution to use
47+
dest_num <- get_attribute(env, "post_asu_destination")
48+
dest <- param[["map_num2val"]][as.character(dest_num)]
5449
if (patient_type == "stroke") {
55-
los_params <- switch(
50+
switch(
5651
dest,
57-
esd = param[["asu_los_lnorm"]][["stroke_esd"]],
58-
rehab = param[["asu_los_lnorm"]][["stroke_no_esd"]],
59-
other = param[["asu_los_lnorm"]][["stroke_mortality"]],
52+
esd = param[["dist"]][["los"]][["asu"]][["stroke_esd"]](),
53+
rehab = param[["dist"]][["los"]][["asu"]][["stroke_noesd"]](),
54+
other = param[["dist"]][["los"]][["asu"]][["stroke_mortality"]](),
6055
stop("Stroke post-asu destination '", dest, "' invalid",
6156
call. = FALSE)
6257
)
6358
} else {
64-
los_params <- param[["asu_los_lnorm"]][[patient_type]]
59+
param[["dist"]][["los"]][["asu"]][[patient_type]]()
6560
}
66-
67-
# Sample LOS from lognormal
68-
rlnorm(
69-
n = 1L,
70-
meanlog = los_params[["meanlog"]],
71-
sdlog = los_params[["sdlog"]]
72-
)
7361
}) |>
7462

7563
log_(function() {
@@ -86,11 +74,8 @@ create_asu_trajectory <- function(env, patient_type, param) {
8674
# If that patient's destination is rehab, then start on that trajectory
8775
branch(
8876
option = function() {
89-
# Retrieve attribute, and use to get post-ASU destination as a string
90-
dest_index <- get_attribute(env, "post_asu_destination")
91-
dest_names <- names(param[["asu_routing"]][[patient_type]])
92-
dest <- dest_names[dest_index]
93-
# Return 1 for rehab and 0 otherwise
77+
dest_num <- get_attribute(env, "post_asu_destination")
78+
dest <- param[["map_num2val"]][as.character(dest_num)]
9479
if (dest == "rehab") 1L else 0L
9580
},
9681
continue = FALSE, # Do not continue main trajectory after branch

R/create_rehab_trajectory.R

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,43 +32,31 @@ create_rehab_trajectory <- function(env, patient_type, param) {
3232

3333
# Sample destination after rehab (as destination influences length of stay)
3434
set_attribute("post_rehab_destination", function() {
35-
sample_routing(prob_list = param[["rehab_routing"]][[patient_type]])
35+
param[["dist"]][["routing"]][["rehab"]][[patient_type]]()
3636
}) |>
3737

3838
log_(function() {
39-
# Retrieve attribute, and use to get post-rehab destination as a string
40-
dest_index <- get_attribute(env, "post_rehab_destination")
41-
dest_names <- names(param[["rehab_routing"]][[patient_type]])
42-
dest <- dest_names[dest_index]
43-
# Create log message
44-
paste0("\U0001F3AF Planned rehab -> ", dest_index, " (", dest, ")")
39+
dest_num <- get_attribute(env, "post_rehab_destination")
40+
dest <- param[["map_num2val"]][as.character(dest_num)]
41+
paste0("\U0001F3AF Planned rehab -> ", dest_num, " (", dest, ")")
4542
}, level = 1L) |>
4643

44+
# Sample rehab LOS. For stroke patients, LOS distribution is based on
45+
# the planned destination after the rehab
4746
set_attribute("rehab_los", function() {
48-
# Retrieve attribute, and use to get post-rehab destination as a string
49-
dest_index <- get_attribute(env, "post_rehab_destination")
50-
dest_names <- names(param[["rehab_routing"]][[patient_type]])
51-
dest <- dest_names[dest_index]
52-
53-
# Determine which LOS distribution to use
47+
dest_num <- get_attribute(env, "post_rehab_destination")
48+
dest <- param[["map_num2val"]][as.character(dest_num)]
5449
if (patient_type == "stroke") {
55-
los_params <- switch(
50+
switch(
5651
dest,
57-
esd = param[["rehab_los_lnorm"]][["stroke_esd"]],
58-
other = param[["rehab_los_lnorm"]][["stroke_no_esd"]],
52+
esd = param[["dist"]][["los"]][["rehab"]][["stroke_esd"]](),
53+
other = param[["dist"]][["los"]][["rehab"]][["stroke_noesd"]](),
5954
stop("Stroke post-rehab destination '", dest, "' invalid",
6055
call. = FALSE)
6156
)
6257
} else {
63-
los_params <- param[["rehab_los_lnorm"]][[patient_type]]
58+
param[["dist"]][["los"]][["rehab"]][[patient_type]]()
6459
}
65-
66-
# Sample LOS from lognormal
67-
rlnorm(
68-
n = 1L,
69-
meanlog = los_params[["meanlog"]],
70-
sdlog = los_params[["sdlog"]]
71-
)
7260
}) |>
7361

7462
log_(function() {

R/distribution_registry.R

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#' Registry for parametrised probability distributions
2+
#'
3+
#' @description
4+
#' The DistributionRegistry manages and generates parameterized samplers for a
5+
#' variety of probability distributions. Common distributions are included by
6+
#' default, and more can be added.
7+
#'
8+
#' Once defined, you can create sampler objects for each distribution -
9+
#' individually (`create`) or in batches (`create_batch`) - and then easily
10+
#' draw random samples from these objects.
11+
#'
12+
#' To add more built-in distributions, edit `initialize()`. To add custom ones
13+
#' at any time, use `register()`.
14+
#'
15+
#' @docType class
16+
#' @importFrom R6 R6Class
17+
#' @export
18+
19+
DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_name_linter
20+
21+
#' @field registry Named list of registered distribution generator functions.
22+
registry = list(),
23+
24+
#' @description
25+
#' Pre-registers a set of common base R distribution generators.
26+
initialize = function() {
27+
self$register("exponential", function(mean) {
28+
function(size = 1L) rexp(size, rate = 1L / mean)
29+
})
30+
self$register("uniform", function(min, max) {
31+
function(size = 1L) runif(size, min = min, max = max)
32+
})
33+
self$register("discrete", function(values, prob) {
34+
values <- unlist(values)
35+
prob <- unlist(prob)
36+
# Validation (as not using a pre-made distribution function)
37+
stopifnot(length(values) == length(prob))
38+
stopifnot(all(prob >= 0))
39+
if (round(abs(sum(prob) - 1), 2) > 0.01) {
40+
stop(sprintf(
41+
"'prob' must sum to 1 ± 0.01. Sum: %s", abs(sum(unlist(prob)))
42+
))
43+
}
44+
# Sampling function
45+
function(size = 1L) sample(
46+
values, size = size, replace = TRUE, prob = prob
47+
)
48+
})
49+
self$register("normal", function(mean, sd) {
50+
function(size = 1L) rnorm(size, mean = mean, sd = sd)
51+
})
52+
self$register("lognormal", function(meanlog, sdlog) {
53+
function(size = 1L) rlnorm(size, meanlog = meanlog, sdlog = sdlog)
54+
})
55+
self$register("poisson", function(lambda) {
56+
function(size = 1L) rpois(size, lambda = lambda)
57+
})
58+
self$register("binomial", function(size_param, prob) {
59+
function(size = 1L) rbinom(size, size = size_param, prob = prob)
60+
})
61+
self$register("geometric", function(prob) {
62+
function(size = 1L) rgeom(size, prob = prob)
63+
})
64+
self$register("beta", function(shape1, shape2) {
65+
function(size = 1L) rbeta(size, shape1 = shape1, shape2 = shape2)
66+
})
67+
self$register("gamma", function(shape, rate) {
68+
function(size = 1L) rgamma(size, shape = shape, rate = rate)
69+
})
70+
self$register("chisq", function(df) {
71+
function(size = 1L) rchisq(size, df = df)
72+
})
73+
self$register("t", function(df) {
74+
function(size = 1L) rt(size, df = df)
75+
})
76+
},
77+
78+
#' @description
79+
#' Register a distribution generator under a name.
80+
#'
81+
#' Typically, the generator should be a function that:
82+
#' 1. Accepts parameters for a distribution.
83+
#' 2. Returns another function - the *sampler* - which takes a `size`
84+
#' argument and produces that many random values from the specified
85+
#' distribution.
86+
#'
87+
#' By storing generators rather than fixed samplers, you can create as many
88+
#' different samplers as you want later, each with different parameters,
89+
#' while reusing the same generator code.
90+
#'
91+
#' @param name Distribution name (string)
92+
#' @param generator Function to create a sampler given its parameters.
93+
register = function(name, generator) {
94+
self$registry[[name]] <- generator
95+
},
96+
97+
#' @description
98+
#' Get a registered distribution generator by name.
99+
#'
100+
#' @param name Distribution name (string)
101+
#' @return Generator function for the distribution.
102+
get = function(name) {
103+
if (!(name %in% names(self$registry)))
104+
stop(
105+
sprintf(
106+
c("Distribution '%s' not found.\nAvailable distributions:\n\t%s\n",
107+
"Use register() to add new distributions."),
108+
name, paste(names(self$registry), collapse = ",\n\t")
109+
),
110+
call. = FALSE)
111+
self$registry[[name]]
112+
},
113+
114+
#' @description
115+
#' Convert mean/sd to lognormal parameters, returning the corresponding
116+
#' \code{meanlog} and \code{sdlog} parameters used by R's \code{rlnorm()}.
117+
#'
118+
#' @param params Named list with two elements: mean and sd.
119+
#' @return A named list of the same structure, but with elements
120+
#' \code{meanlog} and \code{sdlog} for each patient type.
121+
transform_to_lnorm = function(params) {
122+
variance <- params$sd^2L
123+
sigma_sq <- log(variance / (params$mean^2L) + 1L)
124+
sdlog <- sqrt(sigma_sq)
125+
meanlog <- log(params$mean) - sigma_sq / 2L
126+
list(meanlog = meanlog, sdlog = sdlog)
127+
},
128+
129+
#' @description
130+
#' Create a parameterised sampler for a distribution.
131+
#'
132+
#' The returned function draws random samples of a specified size from
133+
#' the given distribution with fixed parameters.
134+
#'
135+
#' For "lognormal", if "meanlog" and "sdlog" are present in the parameters,
136+
#' they will be used as-is. If not, but "mean" and "sd" are provided, these
137+
#' will be transformed into "meanlog"/"sdlog" automatically.
138+
#'
139+
#' @param name Distribution name
140+
#' @param ... Parameters for the generator
141+
#' @return A function that draws samples when called.
142+
create = function(name, ...) {
143+
dots <- list(...)
144+
if (name == "lognormal") {
145+
if (!is.null(dots$meanlog) && !is.null(dots$sdlog)) {
146+
dots <- dots
147+
} else if (!is.null(dots$mean) && !is.null(dots$sd)) {
148+
transformed <- self$transform_to_lnorm(dots)
149+
dots <- c(transformed, dots[setdiff(names(dots), c("mean", "sd"))])
150+
} else {
151+
stop("Please supply either 'meanlog' and 'sdlog', or 'mean' and 'sd' ",
152+
"for a lognormal distribution.")
153+
}
154+
}
155+
# Calls the `get()` method above which finds the distribution generator
156+
# function, then do.call() populates it with dots (a list of arguments).
157+
generator <- self$get(name)
158+
do.call(generator, dots)
159+
},
160+
161+
#' @description
162+
#' Batch-create samplers from a configuration list.
163+
#'
164+
#' The configuration should be a list of lists, each sublist specifying a
165+
#' `class_name` (distribution) and `params` (parameter list for that
166+
#' distribution).
167+
#'
168+
#' @param config Named or unnamed list. Each entry is a list with
169+
#' 'class_name' and 'params'.
170+
#' @return List of parameterised samplers (named if config is named).
171+
create_batch = function(config) {
172+
if (is.list(config)) {
173+
# Calls `create()` for each distribution specified in config
174+
lapply(config, function(cfg) {
175+
do.call(self$create, c(cfg$class_name, cfg$params))
176+
})
177+
} else {
178+
stop("config must be a list (named or unnamed)", call. = FALSE)
179+
}
180+
}
181+
)
182+
)

0 commit comments

Comments
 (0)