diff --git a/r/NEWS.md b/r/NEWS.md index 65c4e2205cc..a008088ff82 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -19,6 +19,13 @@ # arrow 3.0.0.9000 +## dplyr methods + +* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in Arrow for many applications. Where not yet supported, the implementation falls back to pulling data into an R `data.frame` first. +* String functions `nchar()`, `tolower()`, and `toupper()`, along with their `stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are supported in Arrow `dplyr` calls. `str_trim()` is also supported. + +## Other improvements + * `value_counts()` to tabulate values in an `Array` or `ChunkedArray`, similar to `base::table()`. * `StructArray` objects gain data.frame-like methods, including `names()`, `$`, `[[`, and `dim()`. * RecordBatch columns can now be added, replaced, or removed by assigning (`<-`) with either `$` or `[[` diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 66694a97867..818d85c8580 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -30,7 +30,7 @@ "dplyr::", c( "select", "filter", "collect", "summarise", "group_by", "groups", - "group_vars", "ungroup", "mutate", "arrange", "rename", "pull" + "group_vars", "ungroup", "mutate", "transmute", "arrange", "rename", "pull" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 3d0f31ce8f3..790232c8e21 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -744,6 +744,10 @@ dataset___expr__field_ref <- function(name){ .Call(`_arrow_dataset___expr__field_ref`, name) } +dataset___expr__get_field_ref_name <- function(ref){ + .Call(`_arrow_dataset___expr__get_field_ref_name`, ref) +} + dataset___expr__scalar <- function(x){ .Call(`_arrow_dataset___expr__scalar`, x) } diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 45fc968ed08..ec6f85c4bab 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -69,6 +69,10 @@ Scanner$create <- function(dataset, batch_size = NULL, ...) { if (inherits(dataset, "arrow_dplyr_query")) { + if (inherits(dataset$.data, "ArrowTabular")) { + # To handle mutate() on Table/RecordBatch, we need to collect(as_data_frame=FALSE) now + dataset <- dplyr::collect(dataset, as_data_frame = FALSE) + } return(Scanner$create( dataset$.data, dataset$selected_columns, @@ -152,6 +156,12 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) { ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject, public = list( Project = function(cols) { + # cols is either a character vector or a named list of Expressions + if (!is.character(cols)) { + # We don't yet support mutate() on datasets, so this is just a list + # of FieldRefs, and we need to back out the field names + cols <- get_field_names(cols) + } assert_is(cols, "character") dataset___ScannerBuilder__Project(self, cols) self diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R index c5c92926715..5078bc3e371 100644 --- a/r/R/dataset-write.R +++ b/r/R/dataset-write.R @@ -62,8 +62,12 @@ write_dataset <- function(dataset, hive_style = TRUE, ...) { if (inherits(dataset, "arrow_dplyr_query")) { + if (inherits(dataset$.data, "ArrowTabular")) { + # collect() to materialize any mutate/rename + dataset <- dplyr::collect(dataset, as_data_frame = FALSE) + } # We can select a subset of columns but we can't rename them - if (!all(dataset$selected_columns == names(dataset$selected_columns))) { + if (!all(get_field_names(dataset) == names(dataset$selected_columns))) { stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE) } # partitioning vars need to be in the `select` schema diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 32713741b53..2bd8170a1cb 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -33,11 +33,11 @@ arrow_dplyr_query <- function(.data) { structure( list( .data = .data$clone(), - # selected_columns is a named character vector: - # * vector contents are the names of the columns in the data - # * vector names are the names they should be in the end (i.e. this + # selected_columns is a named list: + # * contents are references/expressions pointing to the data + # * names are the names they should be in the end (i.e. this # records any renaming) - selected_columns = set_names(names(.data)), + selected_columns = make_field_refs(names(.data), dataset = inherits(.data, "Dataset")), # filtered_rows will be an Expression filtered_rows = TRUE, # group_by_vars is a character vector of columns (as renamed) @@ -51,8 +51,15 @@ arrow_dplyr_query <- function(.data) { #' @export print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema - cols <- x$selected_columns - fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString()) + cols <- get_field_names(x) + # If cols are expressions, they won't be in the schema and will be "" in cols + fields <- map_chr(cols, function(name) { + if (nzchar(name)) { + schm$GetFieldByName(name)$ToString() + } else { + "expr" + } + }) # Strip off the field names as they are in the dataset and add the renamed ones fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") cat(class(x$.data)[1], " (query)\n", sep = "") @@ -73,6 +80,33 @@ print.arrow_dplyr_query <- function(x, ...) { invisible(x) } +get_field_names <- function(selected_cols) { + if (inherits(selected_cols, "arrow_dplyr_query")) { + selected_cols <- selected_cols$selected_columns + } + map_chr(selected_cols, function(x) { + if (inherits(x, "Expression")) { + out <- x$field_name + } else if (inherits(x, "array_expression")) { + out <- x$args$field_name + } else { + out <- NULL + } + # If x isn't some kind of field reference, out is NULL, + # but we always need to return a string + out %||% "" + }) +} + +make_field_refs <- function(field_names, dataset = TRUE) { + if (dataset) { + out <- lapply(field_names, Expression$field_ref) + } else { + out <- lapply(field_names, function(x) array_expression("array_ref", field_name = x)) + } + set_names(out, field_names) +} + # These are the names reflecting all select/rename, not what is in Arrow #' @export names.arrow_dplyr_query <- function(x) names(x$selected_columns) @@ -89,7 +123,7 @@ dim.arrow_dplyr_query <- function(x) { rows <- NA_integer_ } else { # Evaluate the filter expression to a BooleanArray and count - rows <- as.integer(sum(eval_array_expression(x$filtered_rows), na.rm = TRUE)) + rows <- as.integer(sum(eval_array_expression(x$filtered_rows, x$.data), na.rm = TRUE)) } c(rows, cols) } @@ -187,29 +221,8 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { } .data <- arrow_dplyr_query(.data) - # The filter() method works by evaluating the filters to generate Expressions - # with references to Arrays (if .data is Table/RecordBatch) or Fields (if - # .data is a Dataset). - dm <- filter_mask(.data) - filters <- lapply(filts, function (f) { - # This should yield an Expression as long as the filter function(s) are - # implemented in Arrow. - tryCatch(eval_tidy(f, dm), error = function(e) { - # Look for the cases where bad input was given, i.e. this would fail - # in regular dplyr anyway, and let those raise those as errors; - # else, for things not supported by Arrow return a "try-error", - # which we'll handle differently - msg <- conditionMessage(e) - # TODO: internationalization? - if (grepl("object '.*'.not.found", msg)) { - stop(e) - } - if (grepl('could not find function ".*"', msg)) { - stop(e) - } - invisible(structure(msg, class = "try-error", condition = e)) - }) - }) + # tidy-eval the filter expressions inside an Arrow data_mask + filters <- lapply(filts, arrow_eval, arrow_mask(.data)) bad_filters <- map_lgl(filters, ~inherits(., "try-error")) if (any(bad_filters)) { bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE) @@ -238,6 +251,30 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { } filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query +arrow_eval <- function (expr, mask) { + # filter(), mutate(), etc. work by evaluating the quoted `exprs` to generate Expressions + # with references to Arrays (if .data is Table/RecordBatch) or Fields (if + # .data is a Dataset). + + # This yields an Expression as long as the `exprs` are implemented in Arrow. + # Otherwise, it returns a try-error + tryCatch(eval_tidy(expr, mask), error = function(e) { + # Look for the cases where bad input was given, i.e. this would fail + # in regular dplyr anyway, and let those raise those as errors; + # else, for things not supported by Arrow return a "try-error", + # which we'll handle differently + msg <- conditionMessage(e) + # TODO(ARROW-11700): internationalization + if (grepl("object '.*'.not.found", msg)) { + stop(e) + } + if (grepl('could not find function ".*"', msg)) { + stop(e) + } + invisible(structure(msg, class = "try-error", condition = e)) + }) +} + # Helper to assemble the functions that go in the NSE data mask # The only difference between the Dataset and the Table/RecordBatch versions # is that they use a different wrapping function (FUN) to hold the unevaluated @@ -271,23 +308,32 @@ build_function_list <- function(FUN) { dataset_function_list <- build_function_list(build_dataset_expression) array_function_list <- build_function_list(build_array_expression) -# Create a data mask for evaluating a filter expression -filter_mask <- function(.data) { +# Create a data mask for evaluating a dplyr expression +arrow_mask <- function(.data) { if (query_on_dataset(.data)) { f_env <- new_environment(dataset_function_list) - var_binder <- function(x) Expression$field_ref(x) } else { f_env <- new_environment(array_function_list) - var_binder <- function(x) .data$.data[[x]] } - # Add the column references - # Renaming is handled automatically by the named list - data_pronoun <- lapply(.data$selected_columns, var_binder) - env_bind(f_env, !!!data_pronoun) - # Then bind the data pronoun - env_bind(f_env, .data = data_pronoun) - new_data_mask(f_env) + # Add functions that need to error hard and clear. + # Some R functions will still try to evaluate on an Expression + # and return NA with a warning + fail <- function(...) stop("Not implemented") + for (f in c("mean")) { + f_env[[f]] <- fail + } + + # Add the column references and make the mask + out <- new_data_mask( + new_environment(.data$selected_columns, parent = f_env), + f_env + ) + # Then insert the data pronoun + # TODO: figure out what rlang::as_data_pronoun does/why we should use it + # (because if we do we get `Error: Can't modify the data pronoun` in mutate()) + out$.data <- .data$selected_columns + out } set_filters <- function(.data, expressions) { @@ -309,8 +355,27 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { # See dataset.R for Dataset and Scanner(Builder) classes tab <- Scanner$create(x)$ToTable() } else { - # This is a Table/RecordBatch. See record-batch.R for the [ method - tab <- x$.data[x$filtered_rows, x$selected_columns, keep_na = FALSE] + # This is a Table or RecordBatch + + # Filter and select the data referenced in selected columns + if (isTRUE(x$filtered_rows)) { + filter <- TRUE + } else { + filter <- eval_array_expression(x$filtered_rows, x$.data) + } + # TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))? + tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] + # Now evaluate those expressions on the filtered table + cols <- lapply(x$selected_columns, eval_array_expression, data = tab) + if (length(cols) == 0) { + tab <- tab[, integer(0)] + } else { + if (inherits(x$.data, "Table")) { + tab <- Table$create(!!!cols) + } else { + tab <- RecordBatch$create(!!!cols) + } + } } if (as_data_frame) { df <- as.data.frame(tab) @@ -327,7 +392,13 @@ ensure_group_vars <- function(x) { if (inherits(x, "arrow_dplyr_query")) { # Before pulling data from Arrow, make sure all group vars are in the projection gv <- set_names(setdiff(dplyr::group_vars(x), names(x))) - x$selected_columns <- c(x$selected_columns, gv) + if (length(gv)) { + # Add them back + x$selected_columns <- c( + x$selected_columns, + make_field_refs(gv, dataset = query_on_dataset(.data)) + ) + } } x } @@ -337,21 +408,20 @@ restore_dplyr_features <- function(df, query) { # After calling collect(), make sure these features are carried over grouped <- length(query$group_by_vars) > 0 - renamed <- !identical(names(df), names(query)) - if (is.data.frame(df)) { + renamed <- ncol(df) && !identical(names(df), names(query)) + if (renamed) { # In case variables were renamed, apply those names - if (renamed && ncol(df)) { - names(df) <- names(query) - } + names(df) <- names(query) + } + if (grouped) { # Preserve groupings, if present - if (grouped) { + if (is.data.frame(df)) { df <- dplyr::grouped_df(df, dplyr::group_vars(query)) + } else { + # This is a Table, via collect(as_data_frame = FALSE) + df <- arrow_dplyr_query(df) + df$group_by_vars <- query$group_by_vars } - } else if (grouped || renamed) { - # This is a Table, via collect(as_data_frame = FALSE) - df <- arrow_dplyr_query(df) - names(df$selected_columns) <- names(query) - df$group_by_vars <- query$group_by_vars } df } @@ -423,26 +493,117 @@ ungroup.arrow_dplyr_query <- function(x, ...) { } ungroup.Dataset <- ungroup.ArrowTabular <- force -mutate.arrow_dplyr_query <- function(.data, ...) { +mutate.arrow_dplyr_query <- function(.data, + ..., + .keep = c("all", "used", "unused", "none"), + .before = NULL, + .after = NULL) { + call <- match.call() + exprs <- quos(...) + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } + .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("mutate()") } - # TODO: see if we can defer evaluating the expressions and not collect here. - # It's different from filters (as currently implemented) because the basic - # vector transformation functions aren't yet implemented in Arrow C++. - dplyr::mutate(dplyr::collect(.data), ...) + + .keep <- match.arg(.keep) + .before <- enquo(.before) + .after <- enquo(.after) + # Restrict the cases we support for now + if (!quo_is_null(.before) || !quo_is_null(.after)) { + # TODO(ARROW-11701) + return(abandon_ship(call, .data, '.before and .after arguments are not supported in Arrow')) + } else if (length(group_vars(.data)) > 0) { + # mutate() on a grouped dataset does calculations within groups + # This doesn't matter on scalar ops (arithmetic etc.) but it does + # for things with aggregations (e.g. subtracting the mean) + return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) + } + + # Check for unnamed expressions and fix if any + unnamed <- !nzchar(names(exprs)) + # Deparse and take the first element in case they're long expressions + names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label) + + mask <- arrow_mask(.data) + results <- list() + for (i in seq_along(exprs)) { + # Iterate over the indices and not the names because names may be repeated + # (which overwrites the previous name) + new_var <- names(exprs)[i] + results[[new_var]] <- arrow_eval(exprs[[i]], mask) + if (inherits(results[[new_var]], "try-error")) { + msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') + return(abandon_ship(call, .data, msg)) + } + # Put it in the data mask too + mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] + } + + # Assign the new columns into the .data$selected_columns, respecting the .keep param + if (.keep == "none") { + .data$selected_columns <- results + } else { + if (.keep != "all") { + # "used" or "unused" + used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) + old_vars <- names(.data$selected_columns) + if (.keep == "used") { + .data$selected_columns <- .data$selected_columns[intersect(old_vars, used_vars)] + } else { + # "unused" + .data$selected_columns <- .data$selected_columns[setdiff(old_vars, used_vars)] + } + } + # Note that this is names(exprs) not names(results): + # if results$new_var is NULL, that means we are supposed to remove it + for (new_var in names(exprs)) { + .data$selected_columns[[new_var]] <- results[[new_var]] + } + } + # Even if "none", we still keep group vars + ensure_group_vars(.data) } mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query -# TODO: add transmute() that does what summarise() does (select only the vars we need) + +transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ..., .keep = "none") +transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query + +# Helper to handle unsupported dplyr features +# * For Table/RecordBatch, we collect() and then call the dplyr method in R +# * For Dataset, we just error +abandon_ship <- function(call, .data, msg = NULL) { + dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]])) + if (query_on_dataset(.data)) { + if (is.null(msg)) { + # Default message: function not implemented + not_implemented_for_dataset(paste0(dplyr_fun_name, "()")) + } else { + stop(msg, call. = FALSE) + } + } + + # else, collect and call dplyr method + if (!is.null(msg)) { + warning(msg, "; pulling data into R", immediate. = TRUE, call. = FALSE) + } + call$.data <- dplyr::collect(.data) + call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr")) + eval.parent(call, 2) +} arrange.arrow_dplyr_query <- function(.data, ...) { .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("arrange()") } - - dplyr::arrange(dplyr::collect(.data), ...) + # TODO(ARROW-11703) move this to Arrow + call <- match.call() + abandon_ship(call, .data) } arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query diff --git a/r/R/expression.R b/r/R/expression.R index 878b800c652..74c1aefcae1 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -143,7 +143,14 @@ cast_array_expression <- function(x, to_type, safe = TRUE, ...) { .array_function_map <- c(.unary_function_map, .binary_function_map) -eval_array_expression <- function(x) { +eval_array_expression <- function(x, data = NULL) { + if (!is.null(data)) { + x <- bind_array_refs(x, data) + } + if (!inherits(x, "array_expression")) { + # Nothing to evaluate + return(x) + } x$args <- lapply(x$args, function (a) { if (inherits(a, "array_expression")) { eval_array_expression(a) @@ -154,6 +161,27 @@ eval_array_expression <- function(x) { call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } +find_array_refs <- function(x) { + if (identical(x$fun, "array_ref")) { + out <- x$args$field_name + } else { + out <- lapply(x$args, find_array_refs) + } + unlist(out) +} + +# Take an array_expression and replace array_refs with arrays/chunkedarrays from data +bind_array_refs <- function(x, data) { + if (inherits(x, "array_expression")) { + if (identical(x$fun, "array_ref")) { + x <- data[[x$args$field_name]] + } else { + x$args <- lapply(x$args, bind_array_refs, data) + } + } + x +} + #' @export is.na.array_expression <- function(x) array_expression("is.na", x) @@ -181,9 +209,13 @@ print.array_expression <- function(x, ...) { deparse(arg) } }) - # Prune this for readability - function_name <- sub("_kleene", "", x$fun) - paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") + if (identical(x$fun, "array_ref")) { + x$args$field_name + } else { + # Prune this for readability + function_name <- sub("_kleene", "", x$fun) + paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") + } } ########### @@ -217,6 +249,9 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ) Expression$create("cast", self, options = modifyList(opts, list(...))) } + ), + active = list( + field_name = function() dataset___expr__get_field_ref_name(self) ) ) Expression$create <- function(function_name, diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 839c9d6c173..73ee64844a6 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1569,6 +1569,14 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +std::string dataset___expr__get_field_ref_name(const std::shared_ptr& ref); +extern "C" SEXP _arrow_dataset___expr__get_field_ref_name(SEXP ref_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type ref(ref_sexp); + return cpp11::as_sexp(dataset___expr__get_field_ref_name(ref)); +END_CPP11 +} +// expression.cpp std::shared_ptr dataset___expr__scalar(const std::shared_ptr& x); extern "C" SEXP _arrow_dataset___expr__scalar(SEXP x_sexp){ BEGIN_CPP11 @@ -3702,6 +3710,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_FixedSizeListType__list_size", (DL_FUNC) &_arrow_FixedSizeListType__list_size, 1}, { "_arrow_dataset___expr__call", (DL_FUNC) &_arrow_dataset___expr__call, 3}, { "_arrow_dataset___expr__field_ref", (DL_FUNC) &_arrow_dataset___expr__field_ref, 1}, + { "_arrow_dataset___expr__get_field_ref_name", (DL_FUNC) &_arrow_dataset___expr__get_field_ref_name, 1}, { "_arrow_dataset___expr__scalar", (DL_FUNC) &_arrow_dataset___expr__scalar, 1}, { "_arrow_dataset___expr__ToString", (DL_FUNC) &_arrow_dataset___expr__ToString, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, diff --git a/r/src/expression.cpp b/r/src/expression.cpp index ddb1e72c309..76d8222967b 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -47,6 +47,13 @@ std::shared_ptr dataset___expr__field_ref(std::string name) { return std::make_shared(ds::field_ref(std::move(name))); } +// [[arrow::export]] +std::string dataset___expr__get_field_ref_name( + const std::shared_ptr& ref) { + auto refname = ref->field_ref()->name(); + return *refname; +} + // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index ce0f9de8a54..76edea61f57 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -59,3 +59,66 @@ verify_output <- function(...) { } testthat::verify_output(...) } + +expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + skip_record_batch = NULL, # Msg, if should skip RB test + skip_table = NULL, # Msg, if should skip Table test + ...) { + expr <- rlang::enquo(expr) + expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + + skip_msg <- NULL + + if (is.null(skip_record_batch)) { + via_batch <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ) + expect_equivalent(via_batch, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_record_batch) + } + + if (is.null(skip_table)) { + via_table <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ) + expect_equivalent(via_table, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) + } +} + +expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + ...) { + expr <- rlang::enquo(expr) + msg <- tryCatch( + rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), + error = function (e) conditionMessage(e) + ) + expect_is(msg, "character", label = "dplyr on data.frame did not error") + + expect_error( + rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ), + msg, + ... + ) + expect_error( + rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ), + msg, + ... + ) +} \ No newline at end of file diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R index aeee66d8710..a017823ce34 100644 --- a/r/tests/testthat/test-RecordBatch.R +++ b/r/tests/testthat/test-RecordBatch.R @@ -416,6 +416,14 @@ test_that("record_batch() handles null type (ARROW-7064)", { expect_equivalent(batch$schema, schema(a = int32(), n = null())) }) +test_that("record_batch() scalar recycling", { + skip("Not implemented (ARROW-11705)") + expect_data_frame( + record_batch(a = 1:10, b = 5), + tibble::tibble(a = 1:10, b = 5) + ) +}) + test_that("RecordBatch$Equals", { df <- tibble::tibble(x = 1:10, y = letters[1:10]) a <- record_batch(df) diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R new file mode 100644 index 00000000000..f73589496be --- /dev/null +++ b/r/tests/testthat/test-dplyr-filter.R @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("filter() on is.na()", { + expect_dplyr_equal( + input %>% + filter(is.na(lgl)) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("filter() with NAs in selection", { + expect_dplyr_equal( + input %>% + filter(lgl) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("Filter returning an empty Table should not segfault (ARROW-8354)", { + expect_dplyr_equal( + input %>% + filter(false) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression", { + char_sym <- "b" + expect_dplyr_equal( + input %>% + filter(chr == char_sym) %>% + select(string = chr, int) %>% + collect(), + tbl + ) +}) + +test_that("filtering with arithmetic", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl %/% 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression + autocasting", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("More complex select/filter", { + expect_dplyr_equal( + input %>% + filter(dbl > 2, chr == "d" | chr == "f") %>% + select(chr, int, lgl) %>% + filter(int < 5) %>% + select(int, chr) %>% + collect(), + tbl + ) +}) + +test_that("filter() with %in%", { + expect_dplyr_equal( + input %>% + filter(dbl > 2, chr %in% c("d", "f")) %>% + collect(), + tbl + ) +}) + +test_that("filter() with string ops", { + # Extra instrumentation to ensure that we're calling Arrow compute here + # because many base R string functions implicitly call as.character, + # which means they still work on Arrays but actually force data into R + # 1) wrapper that raises a warning if as.character is called. Can't wrap + # the whole test because as.character apparently gets called in other + # (presumably legitimate) places + # 2) Wrap the test in expect_warning(expr, NA) to catch the warning + + with_no_as_character <- function(expr) { + trace( + "as.character", + tracer = quote(warning("as.character was called")), + print = FALSE, + where = toupper + ) + on.exit(untrace("as.character", where = toupper)) + force(expr) + } + + expect_warning( + expect_dplyr_equal( + input %>% + filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% + collect(), + tbl + ), + NA) + + expect_dplyr_equal( + input %>% + filter(dbl > 2, str_length(verses) > 25) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>% + collect(), + tbl + ) +}) + +test_that("filter environment scope", { + # "object 'b_var' not found" + expect_dplyr_error(input %>% filter(batch, chr == b_var)) + + b_var <- "b" + expect_dplyr_equal( + input %>% + filter(chr == b_var) %>% + collect(), + tbl + ) + # Also for functions + # 'could not find function "isEqualTo"' because we haven't defined it yet + expect_dplyr_error(filter(batch, isEqualTo(int, 4))) + + skip("Need to substitute in user defined function too") + # TODO: fix this: this isEqualTo function is eagerly evaluating; it should + # instead yield array_expressions. Probably bc the parent env of the function + # has the Ops.Array methods defined; we need to move it so that the parent + # env is the data mask we use in the dplyr eval + isEqualTo <- function(x, y) x == y & !is.na(x) + expect_dplyr_equal( + input %>% + select(-fct) %>% # factor levels aren't identical + filter(isEqualTo(int, 4)) %>% + collect(), + tbl + ) +}) + +test_that("Filtering on a column that doesn't exist errors correctly", { + skip("Error handling in arrow_eval() needs to be internationalized (ARROW-11700)") + expect_error( + batch %>% filter(not_a_col == 42) %>% collect(), + "object 'not_a_col' not found" + ) +}) + +test_that("Filtering with a function that doesn't have an Array/expr method still works", { + expect_warning( + expect_dplyr_equal( + input %>% + filter(int > 2, pnorm(dbl) > .99) %>% + collect(), + tbl + ), + 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', + fixed = TRUE + ) +}) + +test_that("filter() with .data pronoun", { + expect_dplyr_equal( + input %>% + filter(.data$dbl > 4) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(is.na(.data$lgl)) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + # and the .env pronoun too! + chr <- 4 + expect_dplyr_equal( + input %>% + filter(.data$dbl > .env$chr) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + # but there is an error if we don't override the masking with `.env` + expect_dplyr_error( + tbl %>% + filter(.data$dbl > chr) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect() + ) +}) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R new file mode 100644 index 00000000000..56d7e368520 --- /dev/null +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -0,0 +1,350 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("mutate() is lazy", { + expect_is( + tbl %>% record_batch() %>% mutate(int = int + 6L), + "arrow_dplyr_query" + ) +}) + +test_that("basic mutate", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + mutate(int = int + 6L) %>% + collect(), + tbl + ) +}) + +test_that("transmute", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + transmute(int = int + 6L) %>% + collect(), + tbl + ) +}) + +test_that("mutate and refer to previous mutants", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + line_lengths = nchar(padded_strings), + longer = line_lengths * 10 + ) %>% + filter(line_lengths > 15) %>% + collect(), + tbl + ) +}) + +test_that("mutate with .data pronoun", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + line_lengths = nchar(padded_strings), + longer = .data$line_lengths * 10 + ) %>% + filter(line_lengths > 15) %>% + collect(), + tbl + ) +}) + +test_that("mutate with unnamed expressions", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + int, # bare column name + nchar(padded_strings) # expression + ) %>% + filter(int > 5) %>% + collect(), + tbl + ) +}) + +test_that("mutate with reassigning same name", { + expect_dplyr_equal( + input %>% + transmute( + new = lgl, + new = chr + ) %>% + collect(), + tbl + ) +}) + +test_that("mutate with single value for recycling", { + skip("Not implemented (ARROW-11705") + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + dr_bronner = 1 # ALL ONE! + ) %>% + collect(), + tbl + ) +}) + +test_that("dplyr::mutate's examples", { + # Newly created variables are available immediately + expect_dplyr_equal( + input %>% + select(name, mass) %>% + mutate( + mass2 = mass * 2, + mass2_squared = mass2 * mass2 + ) %>% + collect(), + starwars # this is a test dataset that ships with dplyr + ) + + # As well as adding new variables, you can use mutate() to + # remove variables and modify existing variables. + expect_dplyr_equal( + input %>% + select(name, height, mass, homeworld) %>% + mutate( + mass = NULL, + height = height * 0.0328084 # convert to feet + ) %>% + collect(), + starwars + ) + + # Examples we don't support should succeed + # but warn that they're pulling data into R to do so + + # across + autosplicing: ARROW-11699 + expect_warning( + expect_dplyr_equal( + input %>% + select(name, homeworld, species) %>% + mutate(across(!name, as.factor)) %>% + collect(), + starwars + ), + "Expression across.*not supported in Arrow" + ) + + # group_by then mutate + expect_warning( + expect_dplyr_equal( + input %>% + select(name, mass, homeworld) %>% + group_by(homeworld) %>% + mutate(rank = min_rank(desc(mass))) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) + + # `.before` and `.after` experimental args: ARROW-11701 + df <- tibble(x = 1, y = 2) + expect_dplyr_equal( + input %>% mutate(z = x + y) %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> x y z + #> + #> 1 1 2 3 + expect_warning( + expect_dplyr_equal( + input %>% mutate(z = x + y, .before = 1) %>% collect(), + df + ), + "not supported in Arrow" + ) + #> # A tibble: 1 x 3 + #> z x y + #> + #> 1 3 1 2 + expect_warning( + expect_dplyr_equal( + input %>% mutate(z = x + y, .after = x) %>% collect(), + df + ), + "not supported in Arrow" + ) + #> # A tibble: 1 x 3 + #> x z y + #> + #> 1 1 3 2 + + # By default, mutate() keeps all columns from the input data. + # Experimental: You can override with `.keep` + df <- tibble(x = 1, y = 2, a = "a", b = "b") + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "all") %>% collect(), # the default + df + ) + #> # A tibble: 1 x 5 + #> x y a b z + #> + #> 1 1 2 a b 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "used") %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> x y z + #> + #> 1 1 2 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "unused") %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> a b z + #> + #> 1 a b 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "none") %>% collect(), # same as transmute() + df + ) + #> # A tibble: 1 x 1 + #> z + #> + #> 1 3 + + # Grouping ---------------------------------------- + # The mutate operation may yield different results on grouped + # tibbles because the expressions are computed within groups. + # The following normalises `mass` by the global average: + # TODO(ARROW-11702) + expect_warning( + expect_dplyr_equal( + input %>% + select(name, mass, species) %>% + mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) +}) + +test_that("handle bad expressions", { + # TODO: search for functions other than mean() (see above test) + # that need to be forced to fail because they error ambiguously + + skip("Error handling in arrow_eval() needs to be internationalized (ARROW-11700)") + expect_error( + Table$create(tbl) %>% mutate(newvar = NOTAVAR + 2), + "object 'NOTAVAR' not found" + ) +}) + +test_that("print a mutated dataset", { + expect_output( + Table$create(tbl) %>% + select(int) %>% + mutate(twice = int * 2) %>% + print(), +'Table (query) +int: int32 +twice: expr + +See $.data for the source Arrow object', + fixed = TRUE) + + # Handling non-expressions/edge cases + expect_output( + Table$create(tbl) %>% + select(int) %>% + mutate(again = 1:10) %>% + print(), +'Table (query) +int: int32 +again: expr + +See $.data for the source Arrow object', + fixed = TRUE) +}) + +test_that("mutate and write_dataset", { + # See related test in test-dataset.R + + skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651 + + first_date <- lubridate::ymd_hms("2015-04-29 03:12:39") + df1 <- tibble( + int = 1:10, + dbl = as.numeric(1:10), + lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2), + chr = letters[1:10], + fct = factor(LETTERS[1:10]), + ts = first_date + lubridate::days(1:10) + ) + + second_date <- lubridate::ymd_hms("2017-03-09 07:01:02") + df2 <- tibble( + int = 101:110, + dbl = c(as.numeric(51:59), NaN), + lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2), + chr = letters[10:1], + fct = factor(LETTERS[10:1]), + ts = second_date + lubridate::days(10:1) + ) + + dst_dir <- tempfile() + stacked <- record_batch(rbind(df1, df2)) + stacked %>% + mutate(twice = int * 2) %>% + group_by(int) %>% + write_dataset(dst_dir, format = "feather") + expect_true(dir.exists(dst_dir)) + expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "="))) + + new_ds <- open_dataset(dst_dir, format = "feather") + + expect_equivalent( + new_ds %>% + select(string = chr, integer = int, twice) %>% + filter(integer > 6 & integer < 11) %>% + collect() %>% + summarize(mean = mean(integer)), + df1 %>% + select(string = chr, integer = int) %>% + mutate(twice = integer * 2) %>% + filter(integer > 6) %>% + summarize(mean = mean(integer)) + ) +}) \ No newline at end of file diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 6d9945a115a..13610f1c6f1 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -15,74 +15,9 @@ # specific language governing permissions and limitations # under the License. -context("dplyr verbs") - library(dplyr) library(stringr) -expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start - tbl, # A tbl/df as reference, will make RB/Table with - skip_record_batch = NULL, # Msg, if should skip RB test - skip_table = NULL, # Msg, if should skip Table test - ...) { - expr <- rlang::enquo(expr) - expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) - - skip_msg <- NULL - - if (is.null(skip_record_batch)) { - via_batch <- rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = record_batch(tbl))) - ) - expect_equivalent(via_batch, expected, ...) - } else { - skip_msg <- c(skip_msg, skip_record_batch) - } - - if (is.null(skip_table)) { - via_table <- rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = Table$create(tbl))) - ) - expect_equivalent(via_table, expected, ...) - } else { - skip_msg <- c(skip_msg, skip_table) - } - - if (!is.null(skip_msg)) { - skip(paste(skip_msg, collpase = "\n")) - } -} - -expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start - tbl, # A tbl/df as reference, will make RB/Table with - ...) { - expr <- rlang::enquo(expr) - msg <- tryCatch( - rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), - error = function (e) conditionMessage(e) - ) - expect_is(msg, "character", label = "dplyr on data.frame did not error") - - expect_error( - rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = record_batch(tbl))) - ), - msg, - ... - ) - expect_error( - rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = Table$create(tbl))) - ), - msg, - ... - ) -} - tbl <- example_data # Add some better string data tbl$verses <- verses[[1]] @@ -104,127 +39,6 @@ test_that("basic select/filter/collect", { expect_identical(collect(batch), tbl) }) -test_that("filter() on is.na()", { - expect_dplyr_equal( - input %>% - filter(is.na(lgl)) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("filter() with NAs in selection", { - expect_dplyr_equal( - input %>% - filter(lgl) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("Filter returning an empty Table should not segfault (ARROW-8354)", { - expect_dplyr_equal( - input %>% - filter(false) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("filtering with expression", { - char_sym <- "b" - expect_dplyr_equal( - input %>% - filter(chr == char_sym) %>% - select(string = chr, int) %>% - collect(), - tbl - ) -}) - -test_that("filtering with arithmetic", { - expect_dplyr_equal( - input %>% - filter(dbl + 1 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl / 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl / 2L > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int / 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int / 2L > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl %/% 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) -}) - -test_that("filtering with expression + autocasting", { - expect_dplyr_equal( - input %>% - filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int + 1 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) -}) - -test_that("More complex select/filter", { - expect_dplyr_equal( - input %>% - filter(dbl > 2, chr == "d" | chr == "f") %>% - select(chr, int, lgl) %>% - filter(int < 5) %>% - select(int, chr) %>% - collect(), - tbl - ) -}) - test_that("dim() on query", { expect_dplyr_equal( input %>% @@ -247,151 +61,12 @@ test_that("Print method", { int: int32 chr: string -* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5)) +* Filter: and(and(greater(dbl, 2), or(equal(chr, "d"), equal(chr, "f"))), less(int, 5)) See $.data for the source Arrow object', fixed = TRUE ) }) -test_that("filter() with %in%", { - expect_dplyr_equal( - input %>% - filter(dbl > 2, chr %in% c("d", "f")) %>% - collect(), - tbl - ) -}) - -test_that("filter() with string ops", { - # Extra instrumentation to ensure that we're calling Arrow compute here - # because many base R string functions implicitly call as.character, - # which means they still work on Arrays but actually force data into R - # 1) wrapper that raises a warning if as.character is called. Can't wrap - # the whole test because as.character apparently gets called in other - # (presumably legitimate) places - # 2) Wrap the test in expect_warning(expr, NA) to catch the warning - - with_no_as_character <- function(expr) { - trace( - "as.character", - tracer = quote(warning("as.character was called")), - print = FALSE, - where = toupper - ) - on.exit(untrace("as.character", where = toupper)) - force(expr) - } - - expect_warning( - expect_dplyr_equal( - input %>% - filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% - collect(), - tbl - ), - NA) - - expect_dplyr_equal( - input %>% - filter(dbl > 2, str_length(verses) > 25) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>% - collect(), - tbl - ) -}) - -test_that("filter environment scope", { - # "object 'b_var' not found" - expect_dplyr_error(input %>% filter(batch, chr == b_var)) - - b_var <- "b" - expect_dplyr_equal( - input %>% - filter(chr == b_var) %>% - collect(), - tbl - ) - # Also for functions - # 'could not find function "isEqualTo"' - expect_dplyr_error(filter(batch, isEqualTo(int, 4))) - - # TODO: fix this: this isEqualTo function is eagerly evaluating; it should - # instead yield array_expressions. Probably bc the parent env of the function - # has the Ops.Array methods defined; we need to move it so that the parent - # env is the data mask we use in the dplyr eval - isEqualTo <- function(x, y) x == y & !is.na(x) - expect_dplyr_equal( - input %>% - select(-fct) %>% # factor levels aren't identical - filter(isEqualTo(int, 4)) %>% - collect(), - tbl - ) -}) - -test_that("Filtering on a column that doesn't exist errors correctly", { - skip("Error handling in filter() needs to be internationalized") - expect_error( - batch %>% filter(not_a_col == 42) %>% collect(), - "object 'not_a_col' not found" - ) -}) - -test_that("Filtering with a function that doesn't have an Array/expr method still works", { - expect_warning( - expect_dplyr_equal( - input %>% - filter(int > 2, pnorm(dbl) > .99) %>% - collect(), - tbl - ), - 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', - fixed = TRUE - ) -}) - -test_that("filter() with .data pronoun", { - expect_dplyr_equal( - input %>% - filter(.data$dbl > 4) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(is.na(.data$lgl)) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - # and the .env pronoun too! - chr <- 4 - expect_dplyr_equal( - input %>% - filter(.data$dbl > .env$chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - # but there is an error if we don't override the masking with `.env` - expect_dplyr_error( - tbl %>% - filter(.data$dbl > chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect() - ) -}) - test_that("summarize", { expect_dplyr_equal( input %>% @@ -410,29 +85,6 @@ test_that("summarize", { ) }) -test_that("mutate", { - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - mutate(int = int + 6L) %>% - summarize(min_int = min(int)), - tbl - ) -}) - -test_that("transmute", { - skip("TODO: reimplement transmute (with dplyr 1.0, it no longer just works via mutate)") - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - transmute(int = int + 6L) %>% - summarize(min_int = min(int)), - tbl - ) -}) - test_that("group_by groupings are recorded", { expect_dplyr_equal( input %>% @@ -599,7 +251,7 @@ test_that("collect(as_data_frame=FALSE)", { select(int, strng = chr) %>% filter(int > 5) %>% collect(as_data_frame = FALSE) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -632,7 +284,7 @@ test_that("head", { select(int, strng = chr) %>% filter(int > 5) %>% head(2) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -665,7 +317,7 @@ test_that("tail", { select(int, strng = chr) %>% filter(int > 5) %>% tail(2) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 3c100812ff1..3df7270f4c5 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -34,8 +34,20 @@ test_that("array_expression print method", { ) }) +test_that("array_refs", { + tab <- Table$create(a = 1:5) + ex <- build_array_expression(">", array_expression("array_ref", field_name = "a"), 4) + expect_is(ex, "array_expression") + expect_identical(ex$args[[1]]$args$field_name, "a") + expect_identical(find_array_refs(ex), "a") + out <- eval_array_expression(ex, tab) + expect_is(out, "ChunkedArray") + expect_equal(as.vector(out), c(FALSE, FALSE, FALSE, FALSE, TRUE)) +}) + test_that("C++ expressions", { f <- Expression$field_ref("f") + expect_identical(f$field_name, "f") g <- Expression$field_ref("g") date <- Expression$scalar(as.Date("2020-01-15")) ts <- Expression$scalar(as.POSIXct("2020-01-17 11:11:11"))