Skip to content

Commit 990269d

Browse files
committed
fix: review updates
1 parent 3600791 commit 990269d

12 files changed

+38
-43
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ S3method(extrapolate_quantiles,distribution)
4545
S3method(fit,epi_workflow)
4646
S3method(flusight_hub_formatter,canned_epipred)
4747
S3method(flusight_hub_formatter,data.frame)
48+
S3method(forecast,epi_workflow)
4849
S3method(format,dist_quantiles)
4950
S3method(is.na,dist_quantiles)
5051
S3method(is.na,distribution)
@@ -220,6 +221,7 @@ importFrom(dplyr,ungroup)
220221
importFrom(epiprocess,growth_rate)
221222
importFrom(generics,augment)
222223
importFrom(generics,fit)
224+
importFrom(generics,forecast)
223225
importFrom(ggplot2,autoplot)
224226
importFrom(hardhat,refresh_blueprint)
225227
importFrom(hardhat,run_mold)

R/arx_classifier.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ arx_classifier <- function(
5151
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
5252
}
5353

54-
wf <- arx_class_epi_workflow(
55-
epi_data, outcome, predictors, trainer, args_list
56-
)
54+
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5755
wf <- generics::fit(wf, epi_data)
5856

5957
preds <- forecast(

R/arx_forecaster.R

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,19 @@
3838
#' trainer = quantile_reg(),
3939
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
4040
#' )
41-
arx_forecaster <- function(epi_data,
42-
outcome,
43-
predictors = outcome,
44-
trainer = parsnip::linear_reg(),
45-
args_list = arx_args_list()) {
41+
arx_forecaster <- function(
42+
epi_data,
43+
outcome,
44+
predictors = outcome,
45+
trainer = parsnip::linear_reg(),
46+
args_list = arx_args_list()) {
4647
if (!is_regression(trainer)) {
4748
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
4849
}
4950

50-
wf <- arx_fcast_epi_workflow(
51-
epi_data, outcome, predictors, trainer, args_list
52-
)
51+
wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5352
wf <- generics::fit(wf, epi_data)
5453

55-
latest <- get_test_data(
56-
hardhat::extract_preprocessor(wf), epi_data,
57-
)
58-
5954
preds <- forecast(
6055
wf,
6156
fill_locf = TRUE,

R/epi_workflow.R

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ print.epi_workflow <- function(x, ...) {
334334

335335
#' Produce a forecast from an epi workflow
336336
#'
337-
#' @param epi_workflow An epi workflow
337+
#' @param object An epi workflow.
338+
#' @param ... Not used.
338339
#' @param fill_locf Logical. Should we use locf to fill in missing data?
339340
#' @param n_recent Integer or NULL. If filling missing data with locf = TRUE,
340341
#' how far back are we willing to tolerate missing data? Larger values allow
@@ -349,21 +350,23 @@ print.epi_workflow <- function(x, ...) {
349350
#' @return A forecast tibble.
350351
#'
351352
#' @export
352-
forecast <- function(epi_workflow, fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
353-
if (!epi_workflow$trained) {
353+
forecast.epi_workflow <- function(object, ..., fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
354+
rlang::check_dots_empty()
355+
356+
if (!object$trained) {
354357
cli_abort(c(
355358
"You cannot `forecast()` a {.cls workflow} that has not been trained.",
356359
i = "Please use `fit()` before forecasting."
357360
))
358361
}
359362

360363
test_data <- get_test_data(
361-
hardhat::extract_preprocessor(epi_workflow),
362-
epi_workflow$original_data,
364+
hardhat::extract_preprocessor(object),
365+
object$original_data,
363366
fill_locf = fill_locf,
364367
n_recent = n_recent %||% Inf,
365-
forecast_date = forecast_date %||% max(epi_workflow$original_data$time_value)
368+
forecast_date = forecast_date %||% max(object$original_data$time_value)
366369
)
367370

368-
predict(epi_workflow, new_data = test_data)
371+
predict(object, new_data = test_data)
369372
}

R/layer_cdc_flatline_quantiles.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
#' eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
7272
#'
7373
#' wf <- epi_workflow(r, eng, f) %>% fit(case_death_rate_subset)
74-
#' preds <- suppressWarnings(forecast(wf)) %>%
74+
#' preds <- forecast(wf) %>%
7575
#' dplyr::select(-time_value) %>%
7676
#' dplyr::mutate(forecast_date = forecast_date)
7777
#' preds

R/reexports-tidymodels.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#' @export
33
generics::fit
44

5+
#' @importFrom generics forecast
6+
#' @export
7+
generics::forecast
8+
59
#' @importFrom recipes prep
610
#' @export
711
recipes::prep

_pkgdown.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ reference:
7070
- predict.epi_workflow
7171
- fit.epi_workflow
7272
- augment.epi_workflow
73-
- forecast
73+
- forecast.epi_workflow
7474
- title: Epi recipe preprocessing steps
7575
contents:
7676
- starts_with("step_")

man/forecast.Rd renamed to man/forecast.epi_workflow.Rd

Lines changed: 6 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_cdc_flatline_quantiles.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/reexports.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-population_scaling.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ test_that("Postprocessing workflow works and values correct", {
119119
fit(jhu) %>%
120120
add_frosting(f)
121121

122-
suppressWarnings(p <- forecast(wf))
122+
p <- forecast(wf)
123123
expect_equal(nrow(p), 2L)
124124
expect_equal(ncol(p), 4L)
125125
expect_equal(p$.pred_scaled, p$.pred * c(20000, 30000))
@@ -136,7 +136,7 @@ test_that("Postprocessing workflow works and values correct", {
136136
wf <- epi_workflow(r, parsnip::linear_reg()) %>%
137137
fit(jhu) %>%
138138
add_frosting(f)
139-
suppressWarnings(p <- forecast(wf))
139+
p <- forecast(wf)
140140
expect_equal(nrow(p), 2L)
141141
expect_equal(ncol(p), 4L)
142142
expect_equal(p$.pred_scaled, p$.pred * c(2, 3))
@@ -178,7 +178,7 @@ test_that("Postprocessing to get cases from case rate", {
178178
fit(jhu) %>%
179179
add_frosting(f)
180180

181-
suppressWarnings(p <- forecast(wf))
181+
p <- forecast(wf)
182182
expect_equal(nrow(p), 2L)
183183
expect_equal(ncol(p), 4L)
184184
expect_equal(p$.pred_scaled, p$.pred * c(1 / 20000, 1 / 30000))

vignettes/epipredict.Rmd

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ out <- arx_forecaster(
110110
)
111111
```
112112

113-
This call produces a warning, which we'll ignore for now. But essentially, it's telling us that our data comes from May 2022 but we're trying to do a forecast for January 2022. The result is likely not an accurate measure of real-time forecast performance, because the data have been revised over time.
114-
115113
The `out` object has two components:
116114

117115
1. The predictions which is just another `epi_df`. It contains the predictions for each location along with additional columns. By default, these are a 90% predictive interval, the `forecast_date` (the date on which the forecast was putatively made) and the `target_date` (the date for which the forecast is being made).
@@ -123,9 +121,6 @@ out$predictions
123121
out$epi_workflow
124122
```
125123

126-
Note that the `time_value` in the predictions is not necessarily meaningful,
127-
but it is a required column in an `epi_df`, so it remains here.
128-
129124
By default, the forecaster predicts the outcome (`death_rate`) 1-week ahead, using 3 lags of each predictor (`case_rate` and `death_rate`) at 0 (today), 1 week back and 2 weeks back. The predictors and outcome can be changed directly. The rest of the defaults are encapsulated into a list of arguments. This list is produced by `arx_args_list()`.
130125

131126
## Simple adjustments

0 commit comments

Comments
 (0)