Skip to content

Commit

Permalink
issue #1373
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Jan 31, 2025
1 parent 263f360 commit d281782
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 12 deletions.
26 changes: 21 additions & 5 deletions R/get_labels.R
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
get_labels <- function(x, idx = NULL, by = NULL, wrap_parens = FALSE) {
get_labels <- function(x, idx = NULL, by = NULL, wrap_parens = FALSE, hypothesis_by = NULL) {
if (!is.data.frame(x) && !is.vector(x)) {
return(NULL)
}

if (is.data.frame(x)) {
x <- data.table::data.table(x)

# Identify relevant columns
lab_cols <- grep("^term$|^group$|^contrast$|^contrast_|^value$|^by$", names(x), value = TRUE)
if (isTRUE(checkmate::check_character(by))) {
lab_cols <- unique(c(lab_cols, by))
}

# Filter columns with more than one unique value
lab_cols <- setdiff(lab_cols, hypothesis_by)

# Filter out columns with more than one unique value, within `hypothesis_by`
lab_cols <- Filter(function(col) length(unique(x[[col]])) > 1, lab_cols)

if (length(lab_cols) == 0) {
# Default labels if no meaningful columns found
labels <- paste0("b", seq_len(nrow(x)))
} else {
# Create labels by pasting unique combinations of selected columns
lab_df <- data.frame(x)[, lab_cols, drop = FALSE]
lab_df <- x[, ..lab_cols]
labels <- apply(lab_df, 1, paste, collapse = " ")

# Handle duplicate labels
if (anyDuplicated(labels) > 0) {
# duplicated labels (within groups) revert to b1, b2, ...
uniq <- TRUE
if (isTRUE(checkmate::check_character(hypothesis_by, min.len = 1))) {
uniq <- x[, ..hypothesis_by]
uniq[, marginaleffects_unique_labels := labels]
uniq <- uniq[,
.(marginaleffects_uniq_labels = anyDuplicated(marginaleffects_unique_labels) > 0),
by = hypothesis_by]
uniq <- !any(uniq$V1)
} else if (anyDuplicated(labels) > 0) {
uniq <- FALSE
}

if (!isTRUE(uniq)) {
labels <- paste0("b", seq_len(nrow(x)))
}
}
Expand Down
16 changes: 11 additions & 5 deletions R/hypothesis_formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,18 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) {
}
}

labels <- get_labels(x, by = by)

form <- sanitize_hypothesis_formula(hypothesis)

group <- form$group

if (isTRUE(checkmate::check_character(by))) {
bycols <- setdiff(by, group)
} else {
bycols <- by
}
labels <- get_labels(x, by = bycols, hypothesis_by = group)

if (isTRUE(form$lhs == "arbitrary_function")) {
fun_comparison <- sprintf("function(x) %s", form$rhs)
fun_label <- sprintf("function(x) suppressWarnings(names(%s))", form$rhs)
Expand All @@ -181,8 +187,7 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) {
groupval <- list()
if (length(col_x) > 0) {
groupval <- c(groupval, list(x[, ..col_x, drop = FALSE]))
}
if (length(col_newdata) > 0) {
} else if (length(col_newdata) > 0) {
groupval <- c(groupval, list(newdata[, ..col_newdata, drop = FALSE]))
}
groupval <- do.call(cbind, Filter(is.data.frame, groupval))
Expand All @@ -203,6 +208,7 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) {

lab <- function(x) suppressWarnings(names(fun_comparison(x)))
lab <- tryCatch(combined[, lapply(.SD, lab), keyby = groupval], error = function(e) NULL)

if (inherits(lab, "data.frame") && nrow(lab) == nrow(estimates)) {
data.table::setnames(lab, old = "estimate", "hypothesis")
cols <- setdiff(colnames(lab), colnames(estimates))
Expand All @@ -222,10 +228,10 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) {
out <- estimates

# Sometimes we get duplicated `term` columns
idx <- grep("term", colnames(out), value = TRUE)
# drop all instances after the first
idx <- grep("^term$", colnames(out))
if (length(idx) > 1) {
idx <- idx[2:length(idx)]
# drop all instances after the first
out <- out[, -..idx]
}

Expand Down
25 changes: 23 additions & 2 deletions inst/tinytest/test-hypothesis.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,29 @@ mod <- lm(mpg ~ factor(cyl) + factor(gear), data = mtcars)
cmp <- avg_comparisons(mod, hypothesis = ~ pairwise | term)
expect_inherits(cmp, "comparisons")
expect_equal(nrow(cmp), 2)
expect_true("(cyl 8 - 4) - (cyl 6 - 4)" %in% cmp$hypothesis)

expect_true("(8 - 4) - (6 - 4)" %in% cmp$hypothesis)
expect_true("(5 - 3) - (4 - 3)" %in% cmp$hypothesis)
expect_equal(cmp$term, c("cyl", "gear"))


# Issue #1373
dat <- get_dataset("thornton")
dat$incentive <- as.factor(dat$incentive)
dat$hiv2004 <- as.factor(dat$hiv2004)
mod <- glm(
outcome ~ incentive * agecat,
data = dat,
family = binomial
)
p <- avg_predictions(
mod,
by = c("incentive", "agecat"),
newdata = datagrid(by = c("incentive", "agecat")),
hypothesis = ~ pairwise | agecat
)
expect_inherits(p, "predictions")
expect_equal(nrow(p), 3)
expect_false(any(grepl("18", p$hypothesis))) # no duplicate label



Expand Down

0 comments on commit d281782

Please sign in to comment.