|
| 1 | +#' Predict the future with the most recent value |
| 2 | +#' |
| 3 | +#' This is a simple forecasting model for |
| 4 | +#' [epiprocess::epi_df] data. It uses the most recent observation as the |
| 5 | +#' forecast for any future date, and produces intervals by shuffling the quantiles |
| 6 | +#' of the residuals of such a "flatline" forecast and incrementing these |
| 7 | +#' forward over all available training data. |
| 8 | +#' |
| 9 | +#' By default, the predictive intervals are computed separately for each |
| 10 | +#' combination of `geo_value` in the `epi_data` argument. |
| 11 | +#' |
| 12 | +#' This forecaster is meant to produce exactly the CDC Baseline used for |
| 13 | +#' [COVID19ForecastHub](https://covid19forecasthub.org) |
| 14 | +#' |
| 15 | +#' @param epi_data An [epiprocess::epi_df] |
| 16 | +#' @param outcome A scalar character for the column name we wish to predict. |
| 17 | +#' @param args_list A list of additional arguments as created by the |
| 18 | +#' [cdc_baseline_args_list()] constructor function. |
| 19 | +#' |
| 20 | +#' @return A data frame of point and interval forecasts at for all |
| 21 | +#' aheads (unique horizons) for each unique combination of `key_vars`. |
| 22 | +#' @export |
| 23 | +#' |
| 24 | +#' @examples |
| 25 | +#' library(dplyr) |
| 26 | +#' weekly_deaths <- case_death_rate_subset %>% |
| 27 | +#' select(geo_value, time_value, death_rate) %>% |
| 28 | +#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>% |
| 29 | +#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>% |
| 30 | +#' select(-pop, -death_rate) %>% |
| 31 | +#' group_by(geo_value) %>% |
| 32 | +#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>% |
| 33 | +#' ungroup() %>% |
| 34 | +#' filter(weekdays(time_value) == "Saturday") |
| 35 | +#' |
| 36 | +#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") |
| 37 | +#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn) |
| 38 | +#' |
| 39 | +#' if (require(ggplot2)) { |
| 40 | +#' forecast_date <- unique(preds$forecast_date) |
| 41 | +#' four_states <- c("ca", "pa", "wa", "ny") |
| 42 | +#' preds %>% |
| 43 | +#' filter(geo_value %in% four_states) %>% |
| 44 | +#' ggplot(aes(target_date)) + |
| 45 | +#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) + |
| 46 | +#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) + |
| 47 | +#' geom_line(aes(y = .pred), color = "orange") + |
| 48 | +#' geom_line( |
| 49 | +#' data = weekly_deaths %>% filter(geo_value %in% four_states), |
| 50 | +#' aes(x = time_value, y = deaths) |
| 51 | +#' ) + |
| 52 | +#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) + |
| 53 | +#' labs(x = "Date", y = "Weekly deaths") + |
| 54 | +#' facet_wrap(~geo_value, scales = "free_y") + |
| 55 | +#' theme_bw() + |
| 56 | +#' geom_vline(xintercept = forecast_date) |
| 57 | +#' } |
| 58 | +cdc_baseline_forecaster <- function( |
| 59 | + epi_data, |
| 60 | + outcome, |
| 61 | + args_list = cdc_baseline_args_list()) { |
| 62 | + validate_forecaster_inputs(epi_data, outcome, "time_value") |
| 63 | + if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) { |
| 64 | + cli_stop("args_list was not created using `cdc_baseline_args_list().") |
| 65 | + } |
| 66 | + keys <- epi_keys(epi_data) |
| 67 | + ek <- kill_time_value(keys) |
| 68 | + outcome <- rlang::sym(outcome) |
| 69 | + |
| 70 | + |
| 71 | + r <- epi_recipe(epi_data) %>% |
| 72 | + step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>% |
| 73 | + recipes::update_role(!!outcome, new_role = "predictor") %>% |
| 74 | + recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>% |
| 75 | + step_training_window(n_recent = args_list$n_training) |
| 76 | + |
| 77 | + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) |
| 78 | + # target_date <- args_list$target_date %||% forecast_date + args_list$ahead |
| 79 | + |
| 80 | + |
| 81 | + latest <- get_test_data( |
| 82 | + epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer, |
| 83 | + forecast_date |
| 84 | + ) |
| 85 | + |
| 86 | + f <- frosting() %>% |
| 87 | + layer_predict() %>% |
| 88 | + layer_cdc_flatline_quantiles( |
| 89 | + aheads = args_list$aheads, |
| 90 | + quantile_levels = args_list$quantile_levels, |
| 91 | + nsims = args_list$nsims, |
| 92 | + by_key = args_list$quantile_by_key, |
| 93 | + symmetrize = args_list$symmetrize, |
| 94 | + nonneg = args_list$nonneg |
| 95 | + ) %>% |
| 96 | + layer_add_forecast_date(forecast_date = forecast_date) %>% |
| 97 | + layer_unnest(.pred_distn_all) |
| 98 | + # layer_add_target_date(target_date = target_date) |
| 99 | + if (args_list$nonneg) f <- layer_threshold(f, ".pred") |
| 100 | + |
| 101 | + eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") |
| 102 | + |
| 103 | + wf <- epi_workflow(r, eng, f) |
| 104 | + wf <- generics::fit(wf, epi_data) |
| 105 | + preds <- suppressWarnings(predict(wf, new_data = latest)) %>% |
| 106 | + tibble::as_tibble() %>% |
| 107 | + dplyr::select(-time_value) %>% |
| 108 | + dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency) |
| 109 | + |
| 110 | + structure( |
| 111 | + list( |
| 112 | + predictions = preds, |
| 113 | + epi_workflow = wf, |
| 114 | + metadata = list( |
| 115 | + training = attr(epi_data, "metadata"), |
| 116 | + forecast_created = Sys.time() |
| 117 | + ) |
| 118 | + ), |
| 119 | + class = c("cdc_baseline_fcast", "canned_epipred") |
| 120 | + ) |
| 121 | +} |
| 122 | + |
| 123 | + |
| 124 | + |
| 125 | +#' CDC baseline forecaster argument constructor |
| 126 | +#' |
| 127 | +#' Constructs a list of arguments for [cdc_baseline_forecaster()]. |
| 128 | +#' |
| 129 | +#' @inheritParams arx_args_list |
| 130 | +#' @param data_frequency Integer or string. This describes the frequency of the |
| 131 | +#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`. |
| 132 | +#' Allowable arguments are integers (taken to mean numbers of days) or a |
| 133 | +#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods |
| 134 | +#' (other than days or weeks) result in an error. |
| 135 | +#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have |
| 136 | +#' any effect on the predicted values. |
| 137 | +#' Predictions are always the most recent observation. This determines the |
| 138 | +#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts |
| 139 | +#' with the `data_frequency` argument. So, for example, if the data is daily |
| 140 | +#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However, |
| 141 | +#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`. |
| 142 | +#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`. |
| 143 | +#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce |
| 144 | +#' prediction intervals. These are created by computing the quantiles of |
| 145 | +#' training residuals. A `NULL` value will result in point forecasts only. |
| 146 | +#' @param nsims Positive integer. The number of draws from the empirical CDF. |
| 147 | +#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting |
| 148 | +#' in linear interpolation on the X scale. This is achieved with |
| 149 | +#' [stats::quantile()] Type 7 (the default for that function). |
| 150 | +#' @param nonneg Logical. Force all predictive intervals be non-negative. |
| 151 | +#' Because non-negativity is forced _before_ propagating forward, this |
| 152 | +#' has slightly different behaviour than would occur if using |
| 153 | +#' [layer_threshold()]. |
| 154 | +#' |
| 155 | +#' @return A list containing updated parameter choices with class `cdc_flat_fcast`. |
| 156 | +#' @export |
| 157 | +#' |
| 158 | +#' @examples |
| 159 | +#' cdc_baseline_args_list() |
| 160 | +#' cdc_baseline_args_list(symmetrize = FALSE) |
| 161 | +#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120) |
| 162 | +cdc_baseline_args_list <- function( |
| 163 | + data_frequency = "1 week", |
| 164 | + aheads = 1:4, |
| 165 | + n_training = Inf, |
| 166 | + forecast_date = NULL, |
| 167 | + quantile_levels = c(.01, .025, 1:19 / 20, .975, .99), |
| 168 | + nsims = 1e3L, |
| 169 | + symmetrize = TRUE, |
| 170 | + nonneg = TRUE, |
| 171 | + quantile_by_key = "geo_value", |
| 172 | + nafill_buffer = Inf) { |
| 173 | + arg_is_scalar(n_training, nsims, data_frequency) |
| 174 | + data_frequency <- parse_period(data_frequency) |
| 175 | + arg_is_pos_int(data_frequency) |
| 176 | + arg_is_chr(quantile_by_key, allow_empty = TRUE) |
| 177 | + arg_is_scalar(forecast_date, allow_null = TRUE) |
| 178 | + arg_is_date(forecast_date, allow_null = TRUE) |
| 179 | + arg_is_nonneg_int(aheads, nsims) |
| 180 | + arg_is_lgl(symmetrize, nonneg) |
| 181 | + arg_is_probabilities(quantile_levels, allow_null = TRUE) |
| 182 | + arg_is_pos(n_training) |
| 183 | + if (is.finite(n_training)) arg_is_pos_int(n_training) |
| 184 | + if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) |
| 185 | + |
| 186 | + structure( |
| 187 | + enlist( |
| 188 | + data_frequency, |
| 189 | + aheads, |
| 190 | + n_training, |
| 191 | + forecast_date, |
| 192 | + quantile_levels, |
| 193 | + nsims, |
| 194 | + symmetrize, |
| 195 | + nonneg, |
| 196 | + quantile_by_key, |
| 197 | + nafill_buffer |
| 198 | + ), |
| 199 | + class = c("cdc_baseline_fcast", "alist") |
| 200 | + ) |
| 201 | +} |
| 202 | + |
| 203 | +#' @export |
| 204 | +print.cdc_baseline_fcast <- function(x, ...) { |
| 205 | + name <- "CDC Baseline" |
| 206 | + NextMethod(name = name, ...) |
| 207 | +} |
| 208 | + |
| 209 | +parse_period <- function(x) { |
| 210 | + arg_is_scalar(x) |
| 211 | + if (is.character(x)) { |
| 212 | + x <- unlist(strsplit(x, " ")) |
| 213 | + if (length(x) == 1L) x <- as.numeric(x) |
| 214 | + if (length(x) == 2L) { |
| 215 | + mult <- substr(x[2], 1, 3) |
| 216 | + mult <- switch( |
| 217 | + mult, |
| 218 | + day = 1L, |
| 219 | + wee = 7L, |
| 220 | + cli::cli_abort("incompatible timespan in `aheads`.") |
| 221 | + ) |
| 222 | + x <- as.numeric(x[1]) * mult |
| 223 | + } |
| 224 | + if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.") |
| 225 | + } |
| 226 | + stopifnot(rlang::is_integerish(x)) |
| 227 | + as.integer(x) |
| 228 | +} |
0 commit comments