diff --git a/R/get_labels.R b/R/get_labels.R index 1533f0b00..f46a30ddd 100644 --- a/R/get_labels.R +++ b/R/get_labels.R @@ -1,16 +1,20 @@ -get_labels <- function(x, idx = NULL, by = NULL, wrap_parens = FALSE) { +get_labels <- function(x, idx = NULL, by = NULL, wrap_parens = FALSE, hypothesis_by = NULL) { if (!is.data.frame(x) && !is.vector(x)) { return(NULL) } if (is.data.frame(x)) { + x <- data.table::data.table(x) + # Identify relevant columns lab_cols <- grep("^term$|^group$|^contrast$|^contrast_|^value$|^by$", names(x), value = TRUE) if (isTRUE(checkmate::check_character(by))) { lab_cols <- unique(c(lab_cols, by)) } - # Filter columns with more than one unique value + lab_cols <- setdiff(lab_cols, hypothesis_by) + + # Filter out columns with more than one unique value, within `hypothesis_by` lab_cols <- Filter(function(col) length(unique(x[[col]])) > 1, lab_cols) if (length(lab_cols) == 0) { @@ -18,11 +22,23 @@ get_labels <- function(x, idx = NULL, by = NULL, wrap_parens = FALSE) { labels <- paste0("b", seq_len(nrow(x))) } else { # Create labels by pasting unique combinations of selected columns - lab_df <- data.frame(x)[, lab_cols, drop = FALSE] + lab_df <- x[, ..lab_cols] labels <- apply(lab_df, 1, paste, collapse = " ") - # Handle duplicate labels - if (anyDuplicated(labels) > 0) { + # duplicated labels (within groups) revert to b1, b2, ... + uniq <- TRUE + if (isTRUE(checkmate::check_character(hypothesis_by, min.len = 1))) { + uniq <- x[, ..hypothesis_by] + uniq[, marginaleffects_unique_labels := labels] + uniq <- uniq[, + .(marginaleffects_uniq_labels = anyDuplicated(marginaleffects_unique_labels) > 0), + by = hypothesis_by] + uniq <- !any(uniq$V1) + } else if (anyDuplicated(labels) > 0) { + uniq <- FALSE + } + + if (!isTRUE(uniq)) { labels <- paste0("b", seq_len(nrow(x))) } } diff --git a/R/hypothesis_formula.R b/R/hypothesis_formula.R index 2b6d5fcf0..d899dfbbe 100644 --- a/R/hypothesis_formula.R +++ b/R/hypothesis_formula.R @@ -149,12 +149,18 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) { } } - labels <- get_labels(x, by = by) form <- sanitize_hypothesis_formula(hypothesis) group <- form$group + if (isTRUE(checkmate::check_character(by))) { + bycols <- setdiff(by, group) + } else { + bycols <- by + } + labels <- get_labels(x, by = bycols, hypothesis_by = group) + if (isTRUE(form$lhs == "arbitrary_function")) { fun_comparison <- sprintf("function(x) %s", form$rhs) fun_label <- sprintf("function(x) suppressWarnings(names(%s))", form$rhs) @@ -181,8 +187,7 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) { groupval <- list() if (length(col_x) > 0) { groupval <- c(groupval, list(x[, ..col_x, drop = FALSE])) - } - if (length(col_newdata) > 0) { + } else if (length(col_newdata) > 0) { groupval <- c(groupval, list(newdata[, ..col_newdata, drop = FALSE])) } groupval <- do.call(cbind, Filter(is.data.frame, groupval)) @@ -203,6 +208,7 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) { lab <- function(x) suppressWarnings(names(fun_comparison(x))) lab <- tryCatch(combined[, lapply(.SD, lab), keyby = groupval], error = function(e) NULL) + if (inherits(lab, "data.frame") && nrow(lab) == nrow(estimates)) { data.table::setnames(lab, old = "estimate", "hypothesis") cols <- setdiff(colnames(lab), colnames(estimates)) @@ -222,10 +228,10 @@ hypothesis_formula <- function(x, hypothesis, newdata, by) { out <- estimates # Sometimes we get duplicated `term` columns - idx <- grep("term", colnames(out), value = TRUE) + # drop all instances after the first + idx <- grep("^term$", colnames(out)) if (length(idx) > 1) { idx <- idx[2:length(idx)] - # drop all instances after the first out <- out[, -..idx] } diff --git a/inst/tinytest/test-hypothesis.R b/inst/tinytest/test-hypothesis.R index 8273beaa3..72663282f 100644 --- a/inst/tinytest/test-hypothesis.R +++ b/inst/tinytest/test-hypothesis.R @@ -342,8 +342,29 @@ mod <- lm(mpg ~ factor(cyl) + factor(gear), data = mtcars) cmp <- avg_comparisons(mod, hypothesis = ~ pairwise | term) expect_inherits(cmp, "comparisons") expect_equal(nrow(cmp), 2) -expect_true("(cyl 8 - 4) - (cyl 6 - 4)" %in% cmp$hypothesis) - +expect_true("(8 - 4) - (6 - 4)" %in% cmp$hypothesis) +expect_true("(5 - 3) - (4 - 3)" %in% cmp$hypothesis) +expect_equal(cmp$term, c("cyl", "gear")) + + +# Issue #1373 +dat <- get_dataset("thornton") +dat$incentive <- as.factor(dat$incentive) +dat$hiv2004 <- as.factor(dat$hiv2004) +mod <- glm( + outcome ~ incentive * agecat, + data = dat, + family = binomial +) +p <- avg_predictions( + mod, + by = c("incentive", "agecat"), + newdata = datagrid(by = c("incentive", "agecat")), + hypothesis = ~ pairwise | agecat +) +expect_inherits(p, "predictions") +expect_equal(nrow(p), 3) +expect_false(any(grepl("18", p$hypothesis))) # no duplicate label