Skip to content

Commit b2d1e11

Browse files
authored
Merge pull request #241 from cmu-delphi/240-quantile-pivot
240 quantile pivot
2 parents 015b0ea + 8d1e47d commit b2d1e11

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1451
-223
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
^musings$
1313
^data-raw$
1414
^vignettes/articles$
15+
^.git-blame-ignore-revs$

.github/workflows/R-CMD-check.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# Created with usethis + edited to use API key.
55
on:
66
push:
7-
branches: [main, master]
7+
branches: [main, master, v0.0.6]
88
pull_request:
9-
branches: [main, master]
9+
branches: [main, master, v0.0.6]
1010

1111
name: R-CMD-check
1212

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Imports:
3232
generics,
3333
glue,
3434
hardhat (>= 1.3.0),
35+
lifecycle,
3536
magrittr,
3637
methods,
3738
quantreg,

NAMESPACE

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ S3method(print,alist)
5252
S3method(print,arx_class)
5353
S3method(print,arx_fcast)
5454
S3method(print,canned_epipred)
55+
S3method(print,cdc_baseline_fcast)
5556
S3method(print,epi_workflow)
5657
S3method(print,flat_fcast)
5758
S3method(print,flatline)
@@ -79,6 +80,7 @@ S3method(residuals,flatline)
7980
S3method(run_mold,default_epi_recipe_blueprint)
8081
S3method(slather,layer_add_forecast_date)
8182
S3method(slather,layer_add_target_date)
83+
S3method(slather,layer_cdc_flatline_quantiles)
8284
S3method(slather,layer_naomit)
8385
S3method(slather,layer_point_from_distn)
8486
S3method(slather,layer_population_scaling)
@@ -106,6 +108,8 @@ export(arx_classifier)
106108
export(arx_fcast_epi_workflow)
107109
export(arx_forecaster)
108110
export(bake)
111+
export(cdc_baseline_args_list)
112+
export(cdc_baseline_forecaster)
109113
export(create_layer)
110114
export(default_epi_recipe_blueprint)
111115
export(detect_layer)
@@ -131,6 +135,7 @@ export(is_layer)
131135
export(layer)
132136
export(layer_add_forecast_date)
133137
export(layer_add_target_date)
138+
export(layer_cdc_flatline_quantiles)
134139
export(layer_naomit)
135140
export(layer_point_from_distn)
136141
export(layer_population_scaling)
@@ -143,7 +148,8 @@ export(layer_unnest)
143148
export(nested_quantiles)
144149
export(new_default_epi_recipe_blueprint)
145150
export(new_epi_recipe_blueprint)
146-
export(pivot_quantiles)
151+
export(pivot_quantiles_longer)
152+
export(pivot_quantiles_wider)
147153
export(prep)
148154
export(quantile_reg)
149155
export(remove_frosting)
@@ -167,6 +173,7 @@ importFrom(generics,augment)
167173
importFrom(generics,fit)
168174
importFrom(hardhat,refresh_blueprint)
169175
importFrom(hardhat,run_mold)
176+
importFrom(lifecycle,deprecated)
170177
importFrom(magrittr,"%>%")
171178
importFrom(methods,is)
172179
importFrom(quantreg,rq)
@@ -181,6 +188,7 @@ importFrom(rlang,caller_env)
181188
importFrom(rlang,is_empty)
182189
importFrom(rlang,is_null)
183190
importFrom(rlang,quos)
191+
importFrom(smoothqr,smooth_qr)
184192
importFrom(stats,as.formula)
185193
importFrom(stats,family)
186194
importFrom(stats,lm)

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
* canned forecasters get a class
88
* fixed quantile bug in `flatline_forecaster()`
99
* add functionality to output the unfit workflow from the canned forecasters
10-
* add `pivot_quantiles()` for easier plotting
10+
* add `pivot_quantiles_wider()` for easier plotting
11+
* add complement `pivot_quantiles_longer()`
1112

1213

1314
# epipredict 0.0.4

R/cdc_baseline_forecaster.R

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#' Predict the future with the most recent value
2+
#'
3+
#' This is a simple forecasting model for
4+
#' [epiprocess::epi_df] data. It uses the most recent observation as the
5+
#' forecast for any future date, and produces intervals by shuffling the quantiles
6+
#' of the residuals of such a "flatline" forecast and incrementing these
7+
#' forward over all available training data.
8+
#'
9+
#' By default, the predictive intervals are computed separately for each
10+
#' combination of `geo_value` in the `epi_data` argument.
11+
#'
12+
#' This forecaster is meant to produce exactly the CDC Baseline used for
13+
#' [COVID19ForecastHub](https://covid19forecasthub.org)
14+
#'
15+
#' @param epi_data An [epiprocess::epi_df]
16+
#' @param outcome A scalar character for the column name we wish to predict.
17+
#' @param args_list A list of additional arguments as created by the
18+
#' [cdc_baseline_args_list()] constructor function.
19+
#'
20+
#' @return A data frame of point and interval forecasts at for all
21+
#' aheads (unique horizons) for each unique combination of `key_vars`.
22+
#' @export
23+
#'
24+
#' @examples
25+
#' library(dplyr)
26+
#' weekly_deaths <- case_death_rate_subset %>%
27+
#' select(geo_value, time_value, death_rate) %>%
28+
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
29+
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>%
30+
#' select(-pop, -death_rate) %>%
31+
#' group_by(geo_value) %>%
32+
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
33+
#' ungroup() %>%
34+
#' filter(weekdays(time_value) == "Saturday")
35+
#'
36+
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
37+
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
38+
#'
39+
#' if (require(ggplot2)) {
40+
#' forecast_date <- unique(preds$forecast_date)
41+
#' four_states <- c("ca", "pa", "wa", "ny")
42+
#' preds %>%
43+
#' filter(geo_value %in% four_states) %>%
44+
#' ggplot(aes(target_date)) +
45+
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
46+
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
47+
#' geom_line(aes(y = .pred), color = "orange") +
48+
#' geom_line(
49+
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
50+
#' aes(x = time_value, y = deaths)
51+
#' ) +
52+
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
53+
#' labs(x = "Date", y = "Weekly deaths") +
54+
#' facet_wrap(~geo_value, scales = "free_y") +
55+
#' theme_bw() +
56+
#' geom_vline(xintercept = forecast_date)
57+
#' }
58+
cdc_baseline_forecaster <- function(
59+
epi_data,
60+
outcome,
61+
args_list = cdc_baseline_args_list()) {
62+
validate_forecaster_inputs(epi_data, outcome, "time_value")
63+
if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) {
64+
cli_stop("args_list was not created using `cdc_baseline_args_list().")
65+
}
66+
keys <- epi_keys(epi_data)
67+
ek <- kill_time_value(keys)
68+
outcome <- rlang::sym(outcome)
69+
70+
71+
r <- epi_recipe(epi_data) %>%
72+
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
73+
recipes::update_role(!!outcome, new_role = "predictor") %>%
74+
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
75+
step_training_window(n_recent = args_list$n_training)
76+
77+
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
78+
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead
79+
80+
81+
latest <- get_test_data(
82+
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
83+
forecast_date
84+
)
85+
86+
f <- frosting() %>%
87+
layer_predict() %>%
88+
layer_cdc_flatline_quantiles(
89+
aheads = args_list$aheads,
90+
quantile_levels = args_list$quantile_levels,
91+
nsims = args_list$nsims,
92+
by_key = args_list$quantile_by_key,
93+
symmetrize = args_list$symmetrize,
94+
nonneg = args_list$nonneg
95+
) %>%
96+
layer_add_forecast_date(forecast_date = forecast_date) %>%
97+
layer_unnest(.pred_distn_all)
98+
# layer_add_target_date(target_date = target_date)
99+
if (args_list$nonneg) f <- layer_threshold(f, ".pred")
100+
101+
eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")
102+
103+
wf <- epi_workflow(r, eng, f)
104+
wf <- generics::fit(wf, epi_data)
105+
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
106+
tibble::as_tibble() %>%
107+
dplyr::select(-time_value) %>%
108+
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency)
109+
110+
structure(
111+
list(
112+
predictions = preds,
113+
epi_workflow = wf,
114+
metadata = list(
115+
training = attr(epi_data, "metadata"),
116+
forecast_created = Sys.time()
117+
)
118+
),
119+
class = c("cdc_baseline_fcast", "canned_epipred")
120+
)
121+
}
122+
123+
124+
125+
#' CDC baseline forecaster argument constructor
126+
#'
127+
#' Constructs a list of arguments for [cdc_baseline_forecaster()].
128+
#'
129+
#' @inheritParams arx_args_list
130+
#' @param data_frequency Integer or string. This describes the frequency of the
131+
#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`.
132+
#' Allowable arguments are integers (taken to mean numbers of days) or a
133+
#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods
134+
#' (other than days or weeks) result in an error.
135+
#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have
136+
#' any effect on the predicted values.
137+
#' Predictions are always the most recent observation. This determines the
138+
#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts
139+
#' with the `data_frequency` argument. So, for example, if the data is daily
140+
#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However,
141+
#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`.
142+
#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`.
143+
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
144+
#' prediction intervals. These are created by computing the quantiles of
145+
#' training residuals. A `NULL` value will result in point forecasts only.
146+
#' @param nsims Positive integer. The number of draws from the empirical CDF.
147+
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
148+
#' in linear interpolation on the X scale. This is achieved with
149+
#' [stats::quantile()] Type 7 (the default for that function).
150+
#' @param nonneg Logical. Force all predictive intervals be non-negative.
151+
#' Because non-negativity is forced _before_ propagating forward, this
152+
#' has slightly different behaviour than would occur if using
153+
#' [layer_threshold()].
154+
#'
155+
#' @return A list containing updated parameter choices with class `cdc_flat_fcast`.
156+
#' @export
157+
#'
158+
#' @examples
159+
#' cdc_baseline_args_list()
160+
#' cdc_baseline_args_list(symmetrize = FALSE)
161+
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
162+
cdc_baseline_args_list <- function(
163+
data_frequency = "1 week",
164+
aheads = 1:4,
165+
n_training = Inf,
166+
forecast_date = NULL,
167+
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
168+
nsims = 1e3L,
169+
symmetrize = TRUE,
170+
nonneg = TRUE,
171+
quantile_by_key = "geo_value",
172+
nafill_buffer = Inf) {
173+
arg_is_scalar(n_training, nsims, data_frequency)
174+
data_frequency <- parse_period(data_frequency)
175+
arg_is_pos_int(data_frequency)
176+
arg_is_chr(quantile_by_key, allow_empty = TRUE)
177+
arg_is_scalar(forecast_date, allow_null = TRUE)
178+
arg_is_date(forecast_date, allow_null = TRUE)
179+
arg_is_nonneg_int(aheads, nsims)
180+
arg_is_lgl(symmetrize, nonneg)
181+
arg_is_probabilities(quantile_levels, allow_null = TRUE)
182+
arg_is_pos(n_training)
183+
if (is.finite(n_training)) arg_is_pos_int(n_training)
184+
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
185+
186+
structure(
187+
enlist(
188+
data_frequency,
189+
aheads,
190+
n_training,
191+
forecast_date,
192+
quantile_levels,
193+
nsims,
194+
symmetrize,
195+
nonneg,
196+
quantile_by_key,
197+
nafill_buffer
198+
),
199+
class = c("cdc_baseline_fcast", "alist")
200+
)
201+
}
202+
203+
#' @export
204+
print.cdc_baseline_fcast <- function(x, ...) {
205+
name <- "CDC Baseline"
206+
NextMethod(name = name, ...)
207+
}
208+
209+
parse_period <- function(x) {
210+
arg_is_scalar(x)
211+
if (is.character(x)) {
212+
x <- unlist(strsplit(x, " "))
213+
if (length(x) == 1L) x <- as.numeric(x)
214+
if (length(x) == 2L) {
215+
mult <- substr(x[2], 1, 3)
216+
mult <- switch(
217+
mult,
218+
day = 1L,
219+
wee = 7L,
220+
cli::cli_abort("incompatible timespan in `aheads`.")
221+
)
222+
x <- as.numeric(x[1]) * mult
223+
}
224+
if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.")
225+
}
226+
stopifnot(rlang::is_integerish(x))
227+
as.integer(x)
228+
}

R/compat-purrr.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ map_chr <- function(.x, .f, ...) {
3232
.rlang_purrr_map_mold(.x, .f, character(1), ...)
3333
}
3434

35+
map_vec <- function(.x, .f, ...) {
36+
out <- map(.x, .f, ...)
37+
vctrs::list_unchop(out)
38+
}
39+
3540
map_dfr <- function(.x, .f, ..., .id = NULL) {
3641
.f <- rlang::as_function(.f, env = rlang::global_env())
3742
res <- map(.x, .f, ...)

0 commit comments

Comments
 (0)