diff --git a/r/R/dplyr-across.R b/r/R/dplyr-across.R index 0d85764f7fb..6aeedc18f37 100644 --- a/r/R/dplyr-across.R +++ b/r/R/dplyr-across.R @@ -34,7 +34,11 @@ expand_across <- function(.data, quos_in, exclude_cols = NULL) { ) if (!all(names(across_call[-1]) %in% c(".cols", ".fns", ".names"))) { - abort("`...` argument to `across()` is deprecated in dplyr and not supported in Arrow") + arrow_not_supported( + "`...` argument to `across()` is deprecated in dplyr and", + body = c(">" = "Convert your call into a function or formula including the arguments"), + call = rlang::caller_call() + ) } if (!is.null(across_call[[".cols"]])) { diff --git a/r/R/dplyr-arrange.R b/r/R/dplyr-arrange.R index c8594c77df0..fdc69a708d1 100644 --- a/r/R/dplyr-arrange.R +++ b/r/R/dplyr-arrange.R @@ -19,47 +19,46 @@ # The following S3 methods are registered on load if dplyr is present arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { - call <- match.call() - .data <- as_adq(.data) - exprs <- expand_across(.data, quos(...)) + try_arrow_dplyr({ + .data <- as_adq(.data) + exprs <- expand_across(.data, quos(...)) - if (.by_group) { - # when the data is grouped and .by_group is TRUE, order the result by - # the grouping columns first - exprs <- c(quos(!!!dplyr::groups(.data)), exprs) - } - if (length(exprs) == 0) { - # Nothing to do - return(.data) - } - .data <- as_adq(.data) - # find and remove any dplyr::desc() and tidy-eval - # the arrange expressions inside an Arrow data_mask - sorts <- vector("list", length(exprs)) - descs <- logical(0) - mask <- arrow_mask(.data) - for (i in seq_along(exprs)) { - x <- find_and_remove_desc(exprs[[i]]) - exprs[[i]] <- x[["quos"]] - sorts[[i]] <- arrow_eval(exprs[[i]], mask) - names(sorts)[i] <- format_expr(exprs[[i]]) - if (inherits(sorts[[i]], "try-error")) { - msg <- paste("Expression", names(sorts)[i], "not supported in Arrow") - return(abandon_ship(call, .data, msg)) + if (.by_group) { + # when the data is grouped and .by_group is TRUE, order the result by + # the grouping columns first + exprs <- c(quos(!!!dplyr::groups(.data)), exprs) } - if (length(mask$.aggregations)) { - # dplyr lets you arrange on e.g. x < mean(x), but we haven't implemented it. - # But we could, the same way it works in mutate() via join, if someone asks. - # Until then, just error. - # TODO: add a test for this - msg <- paste("Expression", format_expr(expr), "not supported in arrange() in Arrow") - return(abandon_ship(call, .data, msg)) + if (length(exprs) == 0) { + # Nothing to do + return(.data) } - descs[i] <- x[["desc"]] - } - .data$arrange_vars <- c(sorts, .data$arrange_vars) - .data$arrange_desc <- c(descs, .data$arrange_desc) - .data + .data <- as_adq(.data) + # find and remove any dplyr::desc() and tidy-eval + # the arrange expressions inside an Arrow data_mask + sorts <- vector("list", length(exprs)) + descs <- logical(0) + mask <- arrow_mask(.data) + for (i in seq_along(exprs)) { + x <- find_and_remove_desc(exprs[[i]]) + exprs[[i]] <- x[["quos"]] + sorts[[i]] <- arrow_eval(exprs[[i]], mask) + names(sorts)[i] <- format_expr(exprs[[i]]) + if (length(mask$.aggregations)) { + # dplyr lets you arrange on e.g. x < mean(x), but we haven't implemented it. + # But we could, the same way it works in mutate() via join, if someone asks. + # Until then, just error. + # TODO: add a test for this + arrow_not_supported( + .actual_msg = "Expression not supported in arrange() in Arrow", + call = expr + ) + } + descs[i] <- x[["desc"]] + } + .data$arrange_vars <- c(sorts, .data$arrange_vars) + .data$arrange_desc <- c(descs, .data$arrange_desc) + .data + }) } arrange.Dataset <- arrange.ArrowTabular <- arrange.RecordBatchReader <- arrange.arrow_dplyr_query @@ -73,10 +72,9 @@ find_and_remove_desc <- function(quosure) { expr <- quo_get_expr(quosure) descending <- FALSE if (length(all.vars(expr)) < 1L) { - stop( - "Expression in arrange() does not contain any field names: ", - deparse(expr), - call. = FALSE + validation_error( + "Expression in arrange() does not contain any field names", + call = quosure ) } # Use a while loop to remove any number of nested pairs of enclosing @@ -90,7 +88,10 @@ find_and_remove_desc <- function(quosure) { # ensure desc() has only one argument (when an R expression is a function # call, length == 2 means it has exactly one argument) if (length(expr) > 2) { - stop("desc() expects only one argument", call. = FALSE) + validation_error( + "desc() expects only one argument", + call = expr + ) } # remove desc() and toggle descending expr <- expr[[2]] diff --git a/r/R/dplyr-datetime-helpers.R b/r/R/dplyr-datetime-helpers.R index c153f47cbaf..8e6a7f61853 100644 --- a/r/R/dplyr-datetime-helpers.R +++ b/r/R/dplyr-datetime-helpers.R @@ -18,10 +18,10 @@ check_time_locale <- function(locale = Sys.getlocale("LC_TIME")) { if (tolower(Sys.info()[["sysname"]]) == "windows" && locale != "C") { # MingW C++ std::locale only supports "C" and "POSIX" - stop(paste0( - "On Windows, time locales other than 'C' are not supported in Arrow. ", - "Consider setting `Sys.setlocale('LC_TIME', 'C')`" - )) + arrow_not_supported( + "On Windows, time locales other than 'C'", + body = c(">" = "Consider setting `Sys.setlocale('LC_TIME', 'C')`") + ) } locale } @@ -56,13 +56,15 @@ duration_from_chunks <- function(chunks) { matched_chunks <- accepted_chunks[pmatch(names(chunks), accepted_chunks, duplicates.ok = TRUE)] if (any(is.na(matched_chunks))) { - abort( - paste0( - "named `difftime` units other than: ", - oxford_paste(accepted_chunks, quote_symbol = "`"), - " not supported in Arrow. \nInvalid `difftime` parts: ", + arrow_not_supported( + paste( + "named `difftime` units other than:", + oxford_paste(accepted_chunks, quote_symbol = "`") + ), + body = c(i = paste( + "Invalid `difftime` parts:", oxford_paste(names(chunks[is.na(matched_chunks)]), quote_symbol = "`") - ) + )) ) } @@ -114,7 +116,6 @@ binding_as_date_character <- function(x, } binding_as_date_numeric <- function(x, origin = "1970-01-01") { - # Arrow does not support direct casting from double to date32(), but for # integer-like values we can go via int32() # TODO: revisit after ARROW-15798 @@ -442,7 +443,7 @@ parse_period_unit <- function(x) { unit <- as.integer(pmatch(str_unit_start, known_units)) - 1L if (any(is.na(unit))) { - abort( + validation_error( sprintf( "Invalid period name: '%s'", str_unit, @@ -484,13 +485,13 @@ parse_period_unit <- function(x) { # more special cases: lubridate imposes sensible maximum # values on the number of seconds, minutes and hours if (unit == 3L && multiple > 60) { - abort("Rounding with second > 60 is not supported") + validation_error("Rounding with second > 60 is not supported") } if (unit == 4L && multiple > 60) { - abort("Rounding with minute > 60 is not supported") + validation_error("Rounding with minute > 60 is not supported") } if (unit == 5L && multiple > 24) { - abort("Rounding with hour > 24 is not supported") + validation_error("Rounding with hour > 24 is not supported") } list(unit = unit, multiple = multiple) diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R index 211c26cecce..1997d698c0b 100644 --- a/r/R/dplyr-eval.R +++ b/r/R/dplyr-eval.R @@ -25,30 +25,64 @@ arrow_eval <- function(expr, mask) { add_user_functions_to_mask(expr, mask) # This yields an Expression as long as the `exprs` are implemented in Arrow. - # Otherwise, it returns a try-error + # Otherwise, it raises a classed error, either: + # * arrow_not_supported: the expression is not supported in Arrow; retry with + # regular dplyr may work + # * validation_error: the expression is known to be not valid, so don't + # recommend retrying with regular dplyr tryCatch(eval_tidy(expr, mask), error = function(e) { - # Look for the cases where bad input was given, i.e. this would fail - # in regular dplyr anyway, and let those raise those as errors; - # else, for things not supported in Arrow return a "try-error", - # which we'll handle differently + # Inspect why the expression failed, and add the expr as the `call` + # for better error messages msg <- conditionMessage(e) - if (getOption("arrow.debug", FALSE)) print(msg) - patterns <- .cache$i18ized_error_pattern - if (is.null(patterns)) { - patterns <- i18ize_error_messages() - # Memoize it - .cache$i18ized_error_pattern <- patterns - } - if (grepl(patterns, msg)) { + arrow_debug <- getOption("arrow.debug", FALSE) + if (arrow_debug) print(msg) + + # A few cases: + # 1. Evaluation raised one of our error classes. Add the expr as the call + # and re-raise it. + if (inherits(e, c("validation_error", "arrow_not_supported"))) { + e$call <- expr stop(e) } - out <- structure(msg, class = "try-error", condition = e) - if (grepl("not supported.*Arrow|NotImplemented", msg) || getOption("arrow.debug", FALSE)) { - # One of ours. Mark it so that consumers can handle it differently - class(out) <- c("arrow-try-error", class(out)) + # 2. Error is from assert_that: raise as validation_error + if (inherits(e, "assertError")) { + validation_error(msg, call = expr) + } + + # 3. Check to see if this is a standard R error message (not found etc.). + # Retry with dplyr won't help. + if (grepl(get_standard_error_messages(), msg)) { + # Raise the original error: it's actually helpful here + validation_error(msg, call = expr) + } + # 3b. Check to see if this is from match.arg. Retry with dplyr won't help. + if (is.language(e$call) && identical(as.character(e$call[[1]]), "match.arg")) { + # Raise the original error: it's actually helpful here + validation_error(msg, call = expr) + } + + # 4. Check for NotImplemented error raised from Arrow C++ code. + # Not sure where exactly we may raise this, but if we see it, it means + # that something isn't supported in Arrow. Retry in dplyr may help? + if (grepl("NotImplemented", msg)) { + arrow_not_supported(.actual_msg = msg, call = expr) + } + + + # 5. Otherwise, we're not sure why this errored: it's not an error we raised + # explicitly. We'll assume it's because the function it calls isn't + # supported in arrow, and retry with dplyr may help. + if (arrow_debug) { + arrow_not_supported(.actual_msg = msg, call = expr) + } else { + # Don't show the original error message unless in debug mode because + # it's probably not helpful: like, if you've passed an Expression to a + # regular R function that operates on strings, the way it errors would be + # more confusing than just saying that the expression is not supported + # in arrow. + arrow_not_supported("Expression", call = expr) } - invisible(out) }) } @@ -93,15 +127,12 @@ add_user_functions_to_mask <- function(expr, mask) { invisible() } -handle_arrow_not_supported <- function(err, lab) { - # Look for informative message from the Arrow function version (see above) - if (inherits(err, "arrow-try-error")) { - # Include it if found - paste0("In ", lab, ", ", as.character(err)) - } else { - # Otherwise be opaque (the original error is probably not useful) - paste("Expression", lab, "not supported in Arrow") +get_standard_error_messages <- function() { + if (is.null(.cache$i18ized_error_pattern)) { + # Memoize it + .cache$i18ized_error_pattern <- i18ize_error_messages() } + .cache$i18ized_error_pattern } i18ize_error_messages <- function() { @@ -114,10 +145,101 @@ i18ize_error_messages <- function() { paste(map(out, ~ sub("X_____X", ".*", .)), collapse = "|") } -# Helper to raise a common error -arrow_not_supported <- function(msg) { - # TODO: raise a classed error? - stop(paste(msg, "not supported in Arrow"), call. = FALSE) +#' Helpers to raise classed errors +#' +#' `arrow_not_supported()` and `validation_error()` raise classed errors that +#' allow us to distinguish between things that are not supported in Arrow and +#' things that are just invalid input. Additional wrapping in `arrow_eval()` +#' and `try_arrow_dplyr()` provide more context and suggestions. +#' Importantly, if `arrow_not_supported` is raised, then retrying the same code +#' in regular dplyr in R may work. But if `validation_error` is raised, then we +#' shouldn't recommend retrying with regular dplyr because it will fail there +#' too. +#' +#' Use these in function bindings and in the dplyr methods. Inside of function +#' bindings, you don't need to provide the `call` argument, as it will be +#' automatically filled in with the expression that caused the error in +#' `arrow_eval()`. In dplyr methods, you should provide the `call` argument; +#' `rlang::caller_call()` often is correct, but you may need to experiment to +#' find how far up the call stack you need to look. +#' +#' You may provide additional information in the `body` argument, a named +#' character vector. Use `i` for additional information about the error and `>` +#' to indicate potential solutions or workarounds that don't require pulling the +#' data into R. If you have an `arrow_not_supported()` error with a `>` +#' suggestion, when the error is ultimately raised by `try_error_dplyr()`, +#' `Call collect() first to pull data into R` won't be the only suggestion. +#' +#' You can still use `match.arg()` and `assert_that()` for simple input +#' validation inside of the function bindings. `arrow_eval()` will catch their +#' errors and re-raise them as `validation_error`. +#' +#' @param msg The message to show. `arrow_not_supported()` will append +#' "not supported in Arrow" to this message. +#' @param .actual_msg If you don't want to append "not supported in Arrow" to +#' the message, you can provide the full message here. +#' @param ... Additional arguments to pass to `rlang::abort()`. Useful arguments +#' include `call` to provide the call or expression that caused the error, and +#' `body` to provide additional context about the error. +#' @keywords internal +arrow_not_supported <- function(msg, + .actual_msg = paste(msg, "not supported in Arrow"), + ...) { + abort(.actual_msg, class = "arrow_not_supported", use_cli_format = TRUE, ...) +} + +#' @rdname arrow_not_supported +validation_error <- function(msg, ...) { + abort(msg, class = "validation_error", use_cli_format = TRUE, ...) +} + +# Wrap the contents of an arrow dplyr verb function in a tryCatch block to +# handle arrow_not_supported errors: +# * If it errors because of arrow_not_supported, abandon ship +# * If it's another error, just stop, retry with regular dplyr won't help +try_arrow_dplyr <- function(expr) { + parent <- caller_env() + # Make sure that the call is available in the parent environment + # so that we can use it in abandon_ship, if needed + evalq(call <- match.call(), parent) + + tryCatch( + eval(expr, parent), + arrow_not_supported = function(e) abandon_ship(e, parent) + ) +} + +# Helper to handle unsupported dplyr features +# * For Table/RecordBatch, we collect() and then call the dplyr method in R +# * For Dataset, we error and recommend collect() +# Requires that `env` contains `.data` +# The Table/RB path also requires `call` to be in `env` (try_arrow_dplyr adds it) +# and that the function being called also exists in the dplyr namespace. +abandon_ship <- function(err, env) { + .data <- get(".data", envir = env) + if (query_on_dataset(.data)) { + # Add a note suggesting `collect()` to the error message. + # If there are other suggestions already there (with the > arrow name), + # collect() isn't the only suggestion, so message differently + msg <- ifelse( + ">" %in% names(err$body), + "Or, call collect() first to pull data into R.", + "Call collect() first to pull data into R." + ) + err$body <- c(err$body, ">" = msg) + stop(err) + } + + # Else, warn, collect(), and run in regular dplyr + call <- get("call", envir = env) + rlang::warn( + message = paste0("In ", format_expr(err$call), ": "), + body = c("i" = conditionMessage(err), ">" = "Pulling data into R") + ) + call$.data <- dplyr::collect(.data) + dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]])) + call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr")) + eval(call, env) } # Create a data mask for evaluating a dplyr expression diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R index 69decbd7665..36219e411e5 100644 --- a/r/R/dplyr-filter.R +++ b/r/R/dplyr-filter.R @@ -19,45 +19,45 @@ # The following S3 methods are registered on load if dplyr is present filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) { - # TODO something with the .preserve argument - out <- as_adq(.data) + try_arrow_dplyr({ + # TODO something with the .preserve argument + out <- as_adq(.data) - by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") - if (by$from_by) { - out$group_by_vars <- by$names - } - - expanded_filters <- expand_across(out, quos(...)) - if (length(expanded_filters) == 0) { - # Nothing to do - return(as_adq(.data)) - } + if (by$from_by) { + out$group_by_vars <- by$names + } - # tidy-eval the filter expressions inside an Arrow data_mask - mask <- arrow_mask(out) - for (expr in expanded_filters) { - filt <- arrow_eval(expr, mask) - if (inherits(filt, "try-error")) { - msg <- handle_arrow_not_supported(filt, format_expr(expr)) - return(abandon_ship(match.call(), .data, msg)) + expanded_filters <- expand_across(out, quos(...)) + if (length(expanded_filters) == 0) { + # Nothing to do + return(as_adq(.data)) } - if (length(mask$.aggregations)) { - # dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it. - # But we could, the same way it works in mutate() via join, if someone asks. - # Until then, just error. - # TODO: add a test for this - msg <- paste("Expression", format_expr(expr), "not supported in filter() in Arrow") - return(abandon_ship(match.call(), .data, msg)) + + # tidy-eval the filter expressions inside an Arrow data_mask + mask <- arrow_mask(out) + for (expr in expanded_filters) { + filt <- arrow_eval(expr, mask) + if (length(mask$.aggregations)) { + # dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it. + # But we could, the same way it works in mutate() via join, if someone asks. + # Until then, just error. + # TODO: add a test for this + arrow_not_supported( + .actual_msg = "Expression not supported in filter() in Arrow", + call = expr + ) + } + out <- set_filters(out, filt) } - out <- set_filters(out, filt) - } - if (by$from_by) { - out$group_by_vars <- character() - } + if (by$from_by) { + out$group_by_vars <- character() + } - out + out + }) } filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query diff --git a/r/R/dplyr-funcs-agg.R b/r/R/dplyr-funcs-agg.R index c0c4eb30894..340ebe7adc9 100644 --- a/r/R/dplyr-funcs-agg.R +++ b/r/R/dplyr-funcs-agg.R @@ -155,7 +155,7 @@ register_bindings_aggregate <- function() { set_agg <- function(...) { agg_data <- list2(...) # Find the environment where .aggregations is stored - target <- find_aggregations_env() + target <- find_arrow_mask() aggs <- get(".aggregations", target) lapply(agg_data[["data"]], function(expr) { # If any of the fields referenced in the expression are in .aggregations, @@ -176,8 +176,8 @@ set_agg <- function(...) { Expression$field_ref(tmpname) } -find_aggregations_env <- function() { - # Find the environment where .aggregations is stored, +find_arrow_mask <- function() { + # Find the arrow_mask environment by looking for .aggregations, # it's in parent.env of something in the call stack n <- 1 while (TRUE) { diff --git a/r/R/dplyr-funcs-conditional.R b/r/R/dplyr-funcs-conditional.R index b9639f00295..3ab955aa8ae 100644 --- a/r/R/dplyr-funcs-conditional.R +++ b/r/R/dplyr-funcs-conditional.R @@ -37,7 +37,7 @@ register_bindings_conditional <- function() { register_binding("dplyr::coalesce", function(...) { args <- list2(...) if (length(args) < 1) { - abort("At least one argument must be supplied to coalesce()") + validation_error("At least one argument must be supplied to coalesce()") } # Treat NaN like NA for consistency with dplyr::coalesce(), but if *all* @@ -102,7 +102,7 @@ register_bindings_conditional <- function() { formulas <- list2(...) n <- length(formulas) if (n == 0) { - abort("No cases provided in case_when()") + validation_error("No cases provided") } query <- vector("list", n) value <- vector("list", n) @@ -110,20 +110,17 @@ register_bindings_conditional <- function() { for (i in seq_len(n)) { f <- formulas[[i]] if (!inherits(f, "formula")) { - abort("Each argument to case_when() must be a two-sided formula") + validation_error("Each argument to case_when() must be a two-sided formula") } query[[i]] <- arrow_eval(f[[2]], mask) value[[i]] <- arrow_eval(f[[3]], mask) if (!call_binding("is.logical", query[[i]])) { - abort("Left side of each formula in case_when() must be a logical expression") - } - if (inherits(value[[i]], "try-error")) { - abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) + validation_error("Left side of each formula in case_when() must be a logical expression") } } if (!is.null(.default)) { if (length(.default) != 1) { - abort(paste0("`.default` must have size 1, not size ", length(.default), ".")) + validation_error(paste0("`.default` must have size 1, not size ", length(.default), ".")) } query[n + 1] <- TRUE @@ -140,6 +137,5 @@ register_bindings_conditional <- function() { value ) ) - }, notes = "`.ptype` and `.size` arguments not supported" - ) + }, notes = "`.ptype` and `.size` arguments not supported") } diff --git a/r/R/dplyr-funcs-datetime.R b/r/R/dplyr-funcs-datetime.R index 440210afd63..5e6ac4a1035 100644 --- a/r/R/dplyr-funcs-datetime.R +++ b/r/R/dplyr-funcs-datetime.R @@ -121,7 +121,7 @@ register_bindings_datetime_utility <- function() { precision <- "ymdhms" } if (!precision %in% names(ISO8601_precision_map)) { - abort( + validation_error( paste( "`precision` must be one of the following values:", paste(names(ISO8601_precision_map), collapse = ", "), @@ -325,10 +325,10 @@ register_bindings_datetime_conversion <- function() { origin = "1970-01-01", tz = "UTC") { if (is.null(format) && length(tryFormats) > 1) { - abort( - paste( - "`as.Date()` with multiple `tryFormats` is not supported in Arrow.", - "Consider using the lubridate specialised parsing functions `ymd()`, `ymd()`, etc." + arrow_not_supported( + "`as.Date()` with multiple `tryFormats`", + body = c( + ">" = "Consider using the lubridate specialised parsing functions `ymd()`, `ymd()`, etc." ) ) } @@ -455,15 +455,13 @@ register_bindings_datetime_timezone <- function() { arrow_not_supported("`roll_dst` must be 1 or 2 items long; other lengths") } - nonexistent <- switch( - roll_dst[1], + nonexistent <- switch(roll_dst[1], "error" = 0L, "boundary" = 2L, arrow_not_supported("`roll_dst` value must be 'error' or 'boundary' for nonexistent times; other values") ) - ambiguous <- switch( - roll_dst[2], + ambiguous <- switch(roll_dst[2], "error" = 0L, "pre" = 1L, "post" = 2L, @@ -651,7 +649,7 @@ register_bindings_duration_helpers <- function() { register_binding( "lubridate::dpicoseconds", function(x = 1) { - abort("Duration in picoseconds not supported in Arrow.") + arrow_not_supported("Duration in picoseconds") }, notes = "not supported" ) diff --git a/r/R/dplyr-funcs-simple.R b/r/R/dplyr-funcs-simple.R index 308a46601a6..4ccc2498435 100644 --- a/r/R/dplyr-funcs-simple.R +++ b/r/R/dplyr-funcs-simple.R @@ -177,7 +177,7 @@ common_type <- function(exprs) { # * pmin/pmax return(first_type) } - stop("There is no common type in these expressions") + validation_error("There is no common type in these expressions") } cast_or_parse <- function(x, type) { diff --git a/r/R/dplyr-funcs-string.R b/r/R/dplyr-funcs-string.R index a21ce78edd1..77e1a5405a6 100644 --- a/r/R/dplyr-funcs-string.R +++ b/r/R/dplyr-funcs-string.R @@ -134,9 +134,9 @@ format_string_replacement <- function(replacement, ignore.case, fixed) { # Arrow locale will be supported with ARROW-14126 stop_if_locale_provided <- function(locale) { if (!identical(locale, "en")) { - stop("Providing a value for 'locale' other than the default ('en') is not supported in Arrow. ", - "To change locale, use 'Sys.setlocale()'", - call. = FALSE + arrow_not_supported( + "Providing a value for 'locale' other than the default ('en')", + body = c(">" = "To change locale, use 'Sys.setlocale()'") ) } } @@ -158,10 +158,11 @@ register_bindings_string_join <- function() { # handle scalar literal args, and cast all args to string for # consistency with base::paste(), base::paste0(), and stringr::str_c() if (!inherits(arg, "Expression")) { - assert_that( - length(arg) == 1, - msg = "Literal vectors of length != 1 not supported in string concatenation" - ) + if (length(arg) != 1) { + arrow_not_supported( + "Literal vectors of length != 1 in string concatenation" + ) + } Expression$scalar(as.character(arg)) } else { call_binding("as.character", arg) @@ -181,12 +182,11 @@ register_bindings_string_join <- function() { register_binding( "base::paste", function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { - assert_that( - is.null(collapse), - msg = "paste() with the collapse argument is not yet supported in Arrow" - ) - if (!inherits(sep, "Expression")) { - assert_that(!is.na(sep), msg = "Invalid separator") + if (!is.null(collapse)) { + arrow_not_supported("`collapse` argument") + } + if (!inherits(sep, "Expression") && is.na(sep)) { + validation_error("Invalid separator") } arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) }, @@ -196,10 +196,9 @@ register_bindings_string_join <- function() { register_binding( "base::paste0", function(..., collapse = NULL, recycle0 = FALSE) { - assert_that( - is.null(collapse), - msg = "paste0() with the collapse argument is not yet supported in Arrow" - ) + if (!is.null(collapse)) { + arrow_not_supported("`collapse` argument") + } arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") }, notes = "the `collapse` argument is not yet supported" @@ -208,12 +207,11 @@ register_bindings_string_join <- function() { register_binding( "stringr::str_c", function(..., sep = "", collapse = NULL) { - assert_that( - is.null(collapse), - msg = "str_c() with the collapse argument is not yet supported in Arrow" - ) - if (!inherits(sep, "Expression")) { - assert_that(!is.na(sep), msg = "`sep` must be a single string, not `NA`.") + if (!is.null(collapse)) { + arrow_not_supported("`collapse` argument") + } + if (!inherits(sep, "Expression") && is.na(sep)) { + validation_error("`sep` must be a single string, not `NA`.") } arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) }, @@ -352,10 +350,10 @@ register_bindings_string_regex <- function() { arrow_r_string_replace_function <- function(max_replacements) { function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) { if (length(pattern) != 1) { - stop("`pattern` must be a length 1 character vector") + validation_error("`pattern` must be a length 1 character vector") } if (length(replacement) != 1) { - stop("`replacement` must be a length 1 character vector") + validation_error("`replacement` must be a length 1 character vector") } Expression$create( ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"), @@ -512,14 +510,12 @@ register_bindings_string_other <- function() { register_binding( "base::substr", function(x, start, stop) { - assert_that( - length(start) == 1, - msg = "`start` must be length 1 - other lengths are not supported in Arrow" - ) - assert_that( - length(stop) == 1, - msg = "`stop` must be length 1 - other lengths are not supported in Arrow" - ) + if (length(start) != 1) { + arrow_not_supported("`start` must be length 1 - other lengths") + } + if (length(stop) != 1) { + arrow_not_supported("`stop` must be length 1 - other lengths") + } # substr treats values as if they're on a continuous number line, so values # 0 are effectively blank characters - set `start` to 1 here so Arrow mimics @@ -561,14 +557,12 @@ register_bindings_string_other <- function() { }) register_binding("stringr::str_sub", function(string, start = 1L, end = -1L) { - assert_that( - length(start) == 1, - msg = "`start` must be length 1 - other lengths are not supported in Arrow" - ) - assert_that( - length(end) == 1, - msg = "`end` must be length 1 - other lengths are not supported in Arrow" - ) + if (length(start) != 1) { + arrow_not_supported("`start` must be length 1 - other lengths") + } + if (length(end) != 1) { + arrow_not_supported("`end` must be length 1 - other lengths") + } # In stringr::str_sub, an `end` value of -1 means the end of the string, so # set it to the maximum integer to match this behavior diff --git a/r/R/dplyr-funcs-type.R b/r/R/dplyr-funcs-type.R index efb3c6b756a..85c26ec05c8 100644 --- a/r/R/dplyr-funcs-type.R +++ b/r/R/dplyr-funcs-type.R @@ -105,7 +105,7 @@ register_bindings_type_cast <- function() { } else if (inherits(class2, "DataType")) { object$type() == as_type(class2) } else { - stop("Second argument to is() is not a string or DataType", call. = FALSE) + validation_error("Second argument to is() is not a string or DataType") } }) @@ -219,7 +219,10 @@ register_bindings_type_inspect <- function() { call_binding("is.character", x) }) register_binding("rlang::is_double", function(x, n = NULL, finite = NULL) { - assert_that(is.null(n) && is.null(finite)) + assert_that(is.null(n)) + if (!is.null(finite)) { + arrow_not_supported("`finite` argument") + } call_binding("is.double", x) }) register_binding("rlang::is_integer", function(x, n = NULL) { diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index f0a8c005676..fcb1cedbbb1 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -24,122 +24,116 @@ mutate.arrow_dplyr_query <- function(.data, .keep = c("all", "used", "unused", "none"), .before = NULL, .after = NULL) { - call <- match.call() - out <- as_adq(.data) + try_arrow_dplyr({ + out <- as_adq(.data) - by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") - if (by$from_by) { - out$group_by_vars <- by$names - } - grv <- out$group_by_vars - expression_list <- expand_across(out, quos(...), exclude_cols = grv) - exprs <- ensure_named_exprs(expression_list) + if (by$from_by) { + out$group_by_vars <- by$names + } + grv <- out$group_by_vars + expression_list <- expand_across(out, quos(...), exclude_cols = grv) + exprs <- ensure_named_exprs(expression_list) - .keep <- match.arg(.keep) - .before <- enquo(.before) - .after <- enquo(.after) + .keep <- match.arg(.keep) + .before <- enquo(.before) + .after <- enquo(.after) - if (.keep %in% c("all", "unused") && length(exprs) == 0) { - # Nothing to do - return(out) - } + if (.keep %in% c("all", "unused") && length(exprs) == 0) { + # Nothing to do + return(out) + } - # Create a mask with aggregation functions in it - # If there are any aggregations, we will need to compute them and - # and join the results back in, for "window functions" like x - mean(x) - mask <- arrow_mask(out) - # Evaluate the mutate expressions - results <- list() - for (i in seq_along(exprs)) { - # Iterate over the indices and not the names because names may be repeated - # (which overwrites the previous name) - new_var <- names(exprs)[i] - results[[new_var]] <- arrow_eval(exprs[[i]], mask) - if (inherits(results[[new_var]], "try-error")) { - msg <- handle_arrow_not_supported( - results[[new_var]], - format_expr(exprs[[i]]) - ) - return(abandon_ship(call, .data, msg)) - } else if (!inherits(results[[new_var]], "Expression") && - !is.null(results[[new_var]])) { - # We need some wrapping to handle literal values - if (length(results[[new_var]]) != 1) { - msg <- paste0("In ", new_var, " = ", format_expr(exprs[[i]]), ", only values of size one are recycled") - return(abandon_ship(call, .data, msg)) + # Create a mask with aggregation functions in it + # If there are any aggregations, we will need to compute them and + # and join the results back in, for "window functions" like x - mean(x) + mask <- arrow_mask(out) + # Evaluate the mutate expressions + results <- list() + for (i in seq_along(exprs)) { + # Iterate over the indices and not the names because names may be repeated + # (which overwrites the previous name) + new_var <- names(exprs)[i] + results[[new_var]] <- arrow_eval(exprs[[i]], mask) + if (!inherits(results[[new_var]], "Expression") && + !is.null(results[[new_var]])) { + # We need some wrapping to handle literal values + if (length(results[[new_var]]) != 1) { + arrow_not_supported("Recycling values of length != 1", call = exprs[[i]]) + } + results[[new_var]] <- Expression$scalar(results[[new_var]]) } - results[[new_var]] <- Expression$scalar(results[[new_var]]) + # Put it in the data mask too + mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] } - # Put it in the data mask too - mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] - } - if (length(mask$.aggregations)) { - # Make a copy of .data, do the aggregations on it, and then left_join on - # the group_by variables. - agg_query <- as_adq(.data) - # These may be computed by .by, make sure they're set - agg_query$group_by_vars <- grv - agg_query$aggregations <- mask$.aggregations - agg_query <- collapse.arrow_dplyr_query(agg_query) - if (length(grv)) { - out <- dplyr::left_join(out, agg_query, by = grv) - } else { - # If there are no group_by vars, add a scalar column to both and join on that - agg_query$selected_columns[["..tempjoin"]] <- Expression$scalar(1L) - out$selected_columns[["..tempjoin"]] <- Expression$scalar(1L) - out <- dplyr::left_join(out, agg_query, by = "..tempjoin") + if (length(mask$.aggregations)) { + # Make a copy of .data, do the aggregations on it, and then left_join on + # the group_by variables. + agg_query <- as_adq(.data) + # These may be computed by .by, make sure they're set + agg_query$group_by_vars <- grv + agg_query$aggregations <- mask$.aggregations + agg_query <- collapse.arrow_dplyr_query(agg_query) + if (length(grv)) { + out <- left_join(out, agg_query, by = grv) + } else { + # If there are no group_by vars, add a scalar column to both and join on that + agg_query$selected_columns[["..tempjoin"]] <- Expression$scalar(1L) + out$selected_columns[["..tempjoin"]] <- Expression$scalar(1L) + out <- left_join(out, agg_query, by = "..tempjoin") + } } - } - old_vars <- names(out$selected_columns) - # Note that this is names(exprs) not names(results): - # if results$new_var is NULL, that means we are supposed to remove it - new_vars <- names(exprs) + old_vars <- names(out$selected_columns) + # Note that this is names(exprs) not names(results): + # if results$new_var is NULL, that means we are supposed to remove it + new_vars <- names(exprs) - # Assign the new columns into the out$selected_columns - for (new_var in new_vars) { - out$selected_columns[[new_var]] <- results[[new_var]] - } + # Assign the new columns into the out$selected_columns + for (new_var in new_vars) { + out$selected_columns[[new_var]] <- results[[new_var]] + } - # Prune any ..temp columns from the result, which would have come from - # .aggregations - temps <- grepl("^\\.\\.temp", names(out$selected_columns)) - out$selected_columns <- out$selected_columns[!temps] + # Prune any ..temp columns from the result, which would have come from + # .aggregations + temps <- grepl("^\\.\\.temp", names(out$selected_columns)) + out$selected_columns <- out$selected_columns[!temps] - # Deduplicate new_vars and remove NULL columns from new_vars - new_vars <- intersect(union(new_vars, grv), names(out$selected_columns)) + # Deduplicate new_vars and remove NULL columns from new_vars + new_vars <- intersect(union(new_vars, grv), names(out$selected_columns)) - # Respect .before and .after - if (!quo_is_null(.before) || !quo_is_null(.after)) { - new <- setdiff(new_vars, old_vars) - out <- dplyr::relocate(out, all_of(new), .before = !!.before, .after = !!.after) - } + # Respect .before and .after + if (!quo_is_null(.before) || !quo_is_null(.after)) { + new <- setdiff(new_vars, old_vars) + out <- dplyr::relocate(out, all_of(new), .before = !!.before, .after = !!.after) + } - # Respect .keep - if (.keep == "none") { - ## for consistency with dplyr, this appends new columns after existing columns - ## by specifying the order - new_cols_last <- c(intersect(old_vars, new_vars), setdiff(new_vars, old_vars)) - out$selected_columns <- out$selected_columns[new_cols_last] - } else if (.keep != "all") { - # "used" or "unused" - used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) - if (.keep == "used") { - out$selected_columns[setdiff(old_vars, used_vars)] <- NULL - } else { - # "unused" - out$selected_columns[intersect(old_vars, used_vars)] <- NULL + # Respect .keep + if (.keep == "none") { + ## for consistency with dplyr, this appends new columns after existing columns + ## by specifying the order + new_cols_last <- c(intersect(old_vars, new_vars), setdiff(new_vars, old_vars)) + out$selected_columns <- out$selected_columns[new_cols_last] + } else if (.keep != "all") { + # "used" or "unused" + used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) + if (.keep == "used") { + out$selected_columns[setdiff(old_vars, used_vars)] <- NULL + } else { + # "unused" + out$selected_columns[intersect(old_vars, used_vars)] <- NULL + } } - } - if (by$from_by) { - out$group_by_vars <- character() - } + if (by$from_by) { + out$group_by_vars <- character() + } - # Even if "none", we still keep group vars - ensure_group_vars(out) + # Even if "none", we still keep group vars + ensure_group_vars(out) + }) } mutate.Dataset <- mutate.ArrowTabular <- mutate.RecordBatchReader <- mutate.arrow_dplyr_query diff --git a/r/R/dplyr-slice.R b/r/R/dplyr-slice.R index bcb6547f7c8..2173d897f1f 100644 --- a/r/R/dplyr-slice.R +++ b/r/R/dplyr-slice.R @@ -148,7 +148,7 @@ prop_to_n <- function(.data, prop) { validate_prop <- function(prop) { if (!is.numeric(prop) || length(prop) != 1 || is.na(prop) || prop < 0 || prop > 1) { - stop("`prop` must be a single numeric value between 0 and 1", call. = FALSE) + validation_error("`prop` must be a single numeric value between 0 and 1") } } diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 58ca849152a..f4fda0f13aa 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -18,39 +18,18 @@ # The following S3 methods are registered on load if dplyr is present summarise.arrow_dplyr_query <- function(.data, ..., .by = NULL, .groups = NULL) { - call <- match.call() - out <- as_adq(.data) + try_arrow_dplyr({ + out <- as_adq(.data) - by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") - - if (by$from_by) { - out$group_by_vars <- by$names - .groups <- "drop" - } - - exprs <- expand_across(out, quos(...), exclude_cols = out$group_by_vars) - - # Only retain the columns we need to do our aggregations - vars_to_keep <- unique(c( - unlist(lapply(exprs, all.vars)), # vars referenced in summarise - dplyr::group_vars(out) # vars needed for grouping - )) - # If exprs rely on the results of previous exprs - # (total = sum(x), mean = total / n()) - # then not all vars will correspond to columns in the data, - # so don't try to select() them (use intersect() to exclude them) - # Note that this select() isn't useful for the Arrow summarize implementation - # because it will effectively project to keep what it needs anyway, - # but the data.frame fallback version does benefit from select here - out <- dplyr::select(out, intersect(vars_to_keep, names(out))) - - # Try stuff, if successful return() - out <- try(do_arrow_summarize(out, !!!exprs, .groups = .groups), silent = TRUE) - if (inherits(out, "try-error")) { - out <- abandon_ship(call, .data, format(out)) - } + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + if (by$from_by) { + out$group_by_vars <- by$names + .groups <- "drop" + } - out + exprs <- expand_across(out, quos(...), exclude_cols = out$group_by_vars) + do_arrow_summarize(out, !!!exprs, .groups = .groups) + }) } summarise.Dataset <- summarise.ArrowTabular <- summarise.RecordBatchReader <- summarise.arrow_dplyr_query @@ -120,11 +99,10 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { # the schema of the data after summarize(). Evaulating its type will # throw an error if it's invalid. tryCatch(post_mutate[[post]]$type(out$.data$schema), error = function(e) { - msg <- paste( - "Expression", as_label(exprs[[post]]), - "is not a valid aggregation expression or is" + arrow_not_supported( + "Expression is not a valid aggregation expression or is", + call = exprs[[post]] ) - arrow_not_supported(msg) }) # If it's valid, add it to the .data object out$selected_columns[[post]] <- post_mutate[[post]] @@ -166,12 +144,18 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { } else if (.groups == "keep") { out$group_by_vars <- .data$group_by_vars } else if (.groups == "rowwise") { - stop(arrow_not_supported('.groups = "rowwise"')) + arrow_not_supported( + '.groups = "rowwise"', + call = rlang::caller_call() + ) } else if (.groups == "drop") { # collapse() preserves groups so remove them out <- dplyr::ungroup(out) } else { - stop(paste("Invalid .groups argument:", .groups)) + validation_error( + paste("Invalid .groups argument:", .groups), + call = rlang::caller_call() + ) } out$drop_empty_groups <- .data$drop_empty_groups if (getOption("arrow.summarise.sort", FALSE)) { @@ -183,16 +167,6 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { out } -arrow_eval_or_stop <- function(expr, mask) { - # TODO: change arrow_eval error handling behavior? - out <- arrow_eval(expr, mask) - if (inherits(out, "try-error")) { - msg <- handle_arrow_not_supported(out, format_expr(expr)) - stop(msg, call. = FALSE) - } - out -} - # This function returns a list of expressions which is used to project the data # before an aggregation. This list includes the fields used in the aggregation # expressions (the "targets") and the group fields. The names of the returned @@ -271,7 +245,7 @@ summarize_eval <- function(name, quosure, mask) { mask[[n]] <- mask$.data[[n]] <- Expression$field_ref(n) } # Evaluate: - value <- arrow_eval_or_stop(quosure, mask) + value <- arrow_eval(quosure, mask) # Handle the result. There are a few different cases. if (!inherits(value, "Expression")) { diff --git a/r/R/dplyr.R b/r/R/dplyr.R index f11b88d301e..93fcfdef28f 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -338,22 +338,6 @@ ensure_arrange_vars <- function(x) { x } -# Helper to handle unsupported dplyr features -# * For Table/RecordBatch, we collect() and then call the dplyr method in R -# * For Dataset, we just error -abandon_ship <- function(call, .data, msg) { - msg <- trimws(msg) - dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]])) - if (query_on_dataset(.data)) { - stop(msg, "\nCall collect() first to pull data into R.", call. = FALSE) - } - # else, collect and call dplyr method - warning(msg, "; pulling data into R", immediate. = TRUE, call. = FALSE) - call$.data <- dplyr::collect(.data) - call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr")) - eval.parent(call, 2) -} - query_on_dataset <- function(x) { any(map_lgl(all_sources(x), ~ inherits(., c("Dataset", "RecordBatchReader")))) } diff --git a/r/man/arrow_not_supported.Rd b/r/man/arrow_not_supported.Rd new file mode 100644 index 00000000000..be6a001fa1f --- /dev/null +++ b/r/man/arrow_not_supported.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dplyr-eval.R +\name{arrow_not_supported} +\alias{arrow_not_supported} +\alias{validation_error} +\title{Helpers to raise classed errors} +\usage{ +arrow_not_supported( + msg, + .actual_msg = paste(msg, "not supported in Arrow"), + ... +) + +validation_error(msg, ...) +} +\arguments{ +\item{msg}{The message to show. \code{arrow_not_supported()} will append +"not supported in Arrow" to this message.} + +\item{.actual_msg}{If you don't want to append "not supported in Arrow" to +the message, you can provide the full message here.} + +\item{...}{Additional arguments to pass to \code{rlang::abort()}. Useful arguments +include \code{call} to provide the call or expression that caused the error, and +\code{body} to provide additional context about the error.} +} +\description{ +\code{arrow_not_supported()} and \code{validation_error()} raise classed errors that +allow us to distinguish between things that are not supported in Arrow and +things that are just invalid input. Additional wrapping in \code{arrow_eval()} +and \code{try_arrow_dplyr()} provide more context and suggestions. +Importantly, if \code{arrow_not_supported} is raised, then retrying the same code +in regular dplyr in R may work. But if \code{validation_error} is raised, then we +shouldn't recommend retrying with regular dplyr because it will fail there +too. +} +\details{ +Use these in function bindings and in the dplyr methods. Inside of function +bindings, you don't need to provide the \code{call} argument, as it will be +automatically filled in with the expression that caused the error in +\code{arrow_eval()}. In dplyr methods, you should provide the \code{call} argument; +\code{rlang::caller_call()} often is correct, but you may need to experiment to +find how far up the call stack you need to look. + +You may provide additional information in the \code{body} argument, a named +character vector. Use \code{i} for additional information about the error and \code{>} +to indicate potential solutions or workarounds that don't require pulling the +data into R. If you have an \code{arrow_not_supported()} error with a \code{>} +suggestion, when the error is ultimately raised by \code{try_error_dplyr()}, +\verb{Call collect() first to pull data into R} won't be the only suggestion. + +You can still use \code{match.arg()} and \code{assert_that()} for simple input +validation inside of the function bindings. \code{arrow_eval()} will catch their +errors and re-raise them as \code{validation_error}. +} +\keyword{internal} diff --git a/r/tests/testthat/_snaps/dataset-dplyr.md b/r/tests/testthat/_snaps/dataset-dplyr.md new file mode 100644 index 00000000000..a2d9820a4e7 --- /dev/null +++ b/r/tests/testthat/_snaps/dataset-dplyr.md @@ -0,0 +1,9 @@ +# dplyr method not implemented messages + + Code + ds %>% filter(int > 6, dbl > max(dbl)) + Condition + Error in `dbl > max(dbl)`: + ! Expression not supported in filter() in Arrow + > Call collect() first to pull data into R. + diff --git a/r/tests/testthat/_snaps/dplyr-across.md b/r/tests/testthat/_snaps/dplyr-across.md new file mode 100644 index 00000000000..47b5bd61b39 --- /dev/null +++ b/r/tests/testthat/_snaps/dplyr-across.md @@ -0,0 +1,11 @@ +# expand_across correctly expands quosures + + Code + InMemoryDataset$create(example_data) %>% mutate(across(c(dbl, dbl2), round, + digits = -1)) + Condition + Error in `mutate.Dataset()`: + ! `...` argument to `across()` is deprecated in dplyr and not supported in Arrow + > Convert your call into a function or formula including the arguments + > Or, call collect() first to pull data into R. + diff --git a/r/tests/testthat/_snaps/dplyr-eval.md b/r/tests/testthat/_snaps/dplyr-eval.md new file mode 100644 index 00000000000..0b4639f1fe7 --- /dev/null +++ b/r/tests/testthat/_snaps/dplyr-eval.md @@ -0,0 +1,27 @@ +# try_arrow_dplyr/abandon_ship adds the right message about collect() + + Code + tester(ds, i) + Condition + Error in `validation_error()`: + ! arg is 0 + +--- + + Code + tester(ds, i) + Condition + Error in `arrow_not_supported()`: + ! arg == 1 not supported in Arrow + > Call collect() first to pull data into R. + +--- + + Code + tester(ds, i) + Condition + Error in `arrow_not_supported()`: + ! arg greater than 0 not supported in Arrow + > Try setting arg to -1 + > Or, call collect() first to pull data into R. + diff --git a/r/tests/testthat/_snaps/dplyr-funcs-datetime.md b/r/tests/testthat/_snaps/dplyr-funcs-datetime.md new file mode 100644 index 00000000000..036c8b49e80 --- /dev/null +++ b/r/tests/testthat/_snaps/dplyr-funcs-datetime.md @@ -0,0 +1,11 @@ +# `as.Date()` and `as_date()` + + Code + test_df %>% InMemoryDataset$create() %>% transmute(date_char_ymd = as.Date( + character_ymd_var, tryFormats = c("%Y-%m-%d", "%Y/%m/%d"))) %>% collect() + Condition + Error in `as.Date()`: + ! `as.Date()` with multiple `tryFormats` not supported in Arrow + > Consider using the lubridate specialised parsing functions `ymd()`, `ymd()`, etc. + > Or, call collect() first to pull data into R. + diff --git a/r/tests/testthat/_snaps/dplyr-mutate.md b/r/tests/testthat/_snaps/dplyr-mutate.md new file mode 100644 index 00000000000..a5bbc0163bc --- /dev/null +++ b/r/tests/testthat/_snaps/dplyr-mutate.md @@ -0,0 +1,25 @@ +# transmute() defuses dots arguments (ARROW-13262) + + Code + tbl %>% Table$create() %>% transmute(a = stringr::str_c(padded_strings, + padded_strings), b = stringr::str_squish(a)) %>% collect() + Condition + Warning: + In stringr::str_squish(a): + i Expression not supported in Arrow + > Pulling data into R + Output + # A tibble: 10 x 2 + a b + + 1 " a a " a a + 2 " b b " b b + 3 " c c " c c + 4 " d d " d d + 5 " e e " e e + 6 " f f " f f + 7 " g g " g g + 8 " h h " h h + 9 " i i " i i + 10 " j j " j j + diff --git a/r/tests/testthat/_snaps/dplyr-query.md b/r/tests/testthat/_snaps/dplyr-query.md index a9d4da26cca..cf5ac594acb 100644 --- a/r/tests/testthat/_snaps/dplyr-query.md +++ b/r/tests/testthat/_snaps/dplyr-query.md @@ -1,4 +1,6 @@ # Scalars in expressions match the type of the field, if possible - Expression int == "5" not supported in Arrow; pulling data into R + In int == "5": + i Expression not supported in Arrow + > Pulling data into R diff --git a/r/tests/testthat/_snaps/dplyr-summarize.md b/r/tests/testthat/_snaps/dplyr-summarize.md index bbb8e64bfe7..449a194d68f 100644 --- a/r/tests/testthat/_snaps/dplyr-summarize.md +++ b/r/tests/testthat/_snaps/dplyr-summarize.md @@ -3,11 +3,44 @@ Code InMemoryDataset$create(tbl) %>% summarize(distinct = n_distinct()) Condition - Error: - ! Error : In n_distinct(), n_distinct() with 0 arguments not supported in Arrow - Call collect() first to pull data into R. + Error in `n_distinct()`: + ! n_distinct() with 0 arguments not supported in Arrow + > Call collect() first to pull data into R. --- - Error : In n_distinct(int, lgl), Multiple arguments to n_distinct() not supported in Arrow; pulling data into R + In n_distinct(int, lgl): + i Multiple arguments to n_distinct() not supported in Arrow + > Pulling data into R + +# Expressions on aggregations + + Code + record_batch(tbl) %>% summarise(any(any(lgl))) + Condition + Warning: + In any(any(lgl)): + i aggregate within aggregate expression not supported in Arrow + > Pulling data into R + Output + # A tibble: 1 x 1 + `any(any(lgl))` + + 1 TRUE + +# Can use across() within summarise() + + Code + data.frame(x = 1, y = 2) %>% arrow_table() %>% group_by(x) %>% summarise(across( + everything())) %>% collect() + Condition + Warning: + In y: + i Expression is not a valid aggregation expression or is not supported in Arrow + > Pulling data into R + Output + # A tibble: 1 x 2 + x y + + 1 1 2 diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index 090ed36aa7f..63d0163aa31 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -88,7 +88,7 @@ compare_dplyr_binding <- function(expr, tbl, warning = NA, ...) { if (isTRUE(warning)) { # Special-case the simple warning: - warning <- "not supported in Arrow; pulling data into R" + warning <- "> Pulling data into R" } # Evaluate `expr` on a Table object and compare with `expected` @@ -289,3 +289,8 @@ split_vector_as_list <- function(vec) { expect_across_equal <- function(across_expr, expected, tbl) { expect_identical(expand_across(as_adq(tbl), across_expr), new_quosures(expected)) } + +expect_arrow_eval_error <- function(expr, ..., .data = example_data) { + mask <- arrow_mask(as_adq(.data)) + expect_error(arrow_eval({{ expr }}, mask), ...) +} diff --git a/r/tests/testthat/test-dataset-dplyr.R b/r/tests/testthat/test-dataset-dplyr.R index 493eac328e5..d5c8dc9820a 100644 --- a/r/tests/testthat/test-dataset-dplyr.R +++ b/r/tests/testthat/test-dataset-dplyr.R @@ -323,10 +323,9 @@ test_that("head/tail on query on dataset", { test_that("dplyr method not implemented messages", { ds <- open_dataset(dataset_dir) # This one is more nuanced - expect_error( + expect_snapshot( ds %>% filter(int > 6, dbl > max(dbl)), - "Expression dbl > max(dbl) not supported in filter() in Arrow\nCall collect() first to pull data into R.", - fixed = TRUE + error = TRUE ) }) diff --git a/r/tests/testthat/test-dplyr-across.R b/r/tests/testthat/test-dplyr-across.R index 32476bab06f..cfdad9a1f4c 100644 --- a/r/tests/testthat/test-dplyr-across.R +++ b/r/tests/testthat/test-dplyr-across.R @@ -117,13 +117,11 @@ test_that("expand_across correctly expands quosures", { ) # ellipses (...) are a deprecated argument - expect_error( - expand_across( - as_adq(example_data), - quos(across(c(dbl, dbl2), round, digits = -1)) - ), - regexp = "`...` argument to `across()` is deprecated in dplyr and not supported in Arrow", - fixed = TRUE + # abandon_ship message offers multiple suggestions + expect_snapshot( + InMemoryDataset$create(example_data) %>% + mutate(across(c(dbl, dbl2), round, digits = -1)), + error = TRUE ) # alternative ways of specifying .fns - as a list diff --git a/r/tests/testthat/test-dplyr-collapse.R b/r/tests/testthat/test-dplyr-collapse.R index f50fa8945db..f658c531e78 100644 --- a/r/tests/testthat/test-dplyr-collapse.R +++ b/r/tests/testthat/test-dplyr-collapse.R @@ -168,19 +168,6 @@ total: int64 extra: int64 (multiply_checked(total, 5)) * Sorted by lgl [asc] -See $.data for the source Arrow object", - fixed = TRUE - ) - expect_output( - print(q$.data), - "Table (query) -int: int32 -lgl: bool - -* Aggregations: -total: sum(int) -* Filter: (dbl > 2) -* Grouped by lgl See $.data for the source Arrow object", fixed = TRUE ) diff --git a/r/tests/testthat/test-dplyr-eval.R b/r/tests/testthat/test-dplyr-eval.R new file mode 100644 index 00000000000..16c56f28cdb --- /dev/null +++ b/r/tests/testthat/test-dplyr-eval.R @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("various paths in arrow_eval", { + expect_arrow_eval_error( + assert_is(1, "character"), + class = "validation_error" + ) + expect_arrow_eval_error( + NoTaVaRiAbLe, + class = "validation_error" + ) + expect_arrow_eval_error( + match.arg("z", c("a", "b")), + class = "validation_error" + ) + expect_arrow_eval_error( + stop("something something NotImplementedError"), + class = "arrow_not_supported" + ) +}) + +test_that("try_arrow_dplyr/abandon_ship adds the right message about collect()", { + tester <- function(.data, arg) { + try_arrow_dplyr({ + if (arg == 0) { + # This one just stops and doesn't recommend calling collect() + validation_error("arg is 0") + } else if (arg == 1) { + # This one recommends calling collect() + arrow_not_supported("arg == 1") + } else { + # Because this one has an alternative suggested, it adds "Or, collect()" + arrow_not_supported( + "arg greater than 0", + body = c(">" = "Try setting arg to -1") + ) + } + }) + } + + ds <- InMemoryDataset$create(arrow_table(x = 1)) + for (i in 0:2) { + expect_snapshot(tester(ds, i), error = TRUE) + } +}) diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index 535bcb70c4c..ba086133dca 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -317,7 +317,12 @@ test_that("Filtering with unsupported functions", { filter(int > 2, pnorm(dbl) > .99) %>% collect(), tbl, - warning = "Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow; pulling data into R" + warning = paste( + "In pnorm\\(dbl\\) > 0.99: ", + "i Expression not supported in Arrow", + "> Pulling data into R", + sep = "\n" + ) ) compare_dplyr_binding( .input %>% @@ -329,8 +334,10 @@ test_that("Filtering with unsupported functions", { collect(), tbl, warning = paste( - 'In nchar\\(chr, type = "bytes", allowNA = TRUE\\) == 1,', - "allowNA = TRUE not supported in Arrow; pulling data into R" + 'In nchar\\(chr, type = "bytes", allowNA = TRUE\\) == 1: ', + "i allowNA = TRUE not supported in Arrow", + "> Pulling data into R", + sep = "\n" ) ) }) @@ -468,7 +475,12 @@ test_that(".by argument", { filter(int > 2, pnorm(dbl) > .99, .by = chr) %>% collect(), tbl, - warning = "Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow; pulling data into R" + warning = paste( + "In pnorm\\(dbl\\) > 0.99: ", + "i Expression not supported in Arrow", + "> Pulling data into R", + sep = "\n" + ) ) expect_error( tbl %>% diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 3ea1853fec4..d90dc827b40 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -248,75 +248,50 @@ test_that("case_when()", { ) ) - # expected errors (which are caught by abandon_ship() and changed to warnings) - # TODO: Find a way to test these directly without abandon_ship() interfering - expect_error( - # no cases - expect_warning( - tbl %>% - Table$create() %>% - transmute(cw = case_when()), - "case_when" - ) - ) - expect_error( - # argument not a formula - expect_warning( - tbl %>% - Table$create() %>% - transmute(cw = case_when(TRUE ~ FALSE, TRUE)), - "case_when" - ) - ) - expect_error( - # non-logical R scalar on left side of formula - expect_warning( - tbl %>% - Table$create() %>% - transmute(cw = case_when(0L ~ FALSE, TRUE ~ FALSE)), - "case_when" - ) - ) - expect_error( + # validation errors + expect_arrow_eval_error( + case_when(), + "No cases provided", + class = "validation_error" + ) + expect_arrow_eval_error( + case_when(TRUE ~ FALSE, TRUE), + "Each argument to case_when\\(\\) must be a two-sided formula", + class = "validation_error" + ) + expect_arrow_eval_error( + case_when(0L ~ FALSE, TRUE ~ FALSE), + "Left side of each formula in case_when\\(\\) must be a logical expression", + class = "validation_error" + ) + expect_arrow_eval_error( # non-logical Arrow column reference on left side of formula - expect_warning( - tbl %>% - Table$create() %>% - transmute(cw = case_when(int ~ FALSE)), - "case_when" - ) + case_when(int ~ FALSE), + "Left side of each formula in case_when\\(\\) must be a logical expression", + class = "validation_error" ) - expect_error( - # non-logical Arrow expression on left side of formula - expect_warning( - tbl %>% - Table$create() %>% - transmute(cw = case_when(dbl + 3.14159 ~ TRUE)), - "case_when" - ) + expect_arrow_eval_error( + # non-logical Arrow column reference on left side of formula + case_when(dbl + 3.14159 ~ TRUE), + "Left side of each formula in case_when\\(\\) must be a logical expression", + class = "validation_error" ) - - expect_error( - expect_warning( - tbl %>% - arrow_table() %>% - mutate(cw = case_when(int > 5 ~ 1, .default = c(0, 1))) - ), - "`.default` must have size" + expect_arrow_eval_error( + case_when(int > 5 ~ 1, .default = c(0, 1)), + "`.default` must have size 1, not size 2", + class = "validation_error" ) - expect_warning( - tbl %>% - arrow_table() %>% - mutate(cw = case_when(int > 5 ~ 1, .ptype = integer())), - "not supported in Arrow" + expect_arrow_eval_error( + case_when(int > 5 ~ 1, .ptype = integer()), + "`case_when\\(\\)` with `.ptype` specified not supported in Arrow", + class = "arrow_not_supported" ) - expect_warning( - tbl %>% - arrow_table() %>% - mutate(cw = case_when(int > 5 ~ 1, .size = 10)), - "not supported in Arrow" + expect_arrow_eval_error( + case_when(int > 5 ~ 1, .size = 10), + "`case_when\\(\\)` with `.size` specified not supported in Arrow", + class = "arrow_not_supported" ) compare_dplyr_binding( @@ -500,9 +475,9 @@ test_that("coalesce()", { ) # no arguments - expect_error( - call_binding("coalesce"), - "At least one argument must be supplied to coalesce()", - fixed = TRUE + expect_arrow_eval_error( + coalesce(), + "At least one argument must be supplied to coalesce\\(\\)", + class = "validation_error" ) }) diff --git a/r/tests/testthat/test-dplyr-funcs-datetime.R b/r/tests/testthat/test-dplyr-funcs-datetime.R index 6f520f6e322..0e4d2f3656a 100644 --- a/r/tests/testthat/test-dplyr-funcs-datetime.R +++ b/r/tests/testthat/test-dplyr-funcs-datetime.R @@ -1886,34 +1886,18 @@ test_that("`as.Date()` and `as_date()`", { ) # we do not support multiple tryFormats - # this is not a simple warning, therefore we cannot use compare_dplyr_binding() - # with `warning = TRUE` - # arrow_table test - expect_warning( - test_df %>% - arrow_table() %>% - mutate( - date_char_ymd = as.Date( - character_ymd_var, - tryFormats = c("%Y-%m-%d", "%Y/%m/%d") - ) - ) %>% - collect(), - regexp = "Consider using the lubridate specialised parsing functions" - ) - - # record batch test - expect_warning( + # Use a dataset to test the alternative suggestion message + expect_snapshot( test_df %>% - record_batch() %>% - mutate( + InMemoryDataset$create() %>% + transmute( date_char_ymd = as.Date( character_ymd_var, tryFormats = c("%Y-%m-%d", "%Y/%m/%d") ) ) %>% collect(), - regexp = "Consider using the lubridate specialised parsing functions" + error = TRUE ) # strptime does not support a partial format - Arrow returns NA, while @@ -3126,11 +3110,9 @@ test_that("timestamp round/floor/ceiling works for a minimal test", { }) test_that("timestamp round/floor/ceiling accepts period unit abbreviation", { - # test helper to ensure standard abbreviations of period names # are understood by arrow and mirror the lubridate behaviour check_period_abbreviation <- function(unit, synonyms) { - # check arrow against lubridate compare_dplyr_binding( .input %>% @@ -3255,7 +3237,6 @@ test_that("timestamp round/floor/ceil works for units: month/quarter/year", { # check helper invoked when we need to avoid the lubridate rounding bug check_date_rounding_1051_bypass <- function(data, unit, ignore_attr = TRUE, ...) { - # directly compare arrow to lubridate for floor and ceiling compare_dplyr_binding( .input %>% @@ -3288,7 +3269,6 @@ check_date_rounding_1051_bypass <- function(data, unit, ignore_attr = TRUE, ...) } test_that("date round/floor/ceil works for units: month/quarter/year", { - # these test cases are affected by lubridate issue 1051 so we bypass # lubridate::round_date() for Date objects with large rounding units # https://github.com/tidyverse/lubridate/issues/1051 @@ -3348,7 +3328,6 @@ test_that("timestamp round/floor/ceil works for week units (non-standard week_st }) check_date_week_rounding <- function(data, week_start, ignore_attr = TRUE, ...) { - # directly compare arrow to lubridate for floor and ceiling compare_dplyr_binding( .input %>% @@ -3395,7 +3374,6 @@ test_that("date round/floor/ceil works for week units (non-standard week_start)" # ceiling_date behaves identically to the lubridate version. It takes # unit as an argument to run tests separately for different rounding units check_boundary_with_unit <- function(unit, ...) { - # timestamps compare_dplyr_binding( .input %>% @@ -3464,7 +3442,6 @@ test_that("temporal round/floor/ceil period unit maxima are enforced", { # results. this test helper runs that test, skipping cases where lubridate # produces incorrect answers check_timezone_rounding_vs_lubridate <- function(data, unit) { - # esoteric lubridate bug: on windows and macOS (not linux), lubridate returns # incorrect ceiling/floor for timezoned POSIXct times (syd, adl, kat zones, # but not mar) but not utc, and not for round, and only for these two @@ -3702,8 +3679,8 @@ test_that("with_tz() and force_tz() works", { mutate(timestamps = force_tz( timestamps, "Europe/Brussels", - roll_dst = "post") - ) %>% + roll_dst = "post" + )) %>% collect(), "roll_dst` value must be 'error' or 'boundary' for nonexistent times" ) @@ -3712,11 +3689,10 @@ test_that("with_tz() and force_tz() works", { tibble::tibble(timestamps = nonexistent) %>% arrow_table() %>% mutate(timestamps = force_tz( - timestamps, - "Europe/Brussels", - roll_dst = c("boundary", "NA") - ) - ) %>% + timestamps, + "Europe/Brussels", + roll_dst = c("boundary", "NA") + )) %>% collect(), "`roll_dst` value must be 'error', 'pre', or 'post' for nonexistent times" ) diff --git a/r/tests/testthat/test-dplyr-funcs-string.R b/r/tests/testthat/test-dplyr-funcs-string.R index 039220b88ee..cb1d4675058 100644 --- a/r/tests/testthat/test-dplyr-funcs-string.R +++ b/r/tests/testthat/test-dplyr-funcs-string.R @@ -172,27 +172,31 @@ test_that("paste, paste0, and str_c", { # expected errors # collapse argument not supported - expect_error( - call_binding("paste", x, y, collapse = ""), - "collapse" + expect_arrow_eval_error( + paste(chr, int, collapse = ""), + "`collapse` argument not supported in Arrow", + class = "arrow_not_supported" ) - expect_error( - call_binding("paste0", x, y, collapse = ""), - "collapse" + expect_arrow_eval_error( + paste0(chr, int, collapse = ""), + "`collapse` argument not supported in Arrow", + class = "arrow_not_supported" ) - expect_error( - call_binding("str_c", x, y, collapse = ""), - "collapse" + expect_arrow_eval_error( + str_c(chr, int, collapse = ""), + "`collapse` argument not supported in Arrow", + class = "arrow_not_supported" ) - # literal vectors of length != 1 not supported - expect_error( - call_binding("paste", x, character(0), y), - "Literal vectors of length != 1 not supported in string concatenation" + expect_arrow_eval_error( + paste(chr, character(0), int), + "Literal vectors of length != 1 in string concatenation not supported in Arrow", + class = "arrow_not_supported" ) - expect_error( - call_binding("paste", x, c(",", ";"), y), - "Literal vectors of length != 1 not supported in string concatenation" + expect_arrow_eval_error( + paste(chr, c(",", ";"), int), + "Literal vectors of length != 1 in string concatenation not supported in Arrow", + class = "arrow_not_supported" ) }) @@ -602,10 +606,15 @@ test_that("str_to_lower, str_to_upper, and str_to_title", { ) # Error checking a single function because they all use the same code path. - expect_error( - call_binding("str_to_lower", "Apache Arrow", locale = "sp"), - "Providing a value for 'locale' other than the default ('en') is not supported in Arrow", - fixed = TRUE + expect_arrow_eval_error( + str_to_lower("Apache Arrow", locale = "sp"), + paste( + "Providing a value for 'locale' other than the default ('en') not supported in Arrow", + "> To change locale, use 'Sys.setlocale()'", + sep = "\n" + ), + fixed = TRUE, + class = "arrow_not_supported" ) }) @@ -1041,14 +1050,15 @@ test_that("substr with string()", { df ) - expect_error( - call_binding("substr", "Apache Arrow", c(1, 2), 3), - "`start` must be length 1 - other lengths are not supported in Arrow" + expect_arrow_eval_error( + substr("Apache Arrow", c(1, 2), 3), + "`start` must be length 1 - other lengths not supported in Arrow", + class = "arrow_not_supported" ) - - expect_error( - call_binding("substr", "Apache Arrow", 1, c(2, 3)), - "`stop` must be length 1 - other lengths are not supported in Arrow" + expect_arrow_eval_error( + substr("Apache Arrow", 1, c(2, 3)), + "`stop` must be length 1 - other lengths not supported in Arrow", + class = "arrow_not_supported" ) }) @@ -1169,14 +1179,15 @@ test_that("str_sub", { df ) - expect_error( - call_binding("str_sub", "Apache Arrow", c(1, 2), 3), - "`start` must be length 1 - other lengths are not supported in Arrow" + expect_arrow_eval_error( + str_sub("Apache Arrow", c(1, 2), 3), + "`start` must be length 1 - other lengths not supported in Arrow", + class = "arrow_not_supported" ) - - expect_error( - call_binding("str_sub", "Apache Arrow", 1, c(2, 3)), - "`end` must be length 1 - other lengths are not supported in Arrow" + expect_arrow_eval_error( + str_sub("Apache Arrow", 1, c(2, 3)), + "`end` must be length 1 - other lengths not supported in Arrow", + class = "arrow_not_supported" ) }) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 71c1e52d33c..fa13c151b14 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -152,16 +152,14 @@ test_that("transmute() with unsupported arguments", { }) test_that("transmute() defuses dots arguments (ARROW-13262)", { - expect_warning( + expect_snapshot( tbl %>% Table$create() %>% transmute( a = stringr::str_c(padded_strings, padded_strings), b = stringr::str_squish(a) ) %>% - collect(), - "Expression stringr::str_squish(a) not supported in Arrow; pulling data into R", - fixed = TRUE + collect() ) }) @@ -202,10 +200,7 @@ test_that("nchar() arguments", { filter(line_lengths > 15) %>% collect(), tbl, - warning = paste0( - "In nchar\\(verses, type = \"bytes\", allowNA = TRUE\\), ", - "allowNA = TRUE not supported in Arrow; pulling data into R" - ) + warning = "allowNA = TRUE not supported in Arrow" ) }) @@ -538,7 +533,7 @@ test_that("Can't just add a vector column with mutate()", { mutate(again = 1:10), tibble::tibble(int = tbl$int, again = 1:10) ), - "In again = 1:10, only values of size one are recycled; pulling data into R" + "Recycling values of length != 1 not supported in Arrow" ) }) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index a61ef95bee7..95212407acf 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -832,28 +832,18 @@ test_that("Expressions on aggregations", { ) # Aggregates on aggregates are not supported - expect_warning( - record_batch(tbl) %>% summarise(any(any(lgl))), - paste( - "In any\\(any\\(lgl\\)\\), aggregate within aggregate expression", - "not supported in Arrow" - ) + expect_snapshot( + record_batch(tbl) %>% summarise(any(any(lgl))) ) # Check aggregates on aggregates with more complex calls expect_warning( record_batch(tbl) %>% summarise(any(any(!lgl))), - paste( - "In any\\(any\\(!lgl\\)\\), aggregate within aggregate expression", - "not supported in Arrow" - ) + "aggregate within aggregate expression not supported in Arrow" ) expect_warning( record_batch(tbl) %>% summarise(!any(any(lgl))), - paste( - "In \\!any\\(any\\(lgl\\)\\), aggregate within aggregate expression", - "not supported in Arrow" - ) + "aggregate within aggregate expression not supported in Arrow" ) }) @@ -965,7 +955,7 @@ test_that("Summarize with 0 arguments", { ) }) -test_that("Not (yet) supported: window functions", { +test_that("Not supported: window functions", { compare_dplyr_binding( .input %>% group_by(some_grouping) %>% @@ -974,10 +964,7 @@ test_that("Not (yet) supported: window functions", { ) %>% collect(), tbl, - warning = paste( - "In sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\), aggregate within", - "aggregate expression not supported in Arrow; pulling data into R" - ) + warning = "aggregate within aggregate expression not supported in Arrow" ) compare_dplyr_binding( .input %>% @@ -987,10 +974,7 @@ test_that("Not (yet) supported: window functions", { ) %>% collect(), tbl, - warning = paste( - "In sum\\(dbl - mean\\(dbl\\)\\), aggregate within aggregate expression", - "not supported in Arrow; pulling data into R" - ) + warning = "aggregate within aggregate expression not supported in Arrow" ) compare_dplyr_binding( .input %>% @@ -1000,10 +984,7 @@ test_that("Not (yet) supported: window functions", { ) %>% collect(), tbl, - warning = paste( - "In sqrt\\(sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\)/\\(n\\(\\) - 1L\\)\\), aggregate within", - "aggregate expression not supported in Arrow; pulling data into R" - ) + warning = "aggregate within aggregate expression not supported in Arrow" ) compare_dplyr_binding( @@ -1012,10 +993,7 @@ test_that("Not (yet) supported: window functions", { summarize(y - mean(y)) %>% collect(), data.frame(x = 1, y = 2), - warning = paste( - "Expression y - mean\\(y\\) is not a valid aggregation expression", - "or is not supported in Arrow; pulling data into R" - ) + warning = "Expression is not a valid aggregation expression or is not supported in Arrow" ) compare_dplyr_binding( @@ -1024,10 +1002,7 @@ test_that("Not (yet) supported: window functions", { summarize(y) %>% collect(), data.frame(x = 1, y = 2), - warning = paste( - "Expression y is not a valid aggregation expression", - "or is not supported in Arrow; pulling data into R" - ) + warning = "Expression is not a valid aggregation expression or is not supported in Arrow" ) # This one could possibly be supported--in mutate() @@ -1037,10 +1012,7 @@ test_that("Not (yet) supported: window functions", { summarize(x - y) %>% collect(), data.frame(x = 1, y = 2, z = 3), - warning = paste( - "Expression x - y is not a valid aggregation expression", - "or is not supported in Arrow; pulling data into R" - ) + warning = "Expression is not a valid aggregation expression or is not supported in Arrow" ) }) @@ -1274,13 +1246,12 @@ test_that("Can use across() within summarise()", { ) # across() doesn't work in summarise when input expressions evaluate to bare field references - expect_warning( + expect_snapshot( data.frame(x = 1, y = 2) %>% arrow_table() %>% group_by(x) %>% summarise(across(everything())) %>% - collect(), - regexp = "Expression y is not a valid aggregation expression or is not supported in Arrow; pulling data into R" + collect() ) }) diff --git a/r/vignettes/developers/matchsubstringoptions.png b/r/vignettes/developers/matchsubstringoptions.png deleted file mode 100644 index 2dff3c5858e..00000000000 Binary files a/r/vignettes/developers/matchsubstringoptions.png and /dev/null differ diff --git a/r/vignettes/developers/starts_with_docs.png b/r/vignettes/developers/starts_with_docs.png deleted file mode 100644 index a55e888128f..00000000000 Binary files a/r/vignettes/developers/starts_with_docs.png and /dev/null differ diff --git a/r/vignettes/developers/startswithdocs.png b/r/vignettes/developers/startswithdocs.png deleted file mode 100644 index 6e1f3df1b9b..00000000000 Binary files a/r/vignettes/developers/startswithdocs.png and /dev/null differ diff --git a/r/vignettes/developers/writing_bindings.Rmd b/r/vignettes/developers/writing_bindings.Rmd deleted file mode 100644 index e1ed92105db..00000000000 --- a/r/vignettes/developers/writing_bindings.Rmd +++ /dev/null @@ -1,253 +0,0 @@ ---- -title: "Writing dplyr bindings" -description: > - Learn how to write bindings that allow arrow to mirror the behavior - of native R functions within dplyr pipelines -output: rmarkdown::html_vignette ---- - -```{r, include=FALSE} -library(arrow, warn.conflicts = FALSE) -library(dplyr, warn.conflicts = FALSE) -``` - -When writing bindings between C++ compute functions and R functions, the aim is -to expose the C++ functionality via the same interface as existing R functions. The syntax and -functionality should match that of the existing R functions -(though there are some exceptions) so that users are able to use existing tidyverse -or base R syntax, whilst taking advantage of the speed and functionality of the -underlying arrow package. - -One of main ways in which users interact with arrow is via -[dplyr](https://dplyr.tidyverse.org/) syntax called on Arrow objects. For -example, when a user calls `dplyr::mutate()` on an Arrow Tabular, -Dataset, or arrow data query object, the Arrow implementation of `mutate()` is -used and under the hood, translates the dplyr code into Arrow C++ code. - -When using `dplyr::mutate()` or `dplyr::filter()`, you may want to use functions -from other packages. The example below uses `stringr::str_detect()`. - -```{r} -library(dplyr) -library(stringr) -starwars %>% - filter(str_detect(name, "Darth")) -``` -This functionality has also been implemented in Arrow, e.g.: - -```{r} -library(arrow) -arrow_table(starwars) %>% - filter(str_detect(name, "Darth")) %>% - collect() -``` - -This is possible as a **binding** has been created between the call to the -stringr function `str_detect()` and the Arrow C++ code, here as a direct mapping -to `match_substring_regex`. You can see this for yourself by inspecting the -arrow data query object without retrieving the results via `collect()`. - - -```{r} -arrow_table(starwars) %>% - filter(str_detect(name, "Darth")) -``` - -In the following sections, we'll walk through how to create a binding between an -R function and an Arrow C++ function. - -# Walkthrough - -Imagine you are writing the bindings for the C++ function -[`starts_with()`](https://arrow.apache.org/docs/cpp/compute.html#containment-tests) -and want to bind it to the (base) R function `startsWith()`. - -First, take a look at the docs for both of those functions. - -## Examining the R function - -Here are the docs for R's `startsWith()` (also available at https://stat.ethz.ch/R-manual/R-devel/library/base/html/startsWith.html) - -```{r, echo=FALSE, out.width="50%"} -knitr::include_graphics("./startswithdocs.png") -``` - -It takes 2 parameters; `x` - the input, and `prefix` - the characters to check -if `x` starts with. - -## Examining the C++ function - -Now, go to -[the compute function documentation](https://arrow.apache.org/docs/cpp/compute.html#containment-tests) -and look for the Arrow C++ library's `starts_with()` function: - -```{r, echo=FALSE, out.width="100%"} -knitr::include_graphics("./starts_with_docs.png") -``` - -The docs show that `starts_with()` is a unary function, which means that it takes a -single data input. The data input must be a string-like class, and the returned -value is boolean, both of which match up to R's `startsWith()`. - -There is an options class associated with `starts_with()` - called [`MatchSubstringOptions`](https://arrow.apache.org/docs/cpp/api/compute.html#_CPPv4N5arrow7compute21MatchSubstringOptionsE) -- so let's take a look at that. - -```{r, echo=FALSE, out.width="100%"} -knitr::include_graphics("./matchsubstringoptions.png") -``` - -Options classes allow the user to control the behaviour of the function. In -this case, there are two possible options which can be supplied - `pattern` and -`ignore_case`, which are described in the docs shown above. - -## Comparing the R and C++ functions - -What conclusions can be drawn from what you've seen so far? - -Base R's `startsWith()` and Arrow's `starts_with()` operate on equivalent data -types, return equivalent data types, and as there are no options implemented in -R that Arrow doesn't have, this should be fairly simple to map without a great -deal of extra work. - -As `starts_with()` has an options class associated with it, we'll need to make -sure that it's linked up with this in the R code. - -In case you're wondering about the difference between arguments in R and options -in Arrow, in R, arguments to functions can include the actual data to be -analysed as well as options governing how the function works, whereas in the -C++ compute functions, the arguments are the data to be analysed and the -options are for specifying how exactly the function works. - -So let's get started. - -## Step 1 - add unit tests - -We recommend a test-driven-development approach - write failing tests first, -then check that they fail, and then write the code needed to make them pass. -Thinking up-front about the behavior which needs testing can make it easier to -reason about the code which needs writing later. - -Look up the R function that you want to bind the compute kernel to, and write a -set of unit tests that use a dplyr pipeline and `compare_dplyr_binding()` (and -perhaps even `compare_dplyr_error()` if necessary. These functions compare the -output of the original function with the dplyr bindings and make sure they match. -We recommend looking at the [documentation next to the source code for these -functions](https://github.com/apache/arrow/blob/main/r/tests/testthat/helper-expectation.R) -to get a better understanding of how they work. - -You should make sure you're testing all parameters of the R function in your -tests. - -Below is a possible example test for `startsWith()`. - -```{r, eval = FALSE} -test_that("startsWith behaves identically in dplyr and Arrow", { - df <- tibble(x = c("Foo", "bar", "baz", "qux")) - compare_dplyr_binding( - .input %>% - filter(startsWith(x, "b")) %>% - collect(), - df - ) -}) -``` - -## Step 2 - Hook up the compute function with options class if necessary - -If the C++ compute function can have options specified, make sure that the -function is linked with its options class in `make_compute_options()` in the -file `arrow/r/src/compute.cpp`. You can find out if a compute function requires -options by looking in the docs here: https://arrow.apache.org/docs/cpp/compute.html - -In the case of `starts_with()`, it looks something like this: - -```cpp - if (func_name == "starts_with") { - using Options = arrow::compute::MatchSubstringOptions; - bool ignore_case = false; - if (!Rf_isNull(options["ignore_case"])) { - ignore_case = cpp11::as_cpp(options["ignore_case"]); - } - return std::make_shared(cpp11::as_cpp(options["pattern"]), - ignore_case); - } -``` - -You can usually copy and paste from a similar existing example. In this case, -as the option `ignore_case` doesn't map to any parameters of `startsWith()`, we -give it a default value of `false` but if it's been set, use the set value -instead. As the `pattern` argument maps directly to `prefix` in `startsWith()` -we can pass it straight through. - -## Step 3 - Map the R function to the C++ kernel - -The next task is writing the code which binds the R function to the C++ kernel. - -### Step 3a - See if direct mapping is appropriate -Compare the C++ function and R function. If they are simple functions with no -options, it might be possible to directly map between the C++ and R in -`unary_function_map`, in the case of compute functions that operate on single -columns of data, or `binary_function_map` for those which operate on 2 columns -of data. - -As `startsWith()` requires options, direct mapping is not appropriate. - -### Step 3b - If direct mapping not possible, try a modified implementation -If the function cannot be mapped directly, some extra work may be needed to -ensure that calling the arrow version of the function results in the same result -as calling the R version of the function. In this case, the function will need -adding to the `.cache$functions` function registry. Here is how this might look for -`startsWith()`: - -```{r, eval = FALSE} -register_binding("base::startsWith", function(x, prefix) { - Expression$create( - "starts_with", - x, - options = list(pattern = prefix) - ) -}) -``` - -In the source files, all the `register_binding()` calls are wrapped in functions -that are called on package load. These are separated into files based on -subject matter (e.g., `R/dplyr-funcs-math.R`, `R/dplyr-funcs-string.R`): find the -closest analog to the function whose binding is being defined and define the -new binding in a similar location. For example, the binding for `startsWith()` -is registered in `dplyr-funcs-string.R` next to the binding for `endsWith()`. - -Note: we use the namespace-qualified name (i.e. `"base::startsWith"`) for a -binding. This will register the same binding both as `startsWith()` and as -`base::startsWith()`, which will allow us to use the `pkg::` prefix in a call. - -```{r} -arrow_table(starwars) %>% - filter(stringr::str_detect(name, "Darth")) -``` - -Hint: you can use `call_function()` to call a compute function directly from R. -This might be useful if you want to experiment with a compute function while -you're writing bindings for it, e.g. - -```{r} -call_function( - "starts_with", - Array$create(c("Apache", "Arrow", "R", "package")), - options = list(pattern = "A") -) -``` - -## Step 4 - Run (and potentially add to) your tests. - -In the process of implementing the function, you will need at least one test -to make sure that your binding works and that future changes to the Arrow R -package don't break it! Bindings are tested in files that correspond to -the file in which they were defined (e.g., `startsWith()` is tested in -`tests/testthat/test-dplyr-funcs-string.R`) next to the tests for `endsWith()`. - -You may end up implementing more tests, for example if you discover unusual -edge cases. This is fine - add them to the ones you wrote originally, -and run them all. If they pass, you're done and you can submit a PR. -If you've modified the C++ code in the -R package (for example, when hooking up a binding to its options class), you -should make sure to run `arrow/r/lint.sh` to lint the code.