Skip to content

Commit 24972d2

Browse files
committed
wip
1 parent a1b8b5f commit 24972d2

11 files changed

+84
-95
lines changed

R/new_epipredict_steps/step_yeo_johnson.R

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#' `step_YeoJohnson2()` creates a *specification* of a recipe step that will
44
#' transform data using a Yeo-Johnson transformation. This fork works with panel
55
#' data and is meant for epidata.
6+
#' TODO: Do an edit pass on this docstring.
67
#'
78
#' @inheritParams step_center
89
#' @param lambdas A numeric vector of transformation values. This
@@ -69,11 +70,21 @@
6970
#' tidy(yj_transform, number = 1)
7071
#' tidy(yj_estimates, number = 1)
7172
step_YeoJohnson2 <-
72-
function(recipe, ..., role = NA, trained = FALSE,
73-
lambdas = NULL, na_lambda_fill = 1 / 4, limits = c(-5, 5), num_unique = 5,
74-
na_rm = TRUE,
75-
skip = FALSE,
76-
id = rand_id("YeoJohnson2")) {
73+
function(
74+
recipe,
75+
...,
76+
role = NA,
77+
trained = FALSE,
78+
lambdas = NULL,
79+
na_lambda_fill = 1 / 4,
80+
limits = c(-5, 5),
81+
num_unique = 5,
82+
na_rm = TRUE,
83+
skip = FALSE,
84+
id = rand_id("YeoJohnson2")
85+
) {
86+
# TODO: Add arg validations.
87+
# TODO: Improve arg names.
7788
add_step(
7889
recipe,
7990
step_YeoJohnson2_new(
@@ -115,17 +126,18 @@ prep.step_YeoJohnson2 <- function(x, training, info = NULL, ...) {
115126
recipes:::check_number_whole(x$num_unique, args = "num_unique")
116127
recipes:::check_bool(x$na_rm, arg = "na_rm")
117128
if (!is.numeric(x$limits) || any(is.na(x$limits)) || length(x$limits) != 2) {
118-
cli::cli_abort("{.arg limits} should be a numeric vector with two values,
119-
not {.obj_type_friendly {x$limits}}")
129+
cli::cli_abort(
130+
"{.arg limits} should be a numeric vector with two values,
131+
not {.obj_type_friendly {x$limits}}"
132+
)
120133
}
121134

122-
x$limits <- sort(x$limits)
123-
124135
values <- training %>%
125-
group_by(geo_value) %>%
126-
summarise(across(all_of(col_names), ~ estimate_yj(.x, x$limits, x$num_unique, x$na_rm))) %>%
127-
ungroup() %>%
128-
rename_with(~ paste0("lambda_", .x), -geo_value)
136+
summarise(
137+
across(all_of(col_names), ~ estimate_yj(.x, x$limits, x$num_unique, x$na_rm)),
138+
.by = key_colnames(training, exclude = "time_value")
139+
) %>%
140+
rename_with(~ paste0("lambda_", .x), -all_of(key_colnames(training, exclude = "time_value")))
129141

130142
# Check for NAs in any of the lambda_ columns
131143
for (col in col_names) {
@@ -137,17 +149,12 @@ prep.step_YeoJohnson2 <- function(x, training, info = NULL, ...) {
137149
),
138150
call = rlang::caller_fn()
139151
)
140-
values <- values %>%
141-
mutate(
142-
!!sym(paste0("lambda_", col)) := ifelse(
143-
is.na(!!sym(paste0("lambda_", col))),
144-
x$na_lambda_fill,
145-
!!sym(paste0("lambda_", col))
146-
)
147-
)
148152
}
149153
}
150154

155+
values <- values %>%
156+
mutate(across(starts_with("lambda_"), \(col) ifelse(is.na(col), x$na_lambda_fill, col)))
157+
151158
step_YeoJohnson2_new(
152159
terms = x$terms,
153160
role = x$role,
@@ -168,11 +175,12 @@ bake.step_YeoJohnson2 <- function(object, new_data, ...) {
168175
col_names <- object$terms %>% purrr::map_chr(rlang::as_name)
169176
check_new_data(col_names, object, new_data)
170177

171-
new_data %<>% left_join(object$lambdas, by = "geo_value")
178+
new_data %<>% left_join(object$lambdas, by = key_colnames(new_data, exclude = "time_value"))
172179
for (col in col_names) {
173180
new_data <- new_data %>%
174181
rowwise() %>%
175182
mutate(!!col := yj_transform(!!sym(col), !!sym(paste0("lambda_", col))))
183+
# mutate(across(col_names, ~ yj_transform(.x, !!sym(paste0("lambda_", .x)))))
176184
}
177185
new_data %>%
178186
select(-starts_with("lambda_")) %>%
@@ -260,11 +268,7 @@ yj_obj <- function(lam, dat, ind_neg, const) {
260268
#' @keywords internal
261269
#' @rdname recipes-internal
262270
#' @export
263-
estimate_yj <- function(dat,
264-
limits = c(-5, 5),
265-
num_unique = 5,
266-
na_rm = TRUE,
267-
call = caller_env(2)) {
271+
estimate_yj <- function(dat, limits = c(-5, 5), num_unique = 5, na_rm = TRUE, call = caller_env(2)) {
268272
na_rows <- which(is.na(dat))
269273
if (length(na_rows) > 0) {
270274
if (na_rm) {
@@ -305,7 +309,7 @@ estimate_yj <- function(dat,
305309
lam
306310
}
307311

308-
309-
#' @rdname tidy.recipe
310-
#' @export
311-
tidy.step_YeoJohnson2 <- tidy.step_BoxCox2
312+
# #
313+
# #' @rdname tidy.recipe
314+
# #' @export
315+
# tidy.step_YeoJohnson2 <- tidy.step_BoxCox2

renv/activate.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ local({
135135

136136
# R help links
137137
pattern <- "`\\?(renv::(?:[^`])+)`"
138-
replacement <- "`\033]8;;x-r-help:\\1\a?\\1\033]8;;\a`"
138+
replacement <- "`\033]8;;ide:help:\\1\a?\\1\033]8;;\a`"
139139
text <- gsub(pattern, replacement, text, perl = TRUE)
140140

141141
# runnable code
142142
pattern <- "`(renv::(?:[^`])+)`"
143-
replacement <- "`\033]8;;x-r-run:\\1\a\\1\033]8;;\a`"
143+
replacement <- "`\033]8;;ide:run:\\1\a\\1\033]8;;\a`"
144144
text <- gsub(pattern, replacement, text, perl = TRUE)
145145

146146
# return ansified text

tests/testthat/test-daily-weekly-archive.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
# Works correctly if you have exactly one version where the previous Friday data
44
# is the latest so it is ignored and the week before THAT is summed (10-27 to

tests/testthat/test-data-whitening.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22
real_ex <- epidatasets::covid_case_death_rates %>%
33
as_tibble() %>%
44
mutate(source = "same") %>%

tests/testthat/test-forecaster-utils.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
test_that("sanitize_args_predictors_trainer", {
44
epi_data <- epidatasets::covid_case_death_rates

tests/testthat/test-forecasters-basics.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22
testthat::local_edition(3)
33
# TODO better way to do this than copypasta
44
forecasters <- list(

tests/testthat/test-forecasters-data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
testthat::skip("Optional, long-running tests skipped.")
44

tests/testthat/test-latency_adjusting.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
test_that("extend_ahead", {
44
# testing that POSIXct converts correctly (as well as basic types)

tests/testthat/test-step-training-window.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
data <- tribble(
44
~geo_value, ~time_value, ~version, ~value,

tests/testthat/test-transforms.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
n_days <- 20
44
removed_date <- 10

tests/testthat/test-yeo-johnson.R

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,58 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

3-
data <- tribble(
4-
~geo_value, ~time_value, ~version, ~value1,
5-
"us", "2024-11-08", "2024-11-13", 1,
6-
"us", "2024-11-07", "2024-11-13", 2,
7-
"us", "2024-11-06", "2024-11-13", 3,
8-
"us", "2024-11-05", "2024-11-13", 4,
9-
"us", "2024-11-04", "2024-11-13", 5,
10-
"us", "2024-11-03", "2024-11-13", 6,
11-
"us", "2024-11-02", "2024-11-13", 7,
12-
"us", "2024-11-01", "2024-11-13", 8,
13-
"us", "2024-10-31", "2024-11-13", 9,
14-
"us", "2024-10-30", "2024-11-13", 10,
15-
"us", "2024-10-29", "2024-11-13", 11,
16-
"us", "2024-10-28", "2024-11-13", 12,
17-
"us", "2024-10-27", "2024-11-13", 13
18-
) %>%
19-
mutate(value2 = value1 * 11) %>%
20-
bind_rows((.) %>% mutate(geo_value = "ca", value1 = value1 * 3 + 1, value2 = value2 + 50)) %>%
21-
mutate(time_value = as.Date(time_value), version = as.Date(version)) %>%
22-
as_epi_df()
233

24-
r <- epi_recipe(data) %>%
25-
step_YeoJohnson2(value1, value2) %>%
26-
prep(data)
27-
r
28-
r$steps[[1]]$lambdas
29-
outcome <- r %>% bake(data)
30-
31-
httpgd::hgd()
32-
data %>%
33-
pivot_longer(c(value1, value2), names_to = "variable", values_to = "value") %>%
34-
ggplot(aes(time_value, value, color = variable)) +
35-
geom_line() +
36-
geom_line(
37-
data = outcome %>% pivot_longer(c(value1, value2), names_to = "variable", values_to = "value"),
38-
aes(time_value, value, color = variable),
39-
) +
40-
facet_wrap(~geo_value, scales = "free_y") +
41-
theme_minimal() +
42-
labs(title = "Yeo-Johnson transformation", x = "Time", y = "Value")
4+
# Real data test
5+
Sys.setenv(TAR_PROJECT = "flu_hosp_explore")
436

447

8+
# Transform with Yeo-Johnson
459
data <- tar_read(joined_archive_data) %>%
46-
epix_as_of(as.Date("2023-11-01")) %>%
47-
filter(source == "nhsn") %>%
48-
rename(value = hhs)
49-
r <- epi_recipe(data) %>%
50-
step_YeoJohnson2(value) %>%
51-
prep(data)
10+
epix_as_of(as.Date("2023-11-01"))
11+
state_geo_values <- data %>% filter(source == "nhsn") %>% pull(geo_value) %>% unique()
12+
filtered_data <- data %>%
13+
filter(geo_value %in% state_geo_values) %>%
14+
select(geo_value, source, time_value, hhs)
15+
r <- epi_recipe(filtered_data) %>%
16+
step_YeoJohnson2(hhs) %>%
17+
prep(filtered_data)
5218
r
19+
# Inspect the lambda values (a few states have default lambda = 0.25, because
20+
# they have issues)
5321
r$steps[[1]]$lambdas %>% print(n = 55)
54-
outcome <- r %>% bake(data)
22+
out1 <- r %>% bake(filtered_data)
5523

56-
httpgd::hgd()
57-
data %>%
58-
ggplot(aes(time_value, value)) +
59-
geom_line(color = "blue") +
60-
geom_line(data = outcome, aes(time_value, value), color = "green") +
24+
# Transform with manual whitening (quarter root scaling)
25+
# learned_params <- calculate_whitening_params(filtered_data, "hhs", scale_method = "none", center_method = "none", nonlin_method = "quart_root")
26+
out2 <- filtered_data %>%
27+
mutate(hhs = (hhs + 0.01)^(1 / 4))
28+
29+
out1 %>%
30+
left_join(out2, by = c("geo_value", "source", "time_value")) %>%
31+
mutate(hhs_diff = hhs.x - hhs.y) %>%
32+
ggplot(aes(time_value, hhs_diff)) +
33+
geom_line() +
6134
facet_wrap(~geo_value, scales = "free_y") +
6235
theme_minimal() +
63-
labs(title = "Yeo-Johnson transformation", x = "Time", y = "Value")
36+
labs(title = "Yeo-Johnson transformation", x = "Time", y = "HHS")
37+
38+
# Plot the real data before and after transformation
39+
geo_filter <- "ca"
40+
filtered_data %>%
41+
filter(geo_value == geo_filter, source == "nhsn") %>%
42+
mutate(hhs = log(hhs)) %>%
43+
ggplot(aes(time_value, hhs)) +
44+
geom_line(color = "blue") +
45+
geom_line(data = out1 %>% filter(geo_value == geo_filter, source == "nhsn") %>% mutate(hhs = log(hhs)), aes(time_value, hhs), color = "green") +
46+
geom_line(data = out2 %>% filter(geo_value == geo_filter, source == "nhsn") %>% mutate(hhs = log(hhs)), aes(time_value, hhs), color = "red") +
47+
theme_minimal() +
48+
labs(title = "Yeo-Johnson transformation", x = "Time", y = "HHS")
6449

6550

6651
# TODO: Test this.
6752
## Layer Yeo-Johnson2
6853
postproc <- frosting() %>%
6954
layer_YeoJohnson2()
7055

71-
wf <- epi_workflow(r, linear_reg()) %>%
56+
wf <- epi_workflow(r) %>%
7257
fit(data) %>%
7358
add_frosting(postproc)

0 commit comments

Comments
 (0)