Skip to content

Commit 6ddffb2

Browse files
authored
Merge pull request #283 from cmu-delphi/ds/check_enough_train_data
feat: check_enough_train_data
2 parents ee11b1e + b869222 commit 6ddffb2

14 files changed

+474
-34
lines changed

.Rbuildignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
^renv$
2+
^renv\.lock$
13
^epipredict\.Rproj$
24
^\.Rproj\.user$
35
^LICENSE\.md$

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ inst/doc
77
.DS_Store
88
/doc/
99
/Meta/
10+
.Rprofile
11+
renv.lock
12+
renv/

DESCRIPTION

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.7
3+
Version: 0.0.8
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -22,11 +22,11 @@ License: MIT + file LICENSE
2222
URL: https://github.com/cmu-delphi/epipredict/,
2323
https://cmu-delphi.github.io/epipredict
2424
BugReports: https://github.com/cmu-delphi/epipredict/issues/
25-
Depends:
25+
Depends:
2626
epiprocess (>= 0.6.0),
2727
parsnip (>= 1.0.0),
2828
R (>= 3.5.0)
29-
Imports:
29+
Imports:
3030
cli,
3131
distributional,
3232
dplyr,
@@ -48,7 +48,7 @@ Imports:
4848
usethis,
4949
vctrs,
5050
workflows (>= 1.0.0)
51-
Suggests:
51+
Suggests:
5252
covidcast,
5353
data.table,
5454
epidatr (>= 1.0.0),
@@ -61,7 +61,7 @@ Suggests:
6161
rmarkdown,
6262
testthat (>= 3.0.0),
6363
xgboost
64-
VignetteBuilder:
64+
VignetteBuilder:
6565
knitr
6666
Remotes:
6767
cmu-delphi/epidatr,

NAMESPACE

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ S3method(adjust_frosting,frosting)
1010
S3method(apply_frosting,default)
1111
S3method(apply_frosting,epi_workflow)
1212
S3method(augment,epi_workflow)
13+
S3method(bake,check_enough_train_data)
1314
S3method(bake,epi_recipe)
1415
S3method(bake,step_epi_ahead)
1516
S3method(bake,step_epi_lag)
@@ -48,6 +49,7 @@ S3method(mean,dist_quantiles)
4849
S3method(median,dist_quantiles)
4950
S3method(predict,epi_workflow)
5051
S3method(predict,flatline)
52+
S3method(prep,check_enough_train_data)
5153
S3method(prep,epi_recipe)
5254
S3method(prep,step_epi_ahead)
5355
S3method(prep,step_epi_lag)
@@ -60,6 +62,7 @@ S3method(print,arx_class)
6062
S3method(print,arx_fcast)
6163
S3method(print,canned_epipred)
6264
S3method(print,cdc_baseline_fcast)
65+
S3method(print,check_enough_train_data)
6366
S3method(print,epi_recipe)
6467
S3method(print,epi_workflow)
6568
S3method(print,flat_fcast)
@@ -104,6 +107,7 @@ S3method(snap,default)
104107
S3method(snap,dist_default)
105108
S3method(snap,dist_quantiles)
106109
S3method(snap,distribution)
110+
S3method(tidy,check_enough_train_data)
107111
S3method(tidy,frosting)
108112
S3method(tidy,layer)
109113
S3method(update,layer)
@@ -127,6 +131,7 @@ export(arx_forecaster)
127131
export(bake)
128132
export(cdc_baseline_args_list)
129133
export(cdc_baseline_forecaster)
134+
export(check_enough_train_data)
130135
export(create_layer)
131136
export(default_epi_recipe_blueprint)
132137
export(detect_layer)
@@ -191,6 +196,12 @@ import(epiprocess)
191196
import(parsnip)
192197
import(recipes)
193198
importFrom(cli,cli_abort)
199+
importFrom(dplyr,across)
200+
importFrom(dplyr,all_of)
201+
importFrom(dplyr,group_by)
202+
importFrom(dplyr,n)
203+
importFrom(dplyr,summarise)
204+
importFrom(dplyr,ungroup)
194205
importFrom(epiprocess,growth_rate)
195206
importFrom(generics,augment)
196207
importFrom(generics,fit)
@@ -225,6 +236,7 @@ importFrom(stats,residuals)
225236
importFrom(tibble,as_tibble)
226237
importFrom(tibble,is_tibble)
227238
importFrom(tibble,tibble)
239+
importFrom(tidyr,drop_na)
228240
importFrom(vctrs,as_list_of)
229241
importFrom(vctrs,field)
230242
importFrom(vctrs,new_rcrd)

NEWS.md

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,49 @@
11
# epipredict (development)
22

3+
# epipredict 0.0.8
4+
5+
- add `check_enough_train_data` that will error if training data is too small
6+
- added `check_enough_train_data` to `arx_forecaster`
7+
38
# epipredict 0.0.7
49

5-
* simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
10+
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
611

712
# epipredict 0.0.6
813

9-
* rename the `dist_quantiles()` to be more descriptive, breaking change)
10-
* removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
11-
* add `pivot_quantiles_wider()` for easier plotting
12-
* add complement `pivot_quantiles_longer()`
13-
* add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
14+
- rename the `dist_quantiles()` to be more descriptive, breaking change)
15+
- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
16+
- add `pivot_quantiles_wider()` for easier plotting
17+
- add complement `pivot_quantiles_longer()`
18+
- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
1419

1520
# epipredict 0.0.5
1621

17-
* add `smooth_quantile_reg()`
18-
* improved printing of various methods / internals
19-
* canned forecasters get a class
20-
* fixed quantile bug in `flatline_forecaster()`
21-
* add functionality to output the unfit workflow from the canned forecasters
22+
- add `smooth_quantile_reg()`
23+
- improved printing of various methods / internals
24+
- canned forecasters get a class
25+
- fixed quantile bug in `flatline_forecaster()`
26+
- add functionality to output the unfit workflow from the canned forecasters
2227

2328
# epipredict 0.0.4
2429

25-
* add quantile_reg()
26-
* clean up documentation bugs
27-
* add smooth_quantile_reg()
28-
* add classifier
29-
* training window step debugged
30-
* `min_train_window` argument removed from canned forecasters
30+
- add quantile_reg()
31+
- clean up documentation bugs
32+
- add smooth_quantile_reg()
33+
- add classifier
34+
- training window step debugged
35+
- `min_train_window` argument removed from canned forecasters
3136

3237
# epipredict 0.0.3
3338

34-
* add forecasters
35-
* implement postprocessing
36-
* vignettes avaliable
37-
* arx_forecaster
38-
* pkgdown
39+
- add forecasters
40+
- implement postprocessing
41+
- vignettes avaliable
42+
- arx_forecaster
43+
- pkgdown
3944

4045
# epipredict 0.0.0.9000
4146

42-
* Publish public for easy navigation
43-
* Two simple forecasters as test beds
44-
* Working vignette
47+
- Publish public for easy navigation
48+
- Two simple forecasters as test beds
49+
- Working vignette

R/arx_classifier.R

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,21 @@ arx_class_epi_workflow <- function(
180180
role = "outcome"
181181
) %>%
182182
step_epi_naomit() %>%
183-
step_training_window(n_recent = args_list$n_training)
183+
step_training_window(n_recent = args_list$n_training) %>%
184+
{
185+
if (!is.null(args_list$check_enough_data_n)) {
186+
check_enough_train_data(
187+
.,
188+
all_predictors(),
189+
!!outcome,
190+
n = args_list$check_enough_data_n,
191+
epi_keys = args_list$check_enough_data_epi_keys,
192+
drop_na = FALSE
193+
)
194+
} else {
195+
.
196+
}
197+
}
184198

185199
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
186200
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
@@ -228,6 +242,11 @@ arx_class_epi_workflow <- function(
228242
#' @param additional_gr_args List. Optional arguments controlling growth rate
229243
#' calculation. See [epiprocess::growth_rate()] and the related Vignette for
230244
#' more details.
245+
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
246+
#' epi_key that are required for training. If `NULL`, this check is ignored.
247+
#' @param check_enough_data_epi_keys Character vector. A character vector of
248+
#' column names on which to group the data and check threshold within each
249+
#' group. Useful if training per group (for example, per geo_value).
231250
#'
232251
#' @return A list containing updated parameter choices with class `arx_clist`.
233252
#' @export
@@ -251,6 +270,8 @@ arx_class_args_list <- function(
251270
log_scale = FALSE,
252271
additional_gr_args = list(),
253272
nafill_buffer = Inf,
273+
check_enough_data_n = NULL,
274+
check_enough_data_epi_keys = NULL,
254275
...) {
255276
rlang::check_dots_empty()
256277
.lags <- lags
@@ -275,6 +296,8 @@ arx_class_args_list <- function(
275296
)
276297
)
277298
}
299+
arg_is_pos(check_enough_data_n, allow_null = TRUE)
300+
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)
278301

279302
breaks <- sort(breaks)
280303
if (min(breaks) > -Inf) breaks <- c(-Inf, breaks)
@@ -296,7 +319,9 @@ arx_class_args_list <- function(
296319
method,
297320
log_scale,
298321
additional_gr_args,
299-
nafill_buffer
322+
nafill_buffer,
323+
check_enough_data_n,
324+
check_enough_data_epi_keys
300325
),
301326
class = c("arx_class", "alist")
302327
)

R/arx_forecaster.R

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,21 @@ arx_fcast_epi_workflow <- function(
126126
r <- r %>%
127127
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
128128
step_epi_naomit() %>%
129-
step_training_window(n_recent = args_list$n_training)
129+
step_training_window(n_recent = args_list$n_training) %>%
130+
{
131+
if (!is.null(args_list$check_enough_data_n)) {
132+
check_enough_train_data(
133+
.,
134+
all_predictors(),
135+
!!outcome,
136+
n = args_list$check_enough_data_n,
137+
epi_keys = args_list$check_enough_data_epi_keys,
138+
drop_na = FALSE
139+
)
140+
} else {
141+
.
142+
}
143+
}
130144

131145
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
132146
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
@@ -199,6 +213,11 @@ arx_fcast_epi_workflow <- function(
199213
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
200214
#' will be treated as _additional_ allowed recent data rather than the
201215
#' total amount of recent data to examine.
216+
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
217+
#' epi_key that are required for training. If `NULL`, this check is ignored.
218+
#' @param check_enough_data_epi_keys Character vector. A character vector of
219+
#' column names on which to group the data and check threshold within each
220+
#' group. Useful if training per group (for example, per geo_value).
202221
#' @param ... Space to handle future expansions (unused).
203222
#'
204223
#'
@@ -220,6 +239,8 @@ arx_args_list <- function(
220239
nonneg = TRUE,
221240
quantile_by_key = character(0L),
222241
nafill_buffer = Inf,
242+
check_enough_data_n = NULL,
243+
check_enough_data_epi_keys = NULL,
223244
...) {
224245
# error checking if lags is a list
225246
rlang::check_dots_empty()
@@ -236,6 +257,8 @@ arx_args_list <- function(
236257
arg_is_pos(n_training)
237258
if (is.finite(n_training)) arg_is_pos_int(n_training)
238259
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
260+
arg_is_pos(check_enough_data_n, allow_null = TRUE)
261+
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)
239262

240263
max_lags <- max(lags)
241264
structure(
@@ -250,7 +273,9 @@ arx_args_list <- function(
250273
nonneg,
251274
max_lags,
252275
quantile_by_key,
253-
nafill_buffer
276+
nafill_buffer,
277+
check_enough_data_n,
278+
check_enough_data_epi_keys
254279
),
255280
class = c("arx_fcast", "alist")
256281
)

0 commit comments

Comments
 (0)