Skip to content

Commit d754409

Browse files
committed
feat: review updates
* check postprocessor for forecast_date in forecast.epi_workflow * add test
1 parent 3ecb78c commit d754409

15 files changed

+68
-51
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ S3method(extrapolate_quantiles,distribution)
4545
S3method(fit,epi_workflow)
4646
S3method(flusight_hub_formatter,canned_epipred)
4747
S3method(flusight_hub_formatter,data.frame)
48+
S3method(forecast,epi_workflow)
4849
S3method(format,dist_quantiles)
4950
S3method(is.na,dist_quantiles)
5051
S3method(is.na,distribution)
@@ -220,6 +221,7 @@ importFrom(dplyr,ungroup)
220221
importFrom(epiprocess,growth_rate)
221222
importFrom(generics,augment)
222223
importFrom(generics,fit)
224+
importFrom(generics,forecast)
223225
importFrom(ggplot2,autoplot)
224226
importFrom(hardhat,refresh_blueprint)
225227
importFrom(hardhat,run_mold)

R/arx_classifier.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ arx_classifier <- function(
5151
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
5252
}
5353

54-
wf <- arx_class_epi_workflow(
55-
epi_data, outcome, predictors, trainer, args_list
56-
)
54+
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5755
wf <- generics::fit(wf, epi_data)
5856

5957
preds <- forecast(

R/arx_forecaster.R

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,19 @@
3838
#' trainer = quantile_reg(),
3939
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
4040
#' )
41-
arx_forecaster <- function(epi_data,
42-
outcome,
43-
predictors = outcome,
44-
trainer = parsnip::linear_reg(),
45-
args_list = arx_args_list()) {
41+
arx_forecaster <- function(
42+
epi_data,
43+
outcome,
44+
predictors = outcome,
45+
trainer = parsnip::linear_reg(),
46+
args_list = arx_args_list()) {
4647
if (!is_regression(trainer)) {
4748
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
4849
}
4950

50-
wf <- arx_fcast_epi_workflow(
51-
epi_data, outcome, predictors, trainer, args_list
52-
)
51+
wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5352
wf <- generics::fit(wf, epi_data)
5453

55-
latest <- get_test_data(
56-
hardhat::extract_preprocessor(wf), epi_data,
57-
)
58-
5954
preds <- forecast(
6055
wf,
6156
fill_locf = TRUE,

R/epi_workflow.R

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ print.epi_workflow <- function(x, ...) {
334334

335335
#' Produce a forecast from an epi workflow
336336
#'
337-
#' @param epi_workflow An epi workflow
337+
#' @param object An epi workflow.
338+
#' @param ... Not used.
338339
#' @param fill_locf Logical. Should we use locf to fill in missing data?
339340
#' @param n_recent Integer or NULL. If filling missing data with locf = TRUE,
340341
#' how far back are we willing to tolerate missing data? Larger values allow
@@ -349,21 +350,34 @@ print.epi_workflow <- function(x, ...) {
349350
#' @return A forecast tibble.
350351
#'
351352
#' @export
352-
forecast <- function(epi_workflow, fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
353-
if (!epi_workflow$trained) {
353+
forecast.epi_workflow <- function(object, ..., fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
354+
rlang::check_dots_empty()
355+
356+
if (!object$trained) {
354357
cli_abort(c(
355358
"You cannot `forecast()` a {.cls workflow} that has not been trained.",
356359
i = "Please use `fit()` before forecasting."
357360
))
358361
}
359362

363+
frosting_fd <- NULL
364+
if (has_postprocessor(object) && detect_layer(object, "layer_add_forecast_date")) {
365+
frosting_fd <- extract_argument(object, "layer_add_forecast_date", "forecast_date")
366+
if (!is.null(frosting_fd) && class(frosting_fd) != class(object$original_data$time_value)) {
367+
cli_abort(c(
368+
"Error with layer_add_forecast_date():",
369+
i = "The type of `forecast_date` must match the type of the `time_value` column in the data."
370+
))
371+
}
372+
}
373+
360374
test_data <- get_test_data(
361-
hardhat::extract_preprocessor(epi_workflow),
362-
epi_workflow$original_data,
375+
hardhat::extract_preprocessor(object),
376+
object$original_data,
363377
fill_locf = fill_locf,
364378
n_recent = n_recent %||% Inf,
365-
forecast_date = forecast_date %||% max(epi_workflow$original_data$time_value)
379+
forecast_date = forecast_date %||% frosting_fd %||% max(object$original_data$time_value)
366380
)
367381

368-
predict(epi_workflow, new_data = test_data)
382+
predict(object, new_data = test_data)
369383
}

R/layer_add_target_date.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#' # Use ahead + forecast date
3333
#' f <- frosting() %>%
3434
#' layer_predict() %>%
35-
#' layer_add_forecast_date(forecast_date = "2022-05-31") %>%
35+
#' layer_add_forecast_date(forecast_date = as.Date("2022-05-31")) %>%
3636
#' layer_add_target_date() %>%
3737
#' layer_naomit(.pred)
3838
#' wf1 <- wf %>% add_frosting(f)

R/layer_cdc_flatline_quantiles.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
#' eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
7272
#'
7373
#' wf <- epi_workflow(r, eng, f) %>% fit(case_death_rate_subset)
74-
#' preds <- suppressWarnings(forecast(wf)) %>%
74+
#' preds <- forecast(wf) %>%
7575
#' dplyr::select(-time_value) %>%
7676
#' dplyr::mutate(forecast_date = forecast_date)
7777
#' preds

R/reexports-tidymodels.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#' @export
33
generics::fit
44

5+
#' @importFrom generics forecast
6+
#' @export
7+
generics::forecast
8+
59
#' @importFrom recipes prep
610
#' @export
711
recipes::prep

_pkgdown.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ reference:
8686
- predict.epi_workflow
8787
- fit.epi_workflow
8888
- augment.epi_workflow
89-
- forecast
89+
- forecast.epi_workflow
9090
- title: Epi recipe preprocessing steps
9191
contents:
9292
- starts_with("step_")

man/forecast.Rd renamed to man/forecast.epi_workflow.Rd

Lines changed: 6 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_add_target_date.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_cdc_flatline_quantiles.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/reexports.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-epi_workflow.R

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,26 @@ test_that("forecast method works", {
7171
step_epi_ahead(death_rate, ahead = 7) %>%
7272
step_epi_naomit()
7373
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
74-
75-
latest <- get_test_data(
76-
hardhat::extract_preprocessor(wf),
77-
jhu
74+
expect_equal(
75+
forecast(wf),
76+
predict(wf, new_data = get_test_data(
77+
hardhat::extract_preprocessor(wf),
78+
jhu
79+
))
7880
)
7981

82+
args <- list(
83+
fill_locf = TRUE,
84+
n_recent = 360 * 3,
85+
forecast_date = as.Date("2024-01-01")
86+
)
8087
expect_equal(
81-
forecast(wf),
82-
predict(wf, new_data = latest)
88+
forecast(wf, !!!args),
89+
predict(wf, new_data = get_test_data(
90+
hardhat::extract_preprocessor(wf),
91+
jhu,
92+
!!!args
93+
))
8394
)
8495
})
8596

tests/testthat/test-population_scaling.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ test_that("Postprocessing workflow works and values correct", {
119119
fit(jhu) %>%
120120
add_frosting(f)
121121

122-
suppressWarnings(p <- forecast(wf))
122+
p <- forecast(wf)
123123
expect_equal(nrow(p), 2L)
124124
expect_equal(ncol(p), 4L)
125125
expect_equal(p$.pred_scaled, p$.pred * c(20000, 30000))
@@ -136,7 +136,7 @@ test_that("Postprocessing workflow works and values correct", {
136136
wf <- epi_workflow(r, parsnip::linear_reg()) %>%
137137
fit(jhu) %>%
138138
add_frosting(f)
139-
suppressWarnings(p <- forecast(wf))
139+
p <- forecast(wf)
140140
expect_equal(nrow(p), 2L)
141141
expect_equal(ncol(p), 4L)
142142
expect_equal(p$.pred_scaled, p$.pred * c(2, 3))
@@ -178,7 +178,7 @@ test_that("Postprocessing to get cases from case rate", {
178178
fit(jhu) %>%
179179
add_frosting(f)
180180

181-
suppressWarnings(p <- forecast(wf))
181+
p <- forecast(wf)
182182
expect_equal(nrow(p), 2L)
183183
expect_equal(ncol(p), 4L)
184184
expect_equal(p$.pred_scaled, p$.pred * c(1 / 20000, 1 / 30000))

vignettes/epipredict.Rmd

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ out <- arx_forecaster(
110110
)
111111
```
112112

113-
This call produces a warning, which we'll ignore for now. But essentially, it's telling us that our data comes from May 2022 but we're trying to do a forecast for January 2022. The result is likely not an accurate measure of real-time forecast performance, because the data have been revised over time.
114-
115113
The `out` object has two components:
116114

117115
1. The predictions which is just another `epi_df`. It contains the predictions for each location along with additional columns. By default, these are a 90% predictive interval, the `forecast_date` (the date on which the forecast was putatively made) and the `target_date` (the date for which the forecast is being made).
@@ -123,9 +121,6 @@ out$predictions
123121
out$epi_workflow
124122
```
125123

126-
Note that the `time_value` in the predictions is not necessarily meaningful,
127-
but it is a required column in an `epi_df`, so it remains here.
128-
129124
By default, the forecaster predicts the outcome (`death_rate`) 1-week ahead, using 3 lags of each predictor (`case_rate` and `death_rate`) at 0 (today), 1 week back and 2 weeks back. The predictors and outcome can be changed directly. The rest of the defaults are encapsulated into a list of arguments. This list is produced by `arx_args_list()`.
130125

131126
## Simple adjustments

0 commit comments

Comments
 (0)