Skip to content

Commit a293ff1

Browse files
authored
Merge pull request #13 from pythonhealthdatascience/dev
Dev
2 parents 2c20ae5 + e72d784 commit a293ff1

File tree

14 files changed

+644
-302
lines changed

14 files changed

+644
-302
lines changed

CITATION.cff

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ repository-code: >-
1919
abstract: >-
2020
Reproducible analytical pipeline (RAP) for R discrete-event simulation (DES)
2121
implementing the Stroke Capacity Planning Model from Monks et al. 2016.
22-
version: 0.1.0
23-
date-released: '2025-07-11'
22+
version: 0.2.0
23+
date-released: '2025-08-12'

NEWS.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# Stroke capacity planning model: R DES RAP 0.2.0
2+
3+
Introduces `DistributionRegistry` with JSON-based parameters, replacing individual parameter functions and CSV. Also add test coverage, add file path check, and documentation and dependency management updates.
4+
5+
## New features
6+
7+
* Add `DistributionRegistry` and `inst/extdata/parameters.json` (and accompanying data dictionary). Amended the package, validation, tests and `rmarkdown` to work with the new syntax for sampling and changing values (as have removed the individual parameter functions - and also removed the CSV).
8+
* Add coverage (`covr`, `DT`, coverage command in README, and GitHub action).
9+
10+
## Bug fixes
11+
12+
* Add check for non-null file path when `log_to_file=TRUE` in model validation.
13+
14+
## Other changes
15+
16+
* Switched to "all" `renv` snapshot type.
17+
* Update `docs/stress_des.md`.
118

219
# Stroke capacity planning model: R DES RAP 0.1.0
320

R/distribution_registry.R

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
3333
self$register("discrete", function(values, prob) {
3434
values <- unlist(values)
3535
prob <- unlist(prob)
36-
# Validation (as not using a pre-made distribution function)
37-
stopifnot(length(values) == length(prob))
38-
stopifnot(all(prob >= 0))
39-
if (round(abs(sum(prob) - 1), 2) > 0.01) {
36+
stopifnot(length(values) == length(prob), prob >= 0L)
37+
if (round(abs(sum(prob) - 1L), 2L) > 0.01) {
4038
stop(sprintf(
4139
"'prob' must sum to 1 ± 0.01. Sum: %s", abs(sum(unlist(prob)))
42-
))
40+
), call. = FALSE)
4341
}
4442
# Sampling function
45-
function(size = 1L) sample(
46-
values, size = size, replace = TRUE, prob = prob
47-
)
43+
function(size = 1L) {
44+
sample(values, size = size, replace = TRUE, prob = prob)
45+
}
4846
})
4947
self$register("normal", function(mean, sd) {
5048
function(size = 1L) rnorm(size, mean = mean, sd = sd)
@@ -107,7 +105,8 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
107105
"Use register() to add new distributions."),
108106
name, paste(names(self$registry), collapse = ",\n\t")
109107
),
110-
call. = FALSE)
108+
call. = FALSE
109+
)
111110
self$registry[[name]]
112111
},
113112

@@ -149,7 +148,7 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
149148
dots <- c(transformed, dots[setdiff(names(dots), c("mean", "sd"))])
150149
} else {
151150
stop("Please supply either 'meanlog' and 'sdlog', or 'mean' and 'sd' ",
152-
"for a lognormal distribution.")
151+
"for a lognormal distribution.", call. = FALSE)
153152
}
154153
}
155154
# Calls the `get()` method above which finds the distribution generator
@@ -169,14 +168,13 @@ DistributionRegistry <- R6Class("DistributionRegistry", list( # nolint: object_n
169168
#' 'class_name' and 'params'.
170169
#' @return List of parameterised samplers (named if config is named).
171170
create_batch = function(config) {
172-
if (is.list(config)) {
173-
# Calls `create()` for each distribution specified in config
174-
lapply(config, function(cfg) {
175-
do.call(self$create, c(cfg$class_name, cfg$params))
176-
})
177-
} else {
171+
if (!is.list(config)) {
178172
stop("config must be a list (named or unnamed)", call. = FALSE)
179173
}
174+
# Calls `create()` for each distribution specified in config
175+
lapply(config, function(cfg) {
176+
do.call(self$create, c(cfg$class_name, cfg$params))
177+
})
180178
}
181179
)
182180
)

R/model.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ model <- function(run_number, param, set_seed = TRUE) {
5757
param[["dist"]] <- registry$create_batch(as.list(param[["dist_config_num"]]))
5858

5959
# Restructure as dist[type][unit][patient]
60-
dist <- list()
60+
dist_refactor <- list()
6161
for (key in names(param[["dist"]])) {
62-
parts <- strsplit(key, "_")[[1]]
63-
dist_type <- parts[2]
64-
unit <- parts[1]
65-
patient <- paste(parts[-(1:2)], collapse = "_")
66-
dist[[dist_type]][[unit]][[patient]] <- param[["dist"]][[key]]
62+
parts <- strsplit(key, "_", fixed = TRUE)[[1L]]
63+
dist_type <- parts[2L]
64+
unit <- parts[1L]
65+
patient <- paste(parts[-(1L:2L)], collapse = "_")
66+
dist_refactor[[dist_type]][[unit]][[patient]] <- param[["dist"]][[key]]
6767
}
68-
param[["dist"]] <- dist
68+
param[["dist"]] <- dist_refactor
6969

7070
# Create simmer environment - set verbose to FALSE as using custom logs
7171
# (but can change to TRUE if want to see default simmer logs as well)

R/parameters.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,3 @@ parameters <- function(
6464
dist = NULL
6565
)
6666
}
67-

R/validate_model_inputs.R

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
valid_inputs <- function(run_number, param) {
1010
check_run_number(run_number)
1111
check_log_file_path(param)
12+
check_param_names(param)
13+
check_param_values(param)
1214
}
1315

1416

@@ -26,6 +28,7 @@ check_run_number <- function(run_number) {
2628
}
2729
}
2830

31+
2932
#' Check logging file path
3033
#'
3134
#' @param param List containing parameters for the simulation.
@@ -45,3 +48,221 @@ check_log_file_path <- function(param) {
4548
)
4649
}
4750
}
51+
52+
53+
#' Check parameter names.
54+
#'
55+
#' Ensure that all required parameters are present, and no extra parameters are
56+
#' provided.
57+
#'
58+
#' @param param List containing parameters for the simulation.
59+
#'
60+
#' @importFrom jsonlite fromJSON
61+
#'
62+
#' @return None. Throws an error if there are missing or extra parameters.
63+
#' @export
64+
65+
check_param_names <- function(param) {
66+
67+
# Check the distribution names....
68+
# Import JSON with the required names
69+
config <- fromJSON(
70+
system.file("extdata", "parameters.json", package = "simulation"),
71+
simplifyVector = FALSE
72+
)[["simulation_parameters"]]
73+
required <- names(config)
74+
75+
# Check what names are within param, if any missing or extra
76+
missing_names <- setdiff(required, names(param[["dist_config"]]))
77+
extra_names <- setdiff(names(param[["dist_config"]]), required)
78+
if (length(missing_names) > 0L || length(extra_names) > 0L)
79+
stop("Problem in param$dist_config. Missing: ", missing_names, ". ",
80+
"Extra: ", extra_names, ".", call. = FALSE)
81+
82+
# Check the names in param
83+
missing_names <- setdiff(names(parameters()), names(param))
84+
extra_names <- setdiff(names(param), names(parameters()))
85+
if (length(missing_names) > 0L || length(extra_names) > 0L)
86+
stop("Problem in param. Missing: ", missing_names, ". ",
87+
"Extra: ", extra_names, ".", call. = FALSE)
88+
}
89+
90+
91+
#' Validate probability vector
92+
#'
93+
#' Checks that values are between 0 and 1 (inclusive), that they sum to 1
94+
#' (with tolerance of +-0.01).
95+
#'
96+
#' @param vec Numeric vector. The probability vector to be checked.
97+
#' @param name Character string. The name or label for the vector, used in
98+
#' error messages.
99+
#'
100+
#' @return Throws an error if the vector is invalid; otherwise, returns nothing.
101+
#' @export
102+
103+
check_prob_vector <- function(vec, name) {
104+
if (!is.numeric(vec)) {
105+
stop('Routing vector "', name, '" must be numeric.', call. = FALSE)
106+
}
107+
if (any(vec < 0L | vec > 1L)) {
108+
stop('All values in routing vector "', name, '" must be between 0 and 1.',
109+
call. = FALSE)
110+
}
111+
if (sum(vec) < 0.99 || sum(vec) > 1.01) {
112+
stop('Values in routing vector "', name, '" must sum to 1 (+-0.01).',
113+
call. = FALSE)
114+
}
115+
}
116+
117+
118+
#' Check if a value is a positive integer
119+
#'
120+
#' Throws an error if the value is not a positive integer (> 0).
121+
#'
122+
#' @param x The value to check.
123+
#' @param name The name of the parameter (for error messages).
124+
#'
125+
#' @return None. Throws an error if the check fails.
126+
#' @export
127+
128+
check_positive_integer <- function(x, name) {
129+
if (is.null(x) || x <= 0L || x %% 1L != 0L) {
130+
stop(
131+
sprintf('The parameter "%s" must be an integer greater than 0.', name),
132+
call. = FALSE
133+
)
134+
}
135+
}
136+
137+
#' Check if all values are positive
138+
#'
139+
#' Throws an error if any value in the input is not positive (> 0).
140+
#'
141+
#' @param x The vector or list of values to check.
142+
#' @param name The name of the parameter (for error messages).
143+
#'
144+
#' @return None. Throws an error if the check fails.
145+
#' @export
146+
147+
check_all_positive <- function(x, name) {
148+
if (!is.null(x) && any(unlist(x) <= 0L)) {
149+
stop(
150+
sprintf('All values in "%s" must be greater than 0.', name),
151+
call. = FALSE
152+
)
153+
}
154+
}
155+
156+
#' Check if a value is a non-negative integer
157+
#'
158+
#' Throws an error if the value is not a non-negative integer (>= 0).
159+
#'
160+
#' @param x The value to check.
161+
#' @param name The name of the parameter (for error messages).
162+
#'
163+
#' @return None. Throws an error if the check fails.
164+
#' @export
165+
166+
check_nonneg_integer <- function(x, name) {
167+
if (is.null(x) || x < 0L || x %% 1L != 0L) {
168+
stop(
169+
sprintf('The parameter "%s" must be an integer >= 0.', name),
170+
call. = FALSE
171+
)
172+
}
173+
}
174+
175+
176+
#' Check that a parameter list contains only allowed names
177+
#'
178+
#' @param object_name String name of object (for error messages).
179+
#' @param actual_names Character vector of parameter names.
180+
#' @param allowed_names Character vector of allowed parameter names.
181+
#'
182+
#' @return None. Throws an error if unrecognised parameters are present.
183+
#' @export
184+
185+
check_allowed_params <- function(object_name, actual_names, allowed_names) {
186+
extra_names <- setdiff(actual_names, allowed_names)
187+
missing_names <- setdiff(allowed_names, actual_names)
188+
if (length(extra_names) > 0L) {
189+
stop(
190+
sprintf(
191+
"Unrecognised parameter(s) in %s: %s. Allowed: %s",
192+
object_name, toString(extra_names), toString(allowed_names)
193+
), call. = FALSE
194+
)
195+
}
196+
if (length(missing_names) > 0L) {
197+
stop(
198+
sprintf(
199+
"Missing required parameter(s) in %s: %s. Allowed: %s",
200+
object_name, toString(missing_names), toString(allowed_names)
201+
), call. = FALSE
202+
)
203+
}
204+
}
205+
206+
207+
#' Validate parameter values
208+
#'
209+
#' Ensures that specific parameters are positive numbers, non-negative integers,
210+
#' and that probability vectors are in [0, 1] and sum to 1 (within tolerance).
211+
#'
212+
#' @param param List containing parameters for the simulation.
213+
#'
214+
#' @return None. Throws an error if any specified parameter value is invalid.
215+
#' @export
216+
217+
check_param_values <- function(param) {
218+
219+
# High-level parameters (runs, simulation run length)
220+
check_positive_integer(param[["number_of_runs"]], "number_of_runs")
221+
lapply(c("warm_up_period", "data_collection_period"),
222+
function(nm) check_nonneg_integer(param[[nm]], nm))
223+
224+
# Loop through each distribution in dist_config
225+
for (dist_name in names(param$dist_config)) {
226+
entry <- param$dist_config[[dist_name]]
227+
type <- entry$class_name
228+
params <- entry$params
229+
230+
check_allowed_params(object_name = paste0("param$dist_config", dist_name),
231+
actual_names = names(entry),
232+
allowed_names = c("class_name", "params"))
233+
234+
if (type == "exponential") {
235+
check_all_positive(params$mean, paste0(dist_name, "$params$mean"))
236+
check_allowed_params(
237+
object_name = paste0("param$dist_config$", dist_name, "$params"),
238+
actual_names = names(params),
239+
allowed_names = "mean"
240+
)
241+
}
242+
243+
if (type == "lognormal") {
244+
check_all_positive(params$mean, paste0(dist_name, "$params$mean"))
245+
check_all_positive(params$sd, paste0(dist_name, "$params$sd"))
246+
check_allowed_params(
247+
object_name = paste0("param$dist_config$", dist_name, "$params"),
248+
actual_names = names(params),
249+
allowed_names = c("mean", "sd")
250+
)
251+
}
252+
253+
if (type == "discrete") {
254+
vals <- unlist(params$values)
255+
prob <- unlist(params$prob)
256+
if (length(vals) != length(prob)) {
257+
stop(sprintf("Discrete dist '%s' values and prob length mismatch",
258+
dist_name), call. = FALSE)
259+
}
260+
check_prob_vector(prob, paste0(dist_name, "$params$prob"))
261+
check_allowed_params(
262+
object_name = paste0("param$dist_config$", dist_name, "$params"),
263+
actual_names = names(params),
264+
allowed_names = c("values", "prob")
265+
)
266+
}
267+
}
268+
}

outputs/log_example.log

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,30 @@ rlnorm(size, meanlog = meanlog, sdlog = sdlog), tia = function (size = 1)
3333
rlnorm(size, meanlog = meanlog, sdlog = sdlog), neuro = function (size = 1)
3434
rlnorm(size, meanlog = meanlog, sdlog = sdlog), other = function (size = 1)
3535
rlnorm(size, meanlog = meanlog, sdlog = sdlog))), routing = list(asu = list(stroke = function (size = 1)
36-
sample(values, size = size, replace = TRUE, prob = prob), tia = function (size = 1)
37-
sample(values, size = size, replace = TRUE, prob = prob), neuro = function (size = 1)
38-
sample(values, size = size, replace = TRUE, prob = prob), other = function (size = 1)
39-
sample(values, size = size, replace = TRUE, prob = prob)), rehab = list(stroke = function (size = 1)
40-
sample(values, size = size, replace = TRUE, prob = prob), tia = function (size = 1)
41-
sample(values, size = size, replace = TRUE, prob = prob), neuro = function (size = 1)
42-
sample(values, size = size, replace = TRUE, prob = prob), other = function (size = 1)
43-
sample(values, size = size, replace = TRUE, prob = prob))));
36+
{
37+
sample(values, size = size, replace = TRUE, prob = prob)
38+
}, tia = function (size = 1)
39+
{
40+
sample(values, size = size, replace = TRUE, prob = prob)
41+
}, neuro = function (size = 1)
42+
{
43+
sample(values, size = size, replace = TRUE, prob = prob)
44+
}, other = function (size = 1)
45+
{
46+
sample(values, size = size, replace = TRUE, prob = prob)
47+
}), rehab = list(stroke = function (size = 1)
48+
{
49+
sample(values, size = size, replace = TRUE, prob = prob)
50+
}, tia = function (size = 1)
51+
{
52+
sample(values, size = size, replace = TRUE, prob = prob)
53+
}, neuro = function (size = 1)
54+
{
55+
sample(values, size = size, replace = TRUE, prob = prob)
56+
}, other = function (size = 1)
57+
{
58+
sample(values, size = size, replace = TRUE, prob = prob)
59+
})));
4460
map_val2num=1:3;
4561
map_num2val=c(`1` = "rehab", `2` = "esd", `3` = "other");
4662
dist_config_num=list(asu_arrival_stroke = list(class_name = "exponential", params = list(mean = 1.2)), asu_arrival_tia = list(class_name = "exponential", params = list(mean = 9.3)), asu_arrival_neuro = list(class_name = "exponential", params = list(mean = 3.6)), asu_arrival_other = list(class_name = "exponential", params = list(mean = 3.2)), rehab_arrival_stroke = list(class_name = "exponential", params = list(mean = 3)), rehab_arrival_neuro = list(class_name = "exponential", params = list(mean = 31.7)), rehab_arrival_other = list(

0 commit comments

Comments
 (0)