Skip to content

Commit ce3b19d

Browse files
committed
feat: add yeo-johnson
1 parent 7cd135f commit ce3b19d

File tree

5 files changed

+824
-1
lines changed

5 files changed

+824
-1
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#' Unormalizing transformation
2+
#'
3+
#' Will undo a step_epi_YeoJohnson transformation.
4+
#'
5+
#' @param frosting a `frosting` postprocessor. The layer will be added to the
6+
#' sequence of operations for this frosting.
7+
#' @param ... One or more selector functions to scale variables
8+
#' for this step. See [recipes::selections()] for more details.
9+
#' @param df a data frame that contains the population data to be used for
10+
#' inverting the existing scaling.
11+
#' @param by A (possibly named) character vector of variables to join by.
12+
#' @param id a random id string
13+
#'
14+
#' @return an updated `frosting` postprocessor
15+
#' @export
16+
#' @examples
17+
#' library(dplyr)
18+
#' jhu <- epidatasets::cases_deaths_subset %>%
19+
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
20+
#' select(geo_value, time_value, cases)
21+
#'
22+
#' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000))
23+
#'
24+
#' r <- epi_recipe(jhu) %>%
25+
#' step_epi_YeoJohnson(
26+
#' df = pop_data,
27+
#' df_pop_col = "value",
28+
#' by = c("geo_value" = "states"),
29+
#' cases, suffix = "_scaled"
30+
#' ) %>%
31+
#' step_epi_lag(cases_scaled, lag = c(0, 7, 14)) %>%
32+
#' step_epi_ahead(cases_scaled, ahead = 7, role = "outcome") %>%
33+
#' step_epi_naomit()
34+
#'
35+
#' f <- frosting() %>%
36+
#' layer_predict() %>%
37+
#' layer_threshold(.pred) %>%
38+
#' layer_naomit(.pred) %>%
39+
#' layer_epi_YeoJohnson(.pred,
40+
#' df = pop_data,
41+
#' by = c("geo_value" = "states"),
42+
#' df_pop_col = "value"
43+
#' )
44+
#'
45+
#' wf <- epi_workflow(r, linear_reg()) %>%
46+
#' fit(jhu) %>%
47+
#' add_frosting(f)
48+
#'
49+
#' forecast(wf)
50+
layer_epi_YeoJohnson <- function(frosting, ..., lambdas = NULL, by = NULL, id = rand_id("epi_YeoJohnson")) {
51+
checkmate::assert_tibble(lambdas, min.rows = 1, null.ok = TRUE)
52+
53+
add_layer(
54+
frosting,
55+
layer_epi_YeoJohnson_new(
56+
lambdas = lambdas,
57+
by = by,
58+
terms = dplyr::enquos(...),
59+
id = id
60+
)
61+
)
62+
}
63+
64+
layer_epi_YeoJohnson_new <- function(lambdas, by, terms, id) {
65+
layer("epi_YeoJohnson", lambdas = lambdas, by = by, terms = terms, id = id)
66+
}
67+
68+
#' @export
69+
#' @importFrom workflows extract_preprocessor
70+
slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data, ...) {
71+
rlang::check_dots_empty()
72+
73+
74+
# Get the lambdas from the layer or from the workflow.
75+
lambdas <- object$lambdas %||% get_lambdas_in_layer(workflow)
76+
77+
# If the by is not specified, try to infer it from the lambdas.
78+
if (is.null(object$by)) {
79+
# Assume `layer_predict` has calculated the prediction keys and other
80+
# layers don't change the prediction key colnames:
81+
prediction_key_colnames <- names(components$keys)
82+
lhs_potential_keys <- prediction_key_colnames
83+
rhs_potential_keys <- colnames(select(lambdas, -starts_with("lambda_")))
84+
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
85+
suggested_min_keys <- setdiff(lhs_potential_keys, "time_value")
86+
if (!all(suggested_min_keys %in% object$by)) {
87+
cli_warn(
88+
c(
89+
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
90+
but {?wasn't/weren't} found in the population `df`.",
91+
"i" = "Defaulting to join by {object$by}",
92+
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
93+
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
94+
">" = "Manually specify `by =` to silence"
95+
),
96+
class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys"
97+
)
98+
}
99+
}
100+
101+
# Establish the join columns.
102+
object$by <- object$by %||%
103+
intersect(
104+
epipredict:::epi_keys_only(components$predictions),
105+
colnames(select(lambdas, -starts_with("lambda_")))
106+
)
107+
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
108+
hardhat::validate_column_names(components$predictions, joinby$x)
109+
hardhat::validate_column_names(lambdas, joinby$y)
110+
111+
# TODO: We don't do multiple outcomes, do we? Assume not for now.
112+
# Get the columns to transform. In components$predictions, the output is
113+
# .pred, so col_names should just be ".pred".
114+
exprs <- rlang::expr(c(!!!object$terms))
115+
pos <- tidyselect::eval_select(exprs, components$predictions)
116+
col_names <- names(pos)
117+
118+
# Get the outcome. `outcomes` is a vector of objects like ahead_1_cases,
119+
# ahead_7_cases, etc. We want to extract the cases part.
120+
outcome_col <- names(components$mold$outcomes) %>%
121+
stringr::str_extract("(?<=_)[^_]+$") %>%
122+
unique() %>%
123+
extract(1)
124+
125+
# Join the lambdas.
126+
components$predictions <- inner_join(
127+
components$predictions,
128+
lambdas,
129+
by = object$by,
130+
relationship = "many-to-one",
131+
unmatched = c("error", "drop")
132+
)
133+
# For every column, we need to use the appropriate lambda column, which differs per row.
134+
# Note that yj_inverse() is vectorized.
135+
for (col in col_names) {
136+
components$predictions <- components$predictions %>%
137+
rowwise() %>%
138+
mutate(!!col := yj_inverse(!!sym(col), !!sym(paste0("lambda_", outcome_col))))
139+
}
140+
# Remove the lambda columns.
141+
components$predictions <- components$predictions %>%
142+
select(-any_of(starts_with("lambda_")))
143+
components
144+
}
145+
146+
#' @export
147+
print.layer_epi_YeoJohnson <- function(x, width = max(20, options()$width - 30), ...) {
148+
title <- "Yeo-Johnson transformation (see `lambdas` object for values) on "
149+
epipredict:::print_layer(x$terms, title = title, width = width)
150+
}
151+
152+
#' Inverse Yeo-Johnson transformation
153+
#'
154+
#' Inverse of `yj_transform` in step_yeo_johnson.R.
155+
#'
156+
#' @keywords internal
157+
yj_inverse <- function(x, lambda, eps = 0.001) {
158+
if (is.na(lambda)) {
159+
return(x)
160+
}
161+
if (!inherits(x, "tbl_df") || is.data.frame(x)) {
162+
x <- unlist(x, use.names = FALSE)
163+
} else {
164+
if (!is.vector(x)) {
165+
x <- as.vector(x)
166+
}
167+
}
168+
169+
dat_neg <- x < 0
170+
ind_neg <- list(is = which(dat_neg), not = which(!dat_neg))
171+
not_neg <- ind_neg[["not"]]
172+
is_neg <- ind_neg[["is"]]
173+
174+
nn_inv_trans <- function(x, lambda) {
175+
if (abs(lambda) < eps) {
176+
# log(x + 1)
177+
exp(x) - 1
178+
} else {
179+
# ((x + 1)^lambda - 1) / lambda
180+
(lambda * x + 1)^(1 / lambda) - 1
181+
}
182+
}
183+
184+
ng_inv_trans <- function(x, lambda) {
185+
if (abs(lambda - 2) < eps) {
186+
# -log(-x + 1)
187+
-(exp(-x) - 1)
188+
} else {
189+
# -((-x + 1)^(2 - lambda) - 1) / (2 - lambda)
190+
-(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
191+
}
192+
}
193+
194+
if (length(not_neg) > 0) {
195+
x[not_neg] <- nn_inv_trans(x[not_neg], lambda)
196+
}
197+
198+
if (length(is_neg) > 0) {
199+
x[is_neg] <- ng_inv_trans(x[is_neg], lambda)
200+
}
201+
x
202+
}
203+
204+
get_lambdas_in_layer <- function(workflow) {
205+
this_recipe <- hardhat::extract_recipe(workflow)
206+
if (!(this_recipe %>% recipes::detect_step("epi_YeoJohnson"))) {
207+
cli_abort("`layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe.", call = rlang::caller_env())
208+
}
209+
for (step in this_recipe$steps) {
210+
if (inherits(step, "step_epi_YeoJohnson")) {
211+
lambdas <- step$lambdas
212+
break
213+
}
214+
}
215+
lambdas
216+
}

0 commit comments

Comments
 (0)