Skip to content

Commit 22386e2

Browse files
committed
refactor: replace predict with forecast in vignettes, tests, examples
1 parent 95cb6ea commit 22386e2

34 files changed

+114
-255
lines changed

R/arx_classifier.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ arx_classifier <- function(
5454
wf <- arx_class_epi_workflow(
5555
epi_data, outcome, predictors, trainer, args_list
5656
)
57-
58-
latest <- get_test_data(
59-
hardhat::extract_preprocessor(wf), epi_data, TRUE, args_list$nafill_buffer,
60-
args_list$forecast_date %||% max(epi_data$time_value)
61-
)
62-
6357
wf <- generics::fit(wf, epi_data)
64-
preds <- predict(wf, new_data = latest) %>%
58+
59+
preds <- forecast(
60+
wf,
61+
fill_locf = TRUE,
62+
n_recent = args_list$nafill_buffer,
63+
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
64+
) %>%
6565
tibble::as_tibble() %>%
6666
dplyr::select(-time_value)
6767

R/arx_forecaster.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,18 @@ arx_forecaster <- function(epi_data,
5050
wf <- arx_fcast_epi_workflow(
5151
epi_data, outcome, predictors, trainer, args_list
5252
)
53+
wf <- generics::fit(wf, epi_data)
5354

5455
latest <- get_test_data(
55-
hardhat::extract_preprocessor(wf), epi_data, TRUE, args_list$nafill_buffer,
56-
args_list$forecast_date %||% max(epi_data$time_value)
56+
hardhat::extract_preprocessor(wf), epi_data,
5757
)
5858

59-
wf <- generics::fit(wf, epi_data)
60-
preds <- predict(wf, new_data = latest) %>%
59+
preds <- forecast(
60+
wf,
61+
fill_locf = TRUE,
62+
n_recent = args_list$nafill_buffer,
63+
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
64+
) %>%
6165
tibble::as_tibble() %>%
6266
dplyr::select(-time_value)
6367

R/autoplot.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ ggplot2::autoplot
6161
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
6262
#' step_epi_naomit()
6363
#' ewf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu)
64-
#' td <- get_test_data(r, jhu)
65-
#' predict(ewf, new_data = td)
64+
#' forecast(ewf)
6665
#' })
6766
#'
6867
#' p <- do.call(rbind, p)

R/flatline_forecaster.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ flatline_forecaster <- function(
4949
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
5050
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
5151

52-
latest <- get_test_data(
53-
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
54-
forecast_date
55-
)
56-
5752
f <- frosting() %>%
5853
layer_predict() %>%
5954
layer_residual_quantiles(
@@ -69,7 +64,12 @@ flatline_forecaster <- function(
6964

7065
wf <- epi_workflow(r, eng, f)
7166
wf <- generics::fit(wf, epi_data)
72-
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
67+
preds <- suppressWarnings(forecast(
68+
wf,
69+
fill_locf = TRUE,
70+
n_recent = args_list$nafill_buffer,
71+
forecast_date = forecast_date
72+
)) %>%
7373
tibble::as_tibble() %>%
7474
dplyr::select(-time_value)
7575

R/frosting.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,15 +275,14 @@ new_frosting <- function() {
275275
#' step_epi_naomit()
276276
#'
277277
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
278-
#' latest <- get_test_data(recipe = r, x = jhu)
279278
#'
280279
#' f <- frosting() %>%
281280
#' layer_predict() %>%
282281
#' layer_naomit(.pred)
283282
#'
284283
#' wf1 <- wf %>% add_frosting(f)
285284
#'
286-
#' p <- predict(wf1, latest)
285+
#' p <- forecast(wf1)
287286
#' p
288287
frosting <- function(layers = NULL, requirements = NULL) {
289288
if (!is_null(layers) || !is_null(requirements)) {

R/layer_add_target_date.R

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#' step_epi_naomit()
2929
#'
3030
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
31-
#' latest <- get_test_data(r, jhu)
3231
#'
3332
#' # Use ahead + forecast date
3433
#' f <- frosting() %>%
@@ -38,7 +37,7 @@
3837
#' layer_naomit(.pred)
3938
#' wf1 <- wf %>% add_frosting(f)
4039
#'
41-
#' p <- predict(wf1, latest)
40+
#' p <- forecast(wf1)
4241
#' p
4342
#'
4443
#' # Use ahead + max time value from pre, fit, post
@@ -49,7 +48,7 @@
4948
#' layer_naomit(.pred)
5049
#' wf2 <- wf %>% add_frosting(f2)
5150
#'
52-
#' p2 <- predict(wf2, latest)
51+
#' p2 <- forecast(wf2)
5352
#' p2
5453
#'
5554
#' # Specify own target date
@@ -59,7 +58,7 @@
5958
#' layer_naomit(.pred)
6059
#' wf3 <- wf %>% add_frosting(f3)
6160
#'
62-
#' p3 <- predict(wf3, latest)
61+
#' p3 <- forecast(wf3)
6362
#' p3
6463
layer_add_target_date <-
6564
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {

R/layer_cdc_flatline_quantiles.R

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,14 @@
6464
#'
6565
#' forecast_date <- max(case_death_rate_subset$time_value)
6666
#'
67-
#' latest <- get_test_data(
68-
#' epi_recipe(case_death_rate_subset), case_death_rate_subset
69-
#' )
70-
#'
7167
#' f <- frosting() %>%
7268
#' layer_predict() %>%
7369
#' layer_cdc_flatline_quantiles(aheads = c(7, 14, 21, 28), symmetrize = TRUE)
7470
#'
7571
#' eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
7672
#'
7773
#' wf <- epi_workflow(r, eng, f) %>% fit(case_death_rate_subset)
78-
#' preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
74+
#' preds <- suppressWarnings(forecast(wf)) %>%
7975
#' dplyr::select(-time_value) %>%
8076
#' dplyr::mutate(forecast_date = forecast_date)
8177
#' preds

R/layer_naomit.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
#'
2121
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
2222
#'
23-
#' latest <- get_test_data(recipe = r, x = jhu)
24-
#'
2523
#' f <- frosting() %>%
2624
#' layer_predict() %>%
2725
#' layer_naomit(.pred)
2826
#'
2927
#' wf1 <- wf %>% add_frosting(f)
3028
#'
31-
#' p <- predict(wf1, latest)
29+
#' p <- forecast(wf1)
3230
#' p
3331
layer_naomit <- function(frosting, ..., id = rand_id("naomit")) {
3432
arg_is_chr_scalar(id)

R/layer_point_from_distn.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@
2626
#'
2727
#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% fit(jhu)
2828
#'
29-
#' latest <- get_test_data(recipe = r, x = jhu)
30-
#'
3129
#' f1 <- frosting() %>%
3230
#' layer_predict() %>%
3331
#' layer_quantile_distn() %>% # puts the other quantiles in a different col
3432
#' layer_point_from_distn() %>% # mutate `.pred` to contain only a point prediction
3533
#' layer_naomit(.pred)
3634
#' wf1 <- wf %>% add_frosting(f1)
3735
#'
38-
#' p1 <- predict(wf1, latest)
36+
#' p1 <- forecast(wf1)
3937
#' p1
4038
#'
4139
#' f2 <- frosting() %>%
@@ -44,7 +42,7 @@
4442
#' layer_naomit(.pred)
4543
#' wf2 <- wf %>% add_frosting(f2)
4644
#'
47-
#' p2 <- predict(wf2, latest)
45+
#' p2 <- forecast(wf2)
4846
#' p2
4947
layer_point_from_distn <- function(frosting,
5048
...,

R/layer_population_scaling.R

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,7 @@
7878
#' fit(jhu) %>%
7979
#' add_frosting(f)
8080
#'
81-
#' latest <- get_test_data(
82-
#' recipe = r,
83-
#' x = epiprocess::jhu_csse_daily_subset %>%
84-
#' dplyr::filter(
85-
#' time_value > "2021-11-01",
86-
#' geo_value %in% c("ca", "ny")
87-
#' ) %>%
88-
#' dplyr::select(geo_value, time_value, cases)
89-
#' )
90-
#'
91-
#' predict(wf, latest)
81+
#' forecast(wf)
9282
layer_population_scaling <- function(frosting,
9383
...,
9484
df,

R/layer_predictive_distn.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,13 @@
3030
#'
3131
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
3232
#'
33-
#' latest <- get_test_data(recipe = r, x = jhu)
34-
#'
3533
#' f <- frosting() %>%
3634
#' layer_predict() %>%
3735
#' layer_predictive_distn() %>%
3836
#' layer_naomit(.pred)
3937
#' wf1 <- wf %>% add_frosting(f)
4038
#'
41-
#' p <- predict(wf1, latest)
39+
#' p <- forecast(wf1)
4240
#' p
4341
layer_predictive_distn <- function(frosting,
4442
...,

R/layer_quantile_distn.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@
2828
#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>%
2929
#' fit(jhu)
3030
#'
31-
#' latest <- get_test_data(recipe = r, x = jhu)
32-
#'
3331
#' f <- frosting() %>%
3432
#' layer_predict() %>%
3533
#' layer_quantile_distn() %>%
3634
#' layer_naomit(.pred)
3735
#' wf1 <- wf %>% add_frosting(f)
3836
#'
39-
#' p <- predict(wf1, latest)
37+
#' p <- forecast(wf1)
4038
#' p
4139
layer_quantile_distn <- function(frosting,
4240
...,

R/layer_residual_quantiles.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,21 @@
2424
#'
2525
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
2626
#'
27-
#' latest <- get_test_data(recipe = r, x = jhu)
28-
#'
2927
#' f <- frosting() %>%
3028
#' layer_predict() %>%
3129
#' layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) %>%
3230
#' layer_naomit(.pred)
3331
#' wf1 <- wf %>% add_frosting(f)
3432
#'
35-
#' p <- predict(wf1, latest)
33+
#' p <- forecast(wf1)
3634
#'
3735
#' f2 <- frosting() %>%
3836
#' layer_predict() %>%
3937
#' layer_residual_quantiles(quantile_levels = c(0.3, 0.7), by_key = "geo_value") %>%
4038
#' layer_naomit(.pred)
4139
#' wf2 <- wf %>% add_frosting(f2)
4240
#'
43-
#' p2 <- predict(wf2, latest)
41+
#' p2 <- forecast(wf2)
4442
layer_residual_quantiles <- function(
4543
frosting, ...,
4644
quantile_levels = c(0.05, 0.95),

R/layer_threshold_preds.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@
3232
#' step_epi_naomit()
3333
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
3434
#'
35-
#' latest <- get_test_data(r, jhu)
36-
#'
3735
#' f <- frosting() %>%
3836
#' layer_predict() %>%
3937
#' layer_threshold(.pred, lower = 0.180, upper = 0.310)
4038
#' wf <- wf %>% add_frosting(f)
41-
#' p <- predict(wf, latest)
39+
#' p <- forecast(wf)
4240
#' p
4341
layer_threshold <-
4442
function(frosting, ..., lower = 0, upper = Inf, id = rand_id("threshold")) {

R/step_population_scaling.R

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,7 @@
9696
#' fit(jhu) %>%
9797
#' add_frosting(f)
9898
#'
99-
#' latest <- get_test_data(
100-
#' recipe = r,
101-
#' epiprocess::jhu_csse_daily_subset %>%
102-
#' dplyr::filter(
103-
#' time_value > "2021-11-01",
104-
#' geo_value %in% c("ca", "ny")
105-
#' ) %>%
106-
#' dplyr::select(geo_value, time_value, cases)
107-
#' )
108-
#'
109-
#'
110-
#' predict(wf, latest)
99+
#' forecast(wf)
111100
step_population_scaling <-
112101
function(recipe,
113102
...,

man/autoplot-epipred.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/frosting.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_add_target_date.Rd

Lines changed: 3 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_cdc_flatline_quantiles.Rd

Lines changed: 1 addition & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)