Skip to content

Commit 05e688f

Browse files
authored
Merge branch 'cmu-delphi:main' into main
2 parents d3a8d9c + e8cfd8e commit 05e688f

35 files changed

+2063
-48
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
^_pkgdown\.yml$
99
^docs$
1010
^pkgdown$
11+
^musings$

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
.Ruserdata
55
docs
66
inst/doc
7+
.DS_Store

DESCRIPTION

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,44 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
33
Version: 0.0.0.9000
4-
Authors@R:
5-
c(
6-
person(given = "Jacob",
7-
family = "Bien",
8-
role = "aut"),
9-
person(given = "Daniel",
10-
family = "McDonald",
11-
role = "aut"),
12-
person(given = "Ryan",
13-
family = "Tibshirani",
14-
role = c("aut", "cre"),
15-
email = "[email protected]"))
4+
Authors@R: c(
5+
person("Jacob", "Bien", role = "aut"),
6+
person("Daniel", "McDonald", role = "aut"),
7+
person("Ryan", "Tibshirani", , "[email protected]", role = c("aut", "cre"))
8+
)
169
Description: What the package does (one paragraph).
1710
License: MIT + file LICENSE
18-
Encoding: UTF-8
19-
Roxygen: list(markdown = TRUE)
20-
RoxygenNote: 7.1.2
21-
Remotes:
22-
cmu-delphi/epiprocess#58
11+
URL: https://github.com/cmu-delphi/epipredict/,
12+
https://cmu-delphi.github.io/epiprocess
2313
Imports:
14+
assertthat,
15+
cli,
2416
dplyr,
17+
glue,
2518
magrittr,
26-
tibble,
27-
rlang,
2819
purrr,
29-
cli,
20+
recipes,
21+
rlang,
3022
stats,
23+
tibble,
3124
tidyr,
32-
assertthat,
33-
tidyselect
25+
tidyselect,
26+
tensr
3427
Suggests:
35-
epiprocess,
36-
data.table,
3728
covidcast,
29+
data.table,
30+
epiprocess,
3831
ggplot2,
3932
knitr,
4033
lubridate,
4134
RcppRoll,
4235
rmarkdown,
4336
testthat (>= 3.0.0)
37+
VignetteBuilder:
38+
knitr
39+
Remotes:
40+
dajmcdon/epiprocess
4441
Config/testthat/edition: 3
45-
URL: https://github.com/cmu-delphi/epipredict/,
46-
https://cmu-delphi.github.io/epiprocess
47-
VignetteBuilder: knitr
42+
Encoding: UTF-8
43+
Roxygen: list(markdown = TRUE)
44+
RoxygenNote: 7.2.0

NAMESPACE

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,41 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(bake,step_epi_ahead)
4+
S3method(bake,step_epi_lag)
5+
S3method(epi_keys,default)
6+
S3method(epi_keys,epi_df)
7+
S3method(epi_keys,recipe)
8+
S3method(epi_recipe,default)
9+
S3method(epi_recipe,epi_df)
10+
S3method(epi_recipe,formula)
11+
S3method(prep,step_epi_ahead)
12+
S3method(prep,step_epi_lag)
13+
S3method(print,step_epi_ahead)
14+
S3method(print,step_epi_lag)
315
export("%>%")
416
export(arx_args_list)
517
export(arx_forecaster)
618
export(create_lags_and_leads)
719
export(df_mat_mul)
20+
export(epi_keys)
21+
export(epi_recipe)
822
export(get_precision)
923
export(grab_names)
24+
export(knn_iteraive_ar_args_list)
25+
export(knn_iteraive_ar_forecaster)
26+
export(knnarx_args_list)
27+
export(knnarx_forecaster)
1028
export(smooth_arx_args_list)
1129
export(smooth_arx_forecaster)
30+
export(step_epi_ahead)
31+
export(step_epi_lag)
32+
import(recipes)
1233
importFrom(magrittr,"%>%")
1334
importFrom(rlang,"!!")
1435
importFrom(rlang,":=")
36+
importFrom(stats,as.formula)
1537
importFrom(stats,lm)
38+
importFrom(stats,model.frame)
1639
importFrom(stats,poly)
1740
importFrom(stats,predict)
1841
importFrom(stats,quantile)

R/arx_forecaster.R

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ arx_forecaster <- function(x, y, key_vars, time_value,
3737
if (intercept) dat$x0 <- 1
3838

3939
obj <- stats::lm(
40-
y1 ~ . + 0, data = dat %>% dplyr::select(starts_with(c("x","y"))))
40+
y1 ~ . + 0,
41+
data = dat %>% dplyr::select(starts_with(c("x", "y")))
42+
)
4143

4244
point <- make_predictions(obj, dat, time_value, keys)
4345

@@ -50,8 +52,9 @@ arx_forecaster <- function(x, y, key_vars, time_value,
5052
# Harder case requires handling failures of 1 and or 2, neither implemented
5153
# 1. different quantiles by key, need to bind the keys, then group_modify
5254
# 2 fails. need to bind the keys, grab, y and yhat, subtract
53-
if (nonneg)
55+
if (nonneg) {
5456
q <- dplyr::mutate(q, dplyr::across(dplyr::everything(), ~ pmax(.x, 0)))
57+
}
5558

5659
return(
5760
dplyr::bind_cols(distinct_keys, q) %>%
@@ -80,12 +83,11 @@ arx_forecaster <- function(x, y, key_vars, time_value,
8083
#' arx_args_list()
8184
#' arx_args_list(symmetrize = FALSE)
8285
#' arx_args_list(levels = c(.1, .3, .7, .9), min_train_window = 120)
83-
arx_args_list <- function(
84-
lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
85-
levels = c(0.05, 0.95), intercept = TRUE,
86-
symmetrize = TRUE,
87-
nonneg = TRUE,
88-
quantile_by_key = FALSE) {
86+
arx_args_list <- function(lags = c(0, 7, 14), ahead = 7, min_train_window = 20,
87+
levels = c(0.05, 0.95), intercept = TRUE,
88+
symmetrize = TRUE,
89+
nonneg = TRUE,
90+
quantile_by_key = FALSE) {
8991

9092
# error checking if lags is a list
9193
.lags <- lags
@@ -94,13 +96,15 @@ arx_args_list <- function(
9496
arg_is_scalar(ahead, min_train_window)
9597
arg_is_nonneg_int(ahead, min_train_window, lags)
9698
arg_is_lgl(intercept, symmetrize, nonneg)
97-
arg_is_probabilities(levels, allow_null=TRUE)
99+
arg_is_probabilities(levels, allow_null = TRUE)
98100

99101
max_lags <- max(lags)
100102

101-
list(lags = .lags, ahead = as.integer(ahead),
102-
min_train_window = min_train_window,
103-
levels = levels, intercept = intercept,
104-
symmetrize = symmetrize, nonneg = nonneg,
105-
max_lags = max_lags)
106-
}
103+
list(
104+
lags = .lags, ahead = as.integer(ahead),
105+
min_train_window = min_train_window,
106+
levels = levels, intercept = intercept,
107+
symmetrize = symmetrize, nonneg = nonneg,
108+
max_lags = max_lags
109+
)
110+
}

R/compat-recipes.R

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# These are copied from `recipes` where they are unexported
2+
3+
fun_calls <- function (f) {
4+
if (is.function(f)) fun_calls(body(f))
5+
else if (rlang::is_quosure(f)) fun_calls(rlang::quo_get_expr(f))
6+
else if (is.call(f)) {
7+
fname <- as.character(f[[1]])
8+
if (identical(fname, ".Internal"))
9+
return(fname)
10+
unique(c(fname, unlist(lapply(f[-1], fun_calls), use.names = FALSE)))
11+
}
12+
}
13+
14+
inline_check <- function(x) {
15+
funs <- fun_calls(x)
16+
funs <- funs[!(funs %in% c("~", "+", "-"))]
17+
if (length(funs) > 0) {
18+
rlang::abort(paste0(
19+
"No in-line functions should be used here; ",
20+
"use steps to define baking actions."
21+
))
22+
}
23+
invisible(x)
24+
}
25+
26+
#' @importFrom stats as.formula
27+
get_lhs_vars <- function(formula, data) {
28+
if (!rlang::is_formula(formula)) {
29+
formula <- as.formula(formula)
30+
}
31+
## Want to make sure that multiple outcomes can be expressed as
32+
## additions with no cbind business and that `.` works too (maybe)
33+
new_formula <- rlang::new_formula(lhs = NULL, rhs = rlang::f_lhs(formula))
34+
get_rhs_vars(new_formula, data)
35+
}
36+
37+
#' @importFrom stats model.frame
38+
get_rhs_vars <- function(formula, data, no_lhs = FALSE) {
39+
if (!rlang::is_formula(formula)) {
40+
formula <- as.formula(formula)
41+
}
42+
if (no_lhs) {
43+
formula <- rlang::new_formula(lhs = NULL, rhs = rlang::f_rhs(formula))
44+
}
45+
46+
## This will need a lot of work to account for cases with `.`
47+
## or embedded functions like `Sepal.Length + poly(Sepal.Width)`.
48+
## or should it? what about Y ~ log(x)?
49+
## Answer: when called from `form2args`, the function
50+
## `inline_check` stops when in-line functions are used.
51+
data_info <- attr(model.frame(formula, data[1, ]), "terms")
52+
response_info <- attr(data_info, "response")
53+
predictor_names <- names(attr(data_info, "dataClasses"))
54+
if (length(response_info) > 0 && all(response_info > 0)) {
55+
predictor_names <- predictor_names[-response_info]
56+
}
57+
predictor_names
58+
}
59+
60+
## Buckets variables into discrete, mutally exclusive types
61+
get_types <- function(x) {
62+
var_types <-
63+
c(
64+
character = "nominal",
65+
factor = "nominal",
66+
ordered = "nominal",
67+
integer = "numeric",
68+
numeric = "numeric",
69+
double = "numeric",
70+
Surv = "censored",
71+
logical = "logical",
72+
Date = "date",
73+
POSIXct = "date",
74+
list = "list",
75+
textrecipes_tokenlist = "tokenlist"
76+
)
77+
78+
classes <- lapply(x, class)
79+
res <- lapply(
80+
classes,
81+
function(x, types) {
82+
in_types <- x %in% names(types)
83+
if (sum(in_types) > 0) {
84+
# not sure what to do with multiple matches; right now
85+
## pick the first match which favors "factor" over "ordered"
86+
out <- unname(types[min(which(names(types) %in% x))])
87+
} else {
88+
out <- "other"
89+
}
90+
out
91+
},
92+
types = var_types
93+
)
94+
res <- unlist(res)
95+
tibble(variable = names(res), type = unname(res))
96+
}

0 commit comments

Comments
 (0)