Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: Robyn
Type: Package
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
Version: 3.12.0.9005
Version: 3.12.0.9006
Authors@R: c(
person("Gufeng", "Zhou", , "gufeng@meta.com", c("cre", "aut")),
person("Igor", "Skokan", , "igorskokan@meta.com", c("aut")),
Expand Down
3 changes: 2 additions & 1 deletion R/R/allocator.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ robyn_allocator <- function(robyn_object = NULL,

## set local variables, sort & prompt
# paid_media_spends <- InputCollect$paid_media_spends
paid_media_selected <- InputCollect$paid_media_selected
paid_media_selected <- if ("paid_media_selected" %in% names(InputCollect))
InputCollect$paid_media_selected else InputCollect$paid_media_spends
dep_var_type <- InputCollect$dep_var_type
if (is.null(select_model) && length(OutputCollect$allSolutions == 1)) {
select_model <- OutputCollect$allSolutions
Expand Down
26 changes: 26 additions & 0 deletions R/R/auxiliary.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,29 @@ baseline_vars <- function(InputCollect, baseline_level) {

# Calculate MSE
.mse_loss <- function(y, y_hat) mean((y - y_hat)^2)

# next_date(c("2021-01-01", "2021-02-01"))
# next_date(c("2021-01-01", "2021-01-08", "2021-01-15"))
# next_date(c(Sys.Date() - 1, Sys.Date()))
.next_date <- function(dates) {
dates <- as.Date(dates)
diffs <- diff(dates)
if (all(diffs == 1)) {
frequency <- "daily"
} else if (all(diffs == 7)) {
frequency <- "weekly"
} else if (all(format(dates[-length(dates)], "%Y-%m") != format(dates[-1], "%Y-%m"))) {
frequency <- "monthly"
} else {
warning(paste(
"Unable to determine frequency to calculate next logical date.",
"Returning last available date."))
return(as.Date(tail(dates, 1)))
}
next_date <- switch(
frequency,
"daily" = dates[length(dates)] + 1,
"weekly" = dates[length(dates)] + 7,
"monthly" = seq(dates[length(dates)], by = "1 month", length.out = 2)[2])
return(as.Date(next_date))
}
21 changes: 11 additions & 10 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -395,30 +395,31 @@ check_windows <- function(dt_input, date_var, all_media, window_start, window_en
refreshAddedStart <- window_start

if (is.null(window_end)) {
window_end <- max(dates_vec)
window_end <- .next_date(dates_vec) - 1
} else {
window_end <- as.Date(as.character(window_end), "%Y-%m-%d", origin = "1970-01-01")
if (is.na(window_end)) {
stop(sprintf("Input 'window_end' must have date format, i.e. '%s'", Sys.Date()))
} else if (window_end > max(dates_vec)) {
window_end <- max(dates_vec)
} else if (window_end > .next_date(dates_vec) - 1) {
window_end <- .next_date(dates_vec) - 1
message(paste(
"Input 'window_end' is larger than the latest date in input data.",
"It's automatically set to the latest date:", window_end
"Input 'window_end' is larger than the latest dates available in input data.",
"Automatically set to date:", window_end
))
} else if (window_end < window_start) {
window_end <- max(dates_vec)
window_end <- .next_date(dates_vec) - 1
message(paste(
"Input 'window_end' must be >= 'window_start.",
"It's automatically set to the latest date:", window_end
"Automatically set to date:", window_end
))
}
}

# Find closest date contained in input data
rollingWindowEndWhich <- which.min(abs(difftime(dates_vec, window_end, units = "days")))
if (!(window_end %in% dates_vec)) {
window_end <- dt_input[rollingWindowEndWhich, date_var][[1]]
message("Input 'window_end' is adapted to the closest date contained in input data: ", window_end)
if (!window_end %in% c(dates_vec, .next_date(dates_vec) - 1)) {
window_end <- .next_date(dt_input[seq(rollingWindowEndWhich), date_var][[1]]) - 1
message("Input 'window_end' is adapted to the closest available date from input data: ", window_end)
}
rollingWindowLength <- rollingWindowEndWhich - rollingWindowStartWhich + 1

Expand Down
2 changes: 1 addition & 1 deletion R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ init_msgs_run <- function(InputCollect, refresh, lambda_control = NULL, quiet =
nrow(InputCollect$dt_mod),
InputCollect$intervalType,
min(InputCollect$dt_mod$ds),
max(InputCollect$dt_mod$ds)
.next_date(InputCollect$dt_mod$ds) - 1
))
depth <- ifelse(
"refreshDepth" %in% names(InputCollect),
Expand Down
4 changes: 1 addition & 3 deletions R/R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,6 @@ robyn_onepagers <- function(
## 4. Response curves
dt_scurvePlot <- temp[[sid]]$plot4data$dt_scurvePlot
dt_scurvePlotMean <- temp[[sid]]$plot4data$dt_scurvePlotMean
paid_media_selected <- if ("paid_media_selected" %in% names(InputCollect))
InputCollect$paid_media_selected else InputCollect$paid_media_spends
trim_rate <- 1.3 # maybe enable as a parameter
if (trim_rate > 0) {
dt_scurvePlot <- dt_scurvePlot %>%
Expand All @@ -496,7 +494,7 @@ robyn_onepagers <- function(
filter(
.data$spend < max(dt_scurvePlotMean$mean_spend_adstocked) * trim_rate,
.data$response < max(dt_scurvePlotMean$mean_response) * trim_rate,
.data$channel %in% paid_media_selected
.data$channel %in% InputCollect$paid_media_vars
) %>%
left_join(
dt_scurvePlotMean[, c("channel", "mean_carryover")], "channel"
Expand Down