Skip to content

Commit 5721d41

Browse files
authored
Merge pull request #452 from cmu-delphi/arx_forecastCheckEnough
using check_enough_train_data in practice
2 parents 5372480 + 84f991d commit 5721d41

18 files changed

+500
-406
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.1.13
3+
Version: 0.1.14
44
Authors@R: c(
55
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ S3method(apply_frosting,epi_workflow)
1515
S3method(augment,epi_workflow)
1616
S3method(autoplot,canned_epipred)
1717
S3method(autoplot,epi_workflow)
18-
S3method(bake,check_enough_train_data)
18+
S3method(bake,check_enough_data)
1919
S3method(bake,epi_recipe)
2020
S3method(bake,step_adjust_latency)
2121
S3method(bake,step_climate)
@@ -49,7 +49,7 @@ S3method(key_colnames,recipe)
4949
S3method(mean,quantile_pred)
5050
S3method(predict,epi_workflow)
5151
S3method(predict,flatline)
52-
S3method(prep,check_enough_train_data)
52+
S3method(prep,check_enough_data)
5353
S3method(prep,epi_recipe)
5454
S3method(prep,step_adjust_latency)
5555
S3method(prep,step_climate)
@@ -65,7 +65,7 @@ S3method(print,arx_class)
6565
S3method(print,arx_fcast)
6666
S3method(print,canned_epipred)
6767
S3method(print,cdc_baseline_fcast)
68-
S3method(print,check_enough_train_data)
68+
S3method(print,check_enough_data)
6969
S3method(print,climate_fcast)
7070
S3method(print,epi_recipe)
7171
S3method(print,epi_workflow)
@@ -109,7 +109,7 @@ S3method(slather,layer_threshold)
109109
S3method(slather,layer_unnest)
110110
S3method(snap,default)
111111
S3method(snap,quantile_pred)
112-
S3method(tidy,check_enough_train_data)
112+
S3method(tidy,check_enough_data)
113113
S3method(tidy,frosting)
114114
S3method(tidy,layer)
115115
S3method(update,layer)
@@ -142,7 +142,7 @@ export(autoplot)
142142
export(bake)
143143
export(cdc_baseline_args_list)
144144
export(cdc_baseline_forecaster)
145-
export(check_enough_train_data)
145+
export(check_enough_data)
146146
export(clean_f_name)
147147
export(climate_args_list)
148148
export(climatological_forecaster)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
2020
- Removes dependence on the `distributional` package, replacing the quantiles
2121
with `hardhat::quantile_pred()`. Some associated functions are deprecated with
2222
`lifecycle` messages.
23+
- Rename `check_enough_train_data()` to `check_enough_data()`, and generalize it
24+
enough to use as a check on either training or testing.
25+
- Add check for enough data to predict in `arx_forecaster()`
2326

2427
## Improvements
2528

@@ -33,6 +36,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
3336
- Add `climatological_forecaster()` to automatically create climate baselines
3437
- Replace `dist_quantiles()` with `hardhat::quantile_pred()`
3538
- Allow `quantile()` to threshold to an interval if desired (#434)
39+
- `arx_forecaster()` detects if there's enough data to predict
3640

3741
## Bug fixes
3842

R/arx_classifier.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ arx_class_epi_workflow <- function(
222222
step_training_window(n_recent = args_list$n_training)
223223

224224
if (!is.null(args_list$check_enough_data_n)) {
225-
r <- check_enough_train_data(
225+
r <- check_enough_data(
226226
r,
227227
recipes::all_predictors(),
228228
recipes::all_outcomes(),

R/arx_forecaster.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,14 @@ arx_fcast_epi_workflow <- function(
171171
step_epi_ahead(!!outcome, ahead = args_list$ahead)
172172
r <- r %>%
173173
step_epi_naomit() %>%
174-
step_training_window(n_recent = args_list$n_training)
174+
step_training_window(n_recent = args_list$n_training) %>%
175+
check_enough_data(all_predictors(), min_observations = 1, skip = FALSE)
176+
175177
if (!is.null(args_list$check_enough_data_n)) {
176-
r <- r %>% check_enough_train_data(
178+
r <- r %>% check_enough_data(
177179
all_predictors(),
178-
!!outcome,
179-
n = args_list$check_enough_data_n,
180+
all_outcomes(),
181+
min_observations = args_list$check_enough_data_n,
180182
epi_keys = args_list$check_enough_data_epi_keys,
181183
drop_na = FALSE
182184
)

R/canned-epipred.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ print.canned_epipred <- function(x, name, ...) {
112112
"At forecast date{?s}: {.val {fds}},",
113113
"For target date{?s}: {.val {tds}},"
114114
))
115-
if ("actions" %in% names(x$pre) && "recipe" %in% names(x$pre$actions)) {
115+
if ("pre" %in% names(x) && "actions" %in% names(x$pre) && "recipe" %in% names(x$pre$actions)) {
116116
fit_recipe <- extract_recipe(x$epi_workflow)
117117
if (detect_step(fit_recipe, "adjust_latency")) {
118118
is_adj_latency <- map_lgl(fit_recipe$steps, function(x) inherits(x, "step_adjust_latency"))

R/check_enough_data.R

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#' Check the dataset contains enough data points.
2+
#'
3+
#' `check_enough_data` creates a *specification* of a recipe
4+
#' operation that will check if variables contain enough data.
5+
#'
6+
#' @param recipe A recipe object. The check will be added to the
7+
#' sequence of operations for this recipe.
8+
#' @param ... One or more selector functions to choose variables for this check.
9+
#' See [selections()] for more details. You will usually want to use
10+
#' [recipes::all_predictors()] and/or [recipes::all_outcomes()] here.
11+
#' @param min_observations The minimum number of data points required for
12+
#' training. If this is NULL, the total number of predictors will be used.
13+
#' @param epi_keys A character vector of column names on which to group the data
14+
#' and check threshold within each group. Useful if your forecaster trains
15+
#' per group (for example, per geo_value).
16+
#' @param drop_na A logical for whether to count NA values as valid rows.
17+
#' @param role Not used by this check since no new variables are
18+
#' created.
19+
#' @param trained A logical for whether the selectors in `...`
20+
#' have been resolved by [prep()].
21+
#' @param id A character string that is unique to this check to identify it.
22+
#' @param skip A logical. If `TRUE`, only training data is checked, while if
23+
#' `FALSE`, both training and predicting data is checked. Technically, this
24+
#' answers the question "should the check be skipped when the recipe is baked
25+
#' by [bake()]?" While all operations are baked when [prep()] is run, some
26+
#' operations may not be able to be conducted on new data (e.g. processing the
27+
#' outcome variable(s)). Care should be taken when using `skip = TRUE` as it
28+
#' may affect the computations for subsequent operations.
29+
#' @family checks
30+
#' @export
31+
#' @details This check will break the `prep` and/or bake function if any of the
32+
#' checked columns have not enough non-NA values. If the check passes, nothing
33+
#' is changed in the data. It is best used after every other step.
34+
#'
35+
#' For checking training data, it is best to set `...` to be
36+
#' `all_predictors(), all_outcomes()`, while for checking prediction data, it
37+
#' is best to set `...` to be `all_predictors()` only, with `n = 1`.
38+
#'
39+
#' # tidy() results
40+
#'
41+
#' When you [`tidy()`][tidy.recipe()] this check, a tibble with column
42+
#' `terms` (the selectors or variables selected) is returned.
43+
#'
44+
check_enough_data <-
45+
function(recipe,
46+
...,
47+
min_observations = NULL,
48+
epi_keys = NULL,
49+
drop_na = TRUE,
50+
role = NA,
51+
trained = FALSE,
52+
skip = TRUE,
53+
id = rand_id("enough_data")) {
54+
recipes::add_check(
55+
recipe,
56+
check_enough_data_new(
57+
min_observations = min_observations,
58+
epi_keys = epi_keys,
59+
drop_na = drop_na,
60+
terms = enquos(...),
61+
role = role,
62+
trained = trained,
63+
columns = NULL,
64+
skip = skip,
65+
id = id
66+
)
67+
)
68+
}
69+
70+
check_enough_data_new <-
71+
function(min_observations, epi_keys, drop_na, terms,
72+
role, trained, columns, skip, id) {
73+
recipes::check(
74+
subclass = "enough_data",
75+
prefix = "check_",
76+
min_observations = min_observations,
77+
epi_keys = epi_keys,
78+
drop_na = drop_na,
79+
terms = terms,
80+
role = role,
81+
trained = trained,
82+
columns = columns,
83+
skip = skip,
84+
id = id
85+
)
86+
}
87+
88+
#' @export
89+
prep.check_enough_data <- function(x, training, info = NULL, ...) {
90+
col_names <- recipes::recipes_eval_select(x$terms, training, info)
91+
if (is.null(x$min_observations)) {
92+
x$min_observations <- length(col_names)
93+
}
94+
95+
check_enough_data_core(training, x, col_names, "train")
96+
97+
check_enough_data_new(
98+
min_observations = x$min_observations,
99+
epi_keys = x$epi_keys,
100+
drop_na = x$drop_na,
101+
terms = x$terms,
102+
role = x$role,
103+
trained = TRUE,
104+
columns = col_names,
105+
skip = x$skip,
106+
id = x$id
107+
)
108+
}
109+
110+
#' @export
111+
bake.check_enough_data <- function(object, new_data, ...) {
112+
col_names <- object$columns
113+
check_enough_data_core(new_data, object, col_names, "predict")
114+
new_data
115+
}
116+
117+
#' @export
118+
print.check_enough_data <- function(x, width = max(20, options()$width - 30), ...) {
119+
title <- paste0("Check enough data (n = ", x$min_observations, ") for ")
120+
recipes::print_step(x$columns, x$terms, x$trained, title, width)
121+
invisible(x)
122+
}
123+
124+
#' @export
125+
tidy.check_enough_data <- function(x, ...) {
126+
if (recipes::is_trained(x)) {
127+
res <- tibble(terms = unname(x$columns))
128+
} else {
129+
res <- tibble(terms = recipes::sel2char(x$terms))
130+
}
131+
res$id <- x$id
132+
res$min_observations <- x$min_observations
133+
res$epi_keys <- x$epi_keys
134+
res$drop_na <- x$drop_na
135+
res
136+
}
137+
138+
check_enough_data_core <- function(epi_df, step_obj, col_names, train_or_predict) {
139+
epi_df <- epi_df %>%
140+
group_by(across(all_of(.env$step_obj$epi_keys)))
141+
if (step_obj$drop_na) {
142+
any_missing_data <- epi_df %>%
143+
mutate(any_are_na = rowSums(across(any_of(.env$col_names), ~ is.na(.x))) > 0) %>%
144+
# count the number of rows where they're all not na
145+
summarise(sum(any_are_na == 0) < .env$step_obj$min_observations, .groups = "drop")
146+
any_missing_data <- any_missing_data %>%
147+
summarize(across(all_of(setdiff(names(any_missing_data), step_obj$epi_keys)), any)) %>%
148+
any()
149+
150+
# figuring out which individual columns (if any) are to blame for this dearth
151+
# of data
152+
cols_not_enough_data <- epi_df %>%
153+
summarise(
154+
across(
155+
all_of(.env$col_names),
156+
~ sum(!is.na(.x)) < .env$step_obj$min_observations
157+
),
158+
.groups = "drop"
159+
) %>%
160+
# Aggregate across keys (if present)
161+
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
162+
unlist() %>%
163+
# Select the names of the columns that are TRUE
164+
names(.)[.]
165+
166+
# Either all columns have enough data, in which case this message won't be
167+
# sent later or none of the single columns have enough data, that means its
168+
# the combination of all of them.
169+
if (length(cols_not_enough_data) == 0) {
170+
cols_not_enough_data <-
171+
glue::glue("no single column, but the combination of {paste0(col_names, collapse = ', ')}")
172+
}
173+
} else {
174+
# if we're not dropping na values, just count
175+
cols_not_enough_data <- epi_df %>%
176+
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$step_obj$min_observations))
177+
any_missing_data <- cols_not_enough_data %>%
178+
summarize(across(all_of(.env$col_names), all)) %>%
179+
all()
180+
cols_not_enough_data <- cols_not_enough_data %>%
181+
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
182+
unlist() %>%
183+
# Select the names of the columns that are TRUE
184+
names(.)[.]
185+
}
186+
187+
if (any_missing_data) {
188+
cli_abort(
189+
"The following columns don't have enough data to {train_or_predict}: {cols_not_enough_data}.",
190+
class = "epipredict__not_enough_data"
191+
)
192+
}
193+
}

0 commit comments

Comments
 (0)