Skip to content

Commit 63a520b

Browse files
authored
Merge pull request #358 from brookslogan/lcb/layer_predict-passing
Make `layer_predict` forward stored dots_list to `predict()`
2 parents 9d57a62 + 7fd4094 commit 63a520b

30 files changed

+221
-32
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.16
3+
Version: 0.0.17
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ import(parsnip)
208208
import(recipes)
209209
importFrom(checkmate,assert)
210210
importFrom(checkmate,assert_character)
211+
importFrom(checkmate,assert_class)
211212
importFrom(checkmate,assert_date)
212213
importFrom(checkmate,assert_function)
213214
importFrom(checkmate,assert_int)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
4848
`{usethis}`)
4949
- Replaced old version-faithful example in sliding AR & ARX forecasters vignette
5050
- `epi_recipe()` will now warn when given non-`epi_df` data
51+
- `layer_predict()` and `predict.epi_workflow()` will now appropriately forward
52+
`...` args intended for `predict.model_fit()`
53+
- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the
54+
steps has changed the appropriate values

R/arx_classifier.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Direct autoregressive classifier with covariates
22
#'
33
#' This is an autoregressive classification model for
4-
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
4+
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning
55
#' that it estimates a class at a particular target horizon.
66
#'
77
#' @inheritParams arx_forecaster

R/arx_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Direct autoregressive forecaster with covariates
22
#'
33
#' This is an autoregressive forecasting model for
4-
#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning
4+
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning
55
#' that it estimates a model for a particular target horizon.
66
#'
77
#'

R/cdc_baseline_forecaster.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Predict the future with the most recent value
22
#'
33
#' This is a simple forecasting model for
4-
#' [epiprocess::epi_df] data. It uses the most recent observation as the
4+
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the
55
#' forecast for any future date, and produces intervals by shuffling the quantiles
66
#' of the residuals of such a "flatline" forecast and incrementing these
77
#' forward over all available training data.
@@ -12,7 +12,7 @@
1212
#' This forecaster is meant to produce exactly the CDC Baseline used for
1313
#' [COVID19ForecastHub](https://covid19forecasthub.org)
1414
#'
15-
#' @param epi_data An [`epiprocess::epi_df`]
15+
#' @param epi_data An [`epiprocess::epi_df`][epiprocess::as_epi_df]
1616
#' @param outcome A scalar character for the column name we wish to predict.
1717
#' @param args_list A list of additional arguments as created by the
1818
#' [cdc_baseline_args_list()] constructor function.

R/data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
#' Subset of Statistics Canada median employment income for postsecondary graduates
6161
#'
62-
#' @format An [epiprocess::epi_df] with 10193 rows and 8 variables:
62+
#' @format An [epiprocess::epi_df][epiprocess::as_epi_df] with 10193 rows and 8 variables:
6363
#' \describe{
6464
#' \item{geo_value}{The province in Canada associated with each
6565
#' row of measurements.}

R/epi_recipe.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ is_epi_recipe <- function(x) {
245245
#' @details
246246
#' `add_epi_recipe` has the same behaviour as
247247
#' [workflows::add_recipe()] but sets a different
248-
#' default blueprint to automatically handle [epiprocess::epi_df] data.
248+
#' default blueprint to automatically handle [epiprocess::epi_df][epiprocess::as_epi_df] data.
249249
#'
250250
#' @param x A `workflow` or `epi_workflow`
251251
#'
@@ -572,9 +572,13 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
572572
}
573573
new_data <- NextMethod("bake")
574574
if (!is.null(meta)) {
575+
# Baking should have dropped epi_df-ness and metadata. Re-infer some
576+
# metadata and assume others remain the same as the object/template:
575577
new_data <- as_epi_df(
576-
new_data, meta$geo_type, meta$time_type, meta$as_of,
577-
meta$additional_metadata %||% list()
578+
new_data,
579+
as_of = meta$as_of,
580+
# avoid NULL if meta is from saved older epi_df:
581+
additional_metadata = meta$additional_metadata %||% list()
578582
)
579583
}
580584
new_data

R/epi_workflow.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
119119
#' - Call [parsnip::predict.model_fit()] for you using the underlying fit
120120
#' parsnip model.
121121
#'
122-
#' - Ensure that the returned object is an [epiprocess::epi_df] where
122+
#' - Ensure that the returned object is an [epiprocess::epi_df][epiprocess::as_epi_df] where
123123
#' possible. Specifically, the output will have `time_value` and
124124
#' `geo_value` columns as well as the prediction.
125125
#'
126-
#' @inheritParams parsnip::predict.model_fit
127-
#'
128126
#' @param object An epi_workflow that has been fit by
129127
#' [workflows::fit.workflow()]
130128
#'
131129
#' @param new_data A data frame containing the new predictors to preprocess
132130
#' and predict on
133131
#'
132+
#' @inheritParams parsnip::predict.model_fit
133+
#'
134134
#' @return
135135
#' A data frame of model predictions, with as many rows as `new_data` has.
136136
#' If `new_data` is an `epi_df` or a data frame with `time_value` or
@@ -152,7 +152,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
152152
#'
153153
#' preds <- predict(wf, latest)
154154
#' preds
155-
predict.epi_workflow <- function(object, new_data, ...) {
155+
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) {
156156
if (!workflows::is_trained_workflow(object)) {
157157
cli::cli_abort(c(
158158
"Can't predict on an untrained epi_workflow.",
@@ -168,7 +168,7 @@ predict.epi_workflow <- function(object, new_data, ...) {
168168
components$forged,
169169
components$mold, new_data
170170
)
171-
components <- apply_frosting(object, components, new_data, ...)
171+
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
172172
components$predictions
173173
}
174174

R/epipredict-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#' @importFrom cli cli_abort
77
#' @importFrom checkmate assert assert_character assert_int assert_scalar
88
#' assert_logical assert_numeric assert_number assert_integer
9-
#' assert_integerish assert_date assert_function
9+
#' assert_integerish assert_date assert_function assert_class
1010
#' @import epiprocess parsnip
1111
## usethis namespace: end
1212
NULL

R/flatline_forecaster.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Predict the future with today's value
22
#'
33
#' This is a simple forecasting model for
4-
#' [epiprocess::epi_df] data. It uses the most recent observation as the
4+
#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the
55
#' forcast for any future date, and produces intervals based on the quantiles
66
#' of the residuals of such a "flatline" forecast over all available training
77
#' data.
@@ -13,7 +13,7 @@
1313
#' This forecaster is very similar to that used by the
1414
#' [COVID19ForecastHub](https://covid19forecasthub.org)
1515
#'
16-
#' @param epi_data An [epiprocess::epi_df]
16+
#' @param epi_data An [epiprocess::epi_df][epiprocess::as_epi_df]
1717
#' @param outcome A scalar character for the column name we wish to predict.
1818
#' @param args_list A list of dditional arguments as created by the
1919
#' [flatline_args_list()] constructor function.

R/frosting.R

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,11 @@ apply_frosting.default <- function(workflow, components, ...) {
355355
#' @rdname apply_frosting
356356
#' @importFrom rlang is_null
357357
#' @importFrom rlang abort
358+
#' @param type,opts forwarded (along with `...`) to [`predict.model_fit()`] and
359+
#' [`slather()`] for supported layers
358360
#' @export
359361
apply_frosting.epi_workflow <-
360-
function(workflow, components, new_data, ...) {
362+
function(workflow, components, new_data, type = NULL, opts = list(), ...) {
361363
the_fit <- workflows::extract_fit_parsnip(workflow)
362364

363365
if (!has_postprocessor(workflow)) {
@@ -376,7 +378,7 @@ apply_frosting.epi_workflow <-
376378
"Returning unpostprocessed predictions."
377379
))
378380
components$predictions <- predict(
379-
the_fit, components$forged$predictors, ...
381+
the_fit, components$forged$predictors, type, opts, ...
380382
)
381383
components$predictions <- dplyr::bind_cols(
382384
components$keys, components$predictions
@@ -397,10 +399,28 @@ apply_frosting.epi_workflow <-
397399
layers
398400
)
399401
}
402+
if (length(layers) > 1L &&
403+
(!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) {
404+
cli_abort("
405+
Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not
406+
supported if you have frosting layers other than `layer_predict`. Please
407+
provide these arguments earlier (i.e. while constructing the frosting
408+
object) by passing them into an explicit call to `layer_predict(), and
409+
adjust the remaining layers to account for resulting differences in
410+
output format from these settings.
411+
", class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers")
412+
}
400413

401414
for (l in seq_along(layers)) {
402415
la <- layers[[l]]
403-
components <- slather(la, components, workflow, new_data)
416+
if (inherits(la, "layer_predict")) {
417+
components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...)
418+
} else {
419+
# The check above should ensure we have default `type` and `opts`, and
420+
# empty `...`; don't forward these default `type` and `opts`, to avoid
421+
# upsetting some slather method validation.
422+
components <- slather(la, components, workflow, new_data)
423+
}
404424
}
405425

406426
return(components)

R/get_test_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Get test data for prediction based on longest lag period
22
#'
33
#' Based on the longest lag period in the recipe,
4-
#' `get_test_data()` creates an [epi_df]
4+
#' `get_test_data()` creates an [epi_df][epiprocess::as_epi_df]
55
#' with columns `geo_value`, `time_value`
66
#' and other variables in the original dataset,
77
#' which will be used to create features necessary to produce forecasts.

R/layer_add_forecast_date.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) {
8686

8787
#' @export
8888
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
89+
rlang::check_dots_empty()
8990
if (is.null(object$forecast_date)) {
9091
max_time_value <- as.Date(max(
9192
workflows::extract_preprocessor(workflow)$max_time_value,

R/layer_naomit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ layer_naomit_new <- function(terms, id) {
4545

4646
#' @export
4747
slather.layer_naomit <- function(object, components, workflow, new_data, ...) {
48+
rlang::check_dots_empty()
4849
exprs <- rlang::expr(c(!!!object$terms))
4950
pos <- tidyselect::eval_select(exprs, components$predictions)
5051
col_names <- names(pos)

R/layer_point_from_distn.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ layer_point_from_distn_new <- function(type, name, id) {
7676
#' @export
7777
slather.layer_point_from_distn <-
7878
function(object, components, workflow, new_data, ...) {
79-
rlang::check_dots_empty()
8079
dstn <- components$predictions$.pred
8180
if (!inherits(dstn, "distribution")) {
8281
rlang::warn(
@@ -86,6 +85,7 @@ slather.layer_point_from_distn <-
8685
)
8786
return(components)
8887
}
88+
rlang::check_dots_empty()
8989

9090
dstn <- match.fun(object$type)(dstn)
9191
if (is.null(object$name)) {

R/layer_population_scaling.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ layer_population_scaling_new <-
128128
#' @export
129129
slather.layer_population_scaling <-
130130
function(object, components, workflow, new_data, ...) {
131-
rlang::check_dots_empty()
132131
stopifnot(
133132
"Only one population column allowed for scaling" =
134133
length(object$df_pop_col) == 1
135134
)
135+
rlang::check_dots_empty()
136136

137137
if (is.null(object$by)) {
138138
object$by <- intersect(

R/layer_predict.R

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,19 @@ layer_predict <-
4545
id = rand_id("predict_default")) {
4646
arg_is_chr_scalar(id)
4747
arg_is_chr_scalar(type, allow_null = TRUE)
48+
assert_class(opts, "list")
49+
dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE)
50+
if (any(rlang::names2(dots_list) == "")) {
51+
cli_abort("All `...` arguments must be named.",
52+
class = "epipredict__layer_predict__unnamed_dot"
53+
)
54+
}
4855
add_layer(
4956
frosting,
5057
layer_predict_new(
5158
type = type,
5259
opts = opts,
53-
dots_list = rlang::list2(...), # can't figure how to use this
60+
dots_list = dots_list,
5461
id = id
5562
)
5663
)
@@ -62,14 +69,27 @@ layer_predict_new <- function(type, opts, dots_list, id) {
6269
}
6370

6471
#' @export
65-
slather.layer_predict <- function(object, components, workflow, new_data, ...) {
72+
slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) {
73+
arg_is_chr_scalar(type, allow_null = TRUE)
74+
if (!is.null(object$type) && !is.null(type) && !identical(object$type, type)) {
75+
cli_abort("
76+
Conflicting `type` settings were specified during frosting construction
77+
(in call to `layer_predict()`) and while slathering (in call to
78+
`slather()`/ `predict()`/etc.): {object$type} vs. {type}. Please remove
79+
one of these `type` settings.
80+
", class = "epipredict__layer_predict__conflicting_type_settings")
81+
}
82+
assert_class(opts, "list")
83+
6684
the_fit <- workflows::extract_fit_parsnip(workflow)
6785

68-
components$predictions <- predict(
86+
components$predictions <- rlang::inject(predict(
6987
the_fit,
7088
components$forged$predictors,
71-
type = object$type, opts = object$opts
72-
)
89+
type = object$type %||% type,
90+
opts = c(object$opts, opts),
91+
!!!object$dots_list, ...
92+
))
7393
components$predictions <- dplyr::bind_cols(
7494
components$keys, components$predictions
7595
)

R/layer_predictive_distn.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) {
7373
slather.layer_predictive_distn <-
7474
function(object, components, workflow, new_data, ...) {
7575
the_fit <- workflows::extract_fit_parsnip(workflow)
76+
rlang::check_dots_empty()
7677

7778
m <- components$predictions$.pred
7879
r <- grab_residuals(the_fit, components)

R/layer_quantile_distn.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ slather.layer_quantile_distn <-
7979
"These are of class {.cls {class(dstn)}}."
8080
))
8181
}
82+
rlang::check_dots_empty()
83+
8284
dstn <- dist_quantiles(
8385
quantile(dstn, object$quantile_levels),
8486
object$quantile_levels

R/layer_residual_quantiles.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ layer_residual_quantiles_new <- function(
7575
#' @export
7676
slather.layer_residual_quantiles <-
7777
function(object, components, workflow, new_data, ...) {
78+
rlang::check_dots_empty()
79+
7880
the_fit <- workflows::extract_fit_parsnip(workflow)
7981

8082
if (is.null(object$quantile_levels)) {

R/layer_threshold_preds.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) {
9898
#' @export
9999
slather.layer_threshold <-
100100
function(object, components, workflow, new_data, ...) {
101+
rlang::check_dots_empty()
101102
exprs <- rlang::expr(c(!!!object$terms))
102103
pos <- tidyselect::eval_select(exprs, components$predictions)
103104
col_names <- names(pos)

R/layer_unnest.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ layer_unnest_new <- function(terms, id) {
2828
#' @export
2929
slather.layer_unnest <-
3030
function(object, components, workflow, new_data, ...) {
31+
rlang::check_dots_empty()
3132
exprs <- rlang::expr(c(!!!object$terms))
3233
pos <- tidyselect::eval_select(exprs, components$predictions)
3334
col_names <- names(pos)

inst/templates/layer.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) {
2929
#' @export
3030
slather.layer_{{{ name }}} <-
3131
function(object, components, workflow, new_data, ...) {
32+
rlang::check_dots_empty()
3233

3334
# if layer_ used ... in tidyselect, we need to evaluate it now
3435
exprs <- rlang::expr(c(!!!object$terms))

man/apply_frosting.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_test_data.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)