diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R index 6a62a058a99..6eddea55294 100644 --- a/r/R/dplyr-group-by.R +++ b/r/R/dplyr-group-by.R @@ -35,19 +35,20 @@ group_by.arrow_dplyr_query <- function(.data, .data <- as_adq(.data) expression_list <- expand_across(.data, quos(...)) - new_groups <- ensure_named_exprs(expression_list) + named_expression_list <- ensure_named_exprs(expression_list) - # set up group names and check which are new + # Set up group names gbp <- dplyr::group_by_prepare(.data, !!!expression_list, .add = .add) - existing_groups <- dplyr::group_vars(gbp$data) - new_group_names <- setdiff(gbp$group_names, existing_groups) - names(new_groups) <- new_group_names - - if (length(new_groups)) { - # Add them to the data - .data <- dplyr::mutate(.data, !!!new_groups) - } + # Add them all (or update them) to the .data via. In theory + # one could calculate which variables do or do not need to be added via a + # complex combination of the expression names, whether they are or are not + # a symbol, and/or whether they currently exist in .data. Instead, we just + # put them all into a mutate(). + existing_groups <- dplyr::groups(gbp$data) + names(existing_groups) <- dplyr::group_vars(gbp$data) + final_groups <- c(unclass(named_expression_list), unclass(existing_groups))[gbp$group_names] + .data <- dplyr::mutate(.data, !!!final_groups) .data$group_by_vars <- gbp$group_names .data$drop_empty_groups <- ifelse(length(gbp$group_names), .drop, dplyr::group_by_drop_default(.data)) diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index 9f2869dd106..f17a0139161 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -52,6 +52,24 @@ test_that("group_by supports creating/renaming", { ) }) +test_that("group_by supports re-grouping by overlapping groups", { + compare_dplyr_binding( + .input %>% + group_by(chr, int) %>% + group_by(int, dbl) %>% + collect(), + tbl + ) + + compare_dplyr_binding( + .input %>% + group_by(chr, int) %>% + group_by(int, chr = "some new value") %>% + collect(), + tbl + ) +}) + test_that("ungroup", { compare_dplyr_binding( .input %>%