Skip to content

Commit 7d22ef8

Browse files
committed
first draft of epi_slide in step_epi_slide
1 parent 3fec57f commit 7d22ef8

File tree

5 files changed

+146
-52
lines changed

5 files changed

+146
-52
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,12 @@ importFrom(checkmate,assert_scalar)
226226
importFrom(cli,cli_abort)
227227
importFrom(dplyr,across)
228228
importFrom(dplyr,all_of)
229+
importFrom(dplyr,bind_cols)
229230
importFrom(dplyr,group_by)
230231
importFrom(dplyr,n)
231232
importFrom(dplyr,summarise)
232233
importFrom(dplyr,ungroup)
234+
importFrom(epiprocess,epi_slide)
233235
importFrom(epiprocess,growth_rate)
234236
importFrom(generics,augment)
235237
importFrom(generics,fit)
@@ -269,6 +271,7 @@ importFrom(stats,qnorm)
269271
importFrom(stats,quantile)
270272
importFrom(stats,residuals)
271273
importFrom(tibble,tibble)
274+
importFrom(tidyr,crossing)
272275
importFrom(tidyr,drop_na)
273276
importFrom(vctrs,as_list_of)
274277
importFrom(vctrs,field)

R/step_epi_slide.R

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -166,30 +166,71 @@ bake.step_epi_slide <- function(object, new_data, ...) {
166166
c("In `step_epi_slide()` a name collision occurred. The following variable names already exist:",
167167
`*` = "{.var {nms}}"
168168
),
169-
call = caller_env()
169+
call = caller_env(),
170+
class = "epipredict__step__name_collision_error"
170171
)
171172
}
172-
173-
ok <- object$keys
174-
names(col_names) <- newnames
175-
gr <- new_data %>%
176-
dplyr::select(dplyr::all_of(c(ok, object$columns))) %>%
177-
group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
178-
dplyr::arrange(time_value) %>%
179-
dplyr::mutate(
180-
dplyr::across(
181-
dplyr::all_of(object$columns),
182-
~ slider::slide_index_vec(
183-
.x,
184-
.i = time_value,
185-
object$.f, .before = object$before, .after = object$after
186-
)
187-
)
173+
if (any(vapply(c(mean, sum), \(x) identical(x, object$.f), logical(1L)))) {
174+
cli_warn(
175+
c("There is an optimized version of both mean and sum. See `step_epi_slide_mean`, `step_epi_slide_sum`, or `step_epi_slide_opt`."
176+
),
177+
class = "epipredict__step_epi_slide__optimized_version"
178+
)
179+
}
180+
epi_slide_wrapper(
181+
new_data,
182+
object$before,
183+
object$after,
184+
object$columns,
185+
c(object$.f),
186+
object$f_name,
187+
object$keys[-1],
188+
object$prefix
189+
)
190+
}
191+
#' wrapper to handle epi_slide particulars
192+
#' @description
193+
#' This should simplify somewhat in the future when we can run `epi_slide` on
194+
#' columns. Surprisingly, lapply is several orders of magnitude faster than
195+
#' using roughly equivalent tidy select style.
196+
#' @param fns vector of functions, even if it's length 1.
197+
#' @param group_keys the keys to group by. likely epi_keys[-1] (to remove time_value)
198+
#' @importFrom tidyr crossing
199+
#' @importFrom dplyr bind_cols group_by ungroup
200+
#' @importFrom epiprocess epi_slide
201+
#' @keywords internal
202+
epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, group_keys, name_prefix) {
203+
cols_fns <- tidyr::crossing(col_name = columns, fn_name = fn_names, fn = fns)
204+
seq_len(nrow(cols_fns)) %>%
205+
lapply( # iterate over the rows of cols_fns
206+
# takes in the row number, outputs the transformed column
207+
function(comp_i) {
208+
# extract values from the row
209+
col_name <- cols_fns[[comp_i, "col_name"]]
210+
fn_name <- cols_fns[[comp_i, "fn_name"]]
211+
fn <- cols_fns[[comp_i, "fn"]][[1L]]
212+
result_name <- paste(name_prefix, fn_name, col_name, sep="_")
213+
result <- new_data %>%
214+
group_by(across(group_keys)) %>%
215+
epi_slide(
216+
before = before,
217+
after = after,
218+
new_col_name = result_name,
219+
f = function(slice, geo_key, ref_time_value) {
220+
fn(slice[[col_name]])
221+
}
222+
) %>%
223+
ungroup()
224+
# the first result needs to include all of the original columns
225+
if (comp_i == 1L) {
226+
result
227+
} else {
228+
# everything else just needs that column transformed
229+
result[result_name]
230+
}
231+
}
188232
) %>%
189-
dplyr::rename(dplyr::all_of(col_names)) %>%
190-
dplyr::ungroup()
191-
192-
dplyr::left_join(new_data, gr, by = ok)
233+
bind_cols()
193234
}
194235

195236

man/epi_slide_wrapper.Rd

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/step_epi_slide.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-step_epi_slide.R

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,54 @@ rolled_after <- edf %>%
5555

5656

5757
test_that("epi_slide handles classed before/after", {
58-
baseline <- r %>%
59-
step_epi_slide(value, .f = mean, before = 3L) %>%
60-
prep(edf) %>%
61-
bake(new_data = NULL)
58+
expect_warning(
59+
baseline <- r %>%
60+
step_epi_slide(value, .f = mean, before = 3L) %>%
61+
prep(edf) %>%
62+
bake(new_data = NULL),
63+
regexp = "There is an optimized version"
64+
)
6265
expect_equal(baseline[[4]], rolled_before)
6366

64-
pbefore <- r %>%
65-
step_epi_slide(value, .f = mean, before = lubridate::period("3 days")) %>%
66-
prep(edf) %>%
67-
bake(new_data = NULL)
68-
cbefore <- r %>%
69-
step_epi_slide(value, .f = mean, before = "3 days") %>%
70-
prep(edf) %>%
71-
bake(new_data = NULL)
67+
expect_warning(
68+
pbefore <- r %>%
69+
step_epi_slide(value, .f = mean, before = lubridate::period("3 days")) %>%
70+
prep(edf) %>%
71+
bake(new_data = NULL),
72+
regexp = "There is an optimized version"
73+
)
74+
expect_warning(
75+
cbefore <- r %>%
76+
step_epi_slide(value, .f = mean, before = "3 days") %>%
77+
prep(edf) %>%
78+
bake(new_data = NULL),
79+
regexp = "There is an optimized version"
80+
)
7281
expect_equal(baseline, pbefore)
7382
expect_equal(baseline, cbefore)
7483

75-
baseline <- r %>%
76-
step_epi_slide(value, .f = mean, after = 3L) %>%
77-
prep(edf) %>%
78-
bake(new_data = NULL)
84+
expect_warning(
85+
baseline <- r %>%
86+
step_epi_slide(value, .f = mean, after = 3L) %>%
87+
prep(edf) %>%
88+
bake(new_data = NULL),
89+
regexp = "There is an optimized version"
90+
)
7991
expect_equal(baseline[[4]], rolled_after)
80-
pafter <- r %>%
81-
step_epi_slide(value, .f = mean, after = lubridate::period("3 days")) %>%
82-
prep(edf) %>%
83-
bake(new_data = NULL)
84-
cafter <- r %>%
85-
step_epi_slide(value, .f = mean, after = "3 days") %>%
86-
prep(edf) %>%
87-
bake(new_data = NULL)
92+
expect_warning(
93+
pafter <- r %>%
94+
step_epi_slide(value, .f = mean, after = lubridate::period("3 days")) %>%
95+
prep(edf) %>%
96+
bake(new_data = NULL),
97+
regexp = "There is an optimized version"
98+
)
99+
expect_warning(
100+
cafter <- r %>%
101+
step_epi_slide(value, .f = mean, after = "3 days") %>%
102+
prep(edf) %>%
103+
bake(new_data = NULL),
104+
regexp = "There is an optimized version"
105+
)
88106
expect_equal(baseline, pafter)
89107
expect_equal(baseline, cafter)
90108
})
@@ -99,10 +117,13 @@ test_that("epi_slide handles different function specs", {
99117
step_epi_slide(value, .f = mean, before = 3L) %>%
100118
prep(edf) %>%
101119
bake(new_data = NULL)
102-
lfun <- r %>%
103-
step_epi_slide(value, .f = ~ mean(.x, na.rm = TRUE), before = 3L) %>%
104-
prep(edf) %>%
105-
bake(new_data = NULL)
120+
# formula NOT currently supported
121+
expect_error(
122+
lfun <- r %>%
123+
step_epi_slide(value, .f = ~ mean(.x, na.rm = TRUE), before = 3L) %>%
124+
prep(edf) %>%
125+
bake(new_data = NULL)
126+
)
106127
blfun <- r %>%
107128
step_epi_slide(value, .f = function(x) mean(x, na.rm = TRUE), before = 3L) %>%
108129
prep(edf) %>%
@@ -114,7 +135,7 @@ test_that("epi_slide handles different function specs", {
114135

115136
expect_equal(cfun[[4]], rolled_before)
116137
expect_equal(ffun[[4]], rolled_before)
117-
expect_equal(lfun[[4]], rolled_before)
138+
#expect_equal(lfun[[4]], rolled_before)
118139
expect_equal(blfun[[4]], rolled_before)
119140
expect_equal(nblfun[[4]], rolled_before)
120141
})

0 commit comments

Comments
 (0)