Skip to content

Commit ca9f450

Browse files
committed
feat: add forecast method #293
1 parent cd3fe2e commit ca9f450

File tree

5 files changed

+176
-63
lines changed

5 files changed

+176
-63
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ export(flatline)
152152
export(flatline_args_list)
153153
export(flatline_forecaster)
154154
export(flusight_hub_formatter)
155+
export(forecast)
155156
export(frosting)
156157
export(get_test_data)
157158
export(grab_names)

R/epi_workflow.R

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ update_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
197197
#'
198198
#' @export
199199
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) {
200-
object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of)
200+
object$fit$meta <- list(
201+
max_time_value = max(data$time_value),
202+
as_of = attributes(data)$metadata$as_of
203+
)
204+
object$original_data <- data
201205

202206
NextMethod()
203207
}
@@ -326,3 +330,40 @@ print.epi_workflow <- function(x, ...) {
326330
print_postprocessor(x)
327331
invisible(x)
328332
}
333+
334+
335+
#' Produce a forecast from an epi workflow
336+
#'
337+
#' @param epi_workflow An epi workflow
338+
#' @param fill_locf Logical. Should we use locf to fill in missing data?
339+
#' @param n_recent Integer or NULL. If filling missing data with locf = TRUE,
340+
#' how far back are we willing to tolerate missing data? Larger values allow
341+
#' more filling. The default NULL will determine this from the the recipe. For
342+
#' example, suppose n_recent = 3, then if the 3 most recent observations in any
343+
#' geo_value are all NA’s, we won’t be able to fill anything, and an error
344+
#' message will be thrown. (See details.)
345+
#' @param forecast_date By default, this is set to the maximum time_value in x.
346+
#' But if there is data latency such that recent NA's should be filled, this may
347+
#' be after the last available time_value.
348+
#'
349+
#' @return A forecast tibble.
350+
#'
351+
#' @export
352+
forecast <- function(epi_workflow, fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
353+
if (!epi_workflow$trained) {
354+
cli_abort(c(
355+
"You cannot `forecast()` a {.cls workflow} that has not been trained.",
356+
i = "Please use `fit()` before forecasting."
357+
))
358+
}
359+
360+
test_data <- get_test_data(
361+
hardhat::extract_preprocessor(epi_workflow),
362+
epi_workflow$original_data,
363+
fill_locf = fill_locf,
364+
n_recent = n_recent %||% Inf,
365+
forecast_date = forecast_date %||% max(epi_workflow$original_data$time_value)
366+
)
367+
368+
predict(epi_workflow, new_data = test_data)
369+
}

_pkgdown.yml

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,20 @@ navbar:
2222
type: light
2323

2424
articles:
25-
- title: Get started
26-
navbar: ~
27-
contents:
28-
- epipredict
29-
- preprocessing-and-models
30-
- arx-classifier
31-
- articles/update
32-
33-
- title: Advanced methods
34-
contents:
35-
- articles/sliding
36-
- articles/smooth-qr
37-
- articles/symptom-surveys
38-
- panel-data
25+
- title: Get started
26+
navbar: ~
27+
contents:
28+
- epipredict
29+
- preprocessing-and-models
30+
- arx-classifier
31+
- articles/update
3932

33+
- title: Advanced methods
34+
contents:
35+
- articles/sliding
36+
- articles/smooth-qr
37+
- articles/symptom-surveys
38+
- panel-data
4039

4140
repo:
4241
url:
@@ -47,81 +46,78 @@ repo:
4746

4847
home:
4948
links:
50-
- text: Introduction to Delphi's Tooling Work
51-
href: https://cmu-delphi.github.io/delphi-tooling-book/
52-
- text: The epiprocess R package
53-
href: https://cmu-delphi.github.io/epiprocess/
54-
- text: The epidatr R package
55-
href: https://github.com/cmu-delphi/epidatr/
56-
- text: The epidatasets R package
57-
href: https://cmu-delphi.github.io/epidatasets/
58-
- text: The covidcast R package
59-
href: https://cmu-delphi.github.io/covidcast/covidcastR/
60-
49+
- text: Introduction to Delphi's Tooling Work
50+
href: https://cmu-delphi.github.io/delphi-tooling-book/
51+
- text: The epiprocess R package
52+
href: https://cmu-delphi.github.io/epiprocess/
53+
- text: The epidatr R package
54+
href: https://github.com/cmu-delphi/epidatr/
55+
- text: The epidatasets R package
56+
href: https://cmu-delphi.github.io/epidatasets/
57+
- text: The covidcast R package
58+
href: https://cmu-delphi.github.io/covidcast/covidcastR/
6159

6260
reference:
6361
- title: Simple forecasters
6462
desc: Complete forecasters that produce reasonable baselines
6563
contents:
66-
- contains("forecaster")
67-
- contains("classifier")
64+
- contains("forecaster")
65+
- contains("classifier")
6866
- title: Forecaster modifications
6967
desc: Constructors to modify forecaster arguments and utilities to produce `epi_workflow` objects
7068
contents:
71-
- contains("args_list")
72-
- contains("_epi_workflow")
69+
- contains("args_list")
70+
- contains("_epi_workflow")
7371
- title: Helper functions for Hub submission
7472
contents:
75-
- flusight_hub_formatter
73+
- flusight_hub_formatter
7674
- title: Parsnip engines
7775
desc: Prediction methods not available elsewhere
7876
contents:
79-
- quantile_reg
80-
- smooth_quantile_reg
77+
- quantile_reg
78+
- smooth_quantile_reg
8179
- title: Custom panel data forecasting workflows
8280
contents:
83-
- epi_recipe
84-
- epi_workflow
85-
- add_epi_recipe
86-
- adjust_epi_recipe
87-
- add_model
88-
- predict.epi_workflow
89-
- fit.epi_workflow
90-
- augment.epi_workflow
81+
- epi_recipe
82+
- epi_workflow
83+
- add_epi_recipe
84+
- adjust_epi_recipe
85+
- add_model
86+
- predict.epi_workflow
87+
- fit.epi_workflow
88+
- augment.epi_workflow
89+
- forecast
9190
- title: Epi recipe preprocessing steps
9291
contents:
93-
- starts_with("step_")
94-
- contains("bake")
95-
- contains("juice")
92+
- starts_with("step_")
93+
- contains("bake")
94+
- contains("juice")
9695
- title: Epi recipe verification checks
9796
contents:
98-
- check_enough_train_data
97+
- check_enough_train_data
9998
- title: Forecast postprocessing
10099
desc: Create a series of postprocessing operations
101100
contents:
102-
- frosting
103-
- ends_with("_frosting")
104-
- get_test_data
105-
- tidy.frosting
101+
- frosting
102+
- ends_with("_frosting")
103+
- get_test_data
104+
- tidy.frosting
106105
- title: Frosting layers
107106
contents:
108-
- contains("layer")
109-
- contains("slather")
107+
- contains("layer")
108+
- contains("slather")
110109
- title: Automatic forecast visualization
111110
contents:
112-
- autoplot.epi_workflow
113-
- autoplot.canned_epipred
111+
- autoplot.epi_workflow
112+
- autoplot.canned_epipred
114113
- title: Utilities for quantile distribution processing
115114
contents:
116-
- dist_quantiles
117-
- extrapolate_quantiles
118-
- nested_quantiles
119-
- starts_with("pivot_quantiles")
115+
- dist_quantiles
116+
- extrapolate_quantiles
117+
- nested_quantiles
118+
- starts_with("pivot_quantiles")
120119
- title: Included datasets
121120
contents:
122-
- case_death_rate_subset
123-
- state_census
124-
- grad_employ_subset
125-
126-
127-
121+
- case_death_rate_subset
122+
- state_census
123+
- grad_employ_subset

man/forecast.Rd

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-epi_workflow.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,43 @@ test_that("model can be added/updated/removed from epi_workflow", {
6262
expect_error(extract_spec_parsnip(wf))
6363
expect_equal(wf$fit$actions$model$spec, NULL)
6464
})
65+
66+
test_that("forecast method works", {
67+
jhu <- case_death_rate_subset %>%
68+
filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
69+
r <- epi_recipe(jhu) %>%
70+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
71+
step_epi_ahead(death_rate, ahead = 7) %>%
72+
step_epi_naomit()
73+
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
74+
75+
latest <- get_test_data(
76+
hardhat::extract_preprocessor(wf),
77+
jhu
78+
)
79+
80+
expect_equal(
81+
forecast(wf),
82+
predict(wf, new_data = latest)
83+
)
84+
})
85+
86+
test_that("forecast method errors when workflow not fit", {
87+
jhu <- case_death_rate_subset %>%
88+
filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
89+
r <- epi_recipe(jhu) %>%
90+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
91+
step_epi_ahead(death_rate, ahead = 7) %>%
92+
step_epi_naomit()
93+
wf <- epi_workflow(r, parsnip::linear_reg())
94+
95+
latest <- get_test_data(
96+
hardhat::extract_preprocessor(wf),
97+
jhu
98+
)
99+
100+
expect_error(
101+
forecast(wf),
102+
regexp = "workflow that has not been fit"
103+
)
104+
})

0 commit comments

Comments
 (0)