diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 08cb5d44038..89fd656daa4 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -90,8 +90,13 @@ Collate: 'dplyr-distinct.R' 'dplyr-eval.R' 'dplyr-filter.R' + 'dplyr-funcs-conditional.R' + 'dplyr-funcs-datetime.R' + 'dplyr-funcs-math.R' + 'dplyr-funcs-string.R' + 'dplyr-funcs-type.R' 'expression.R' - 'dplyr-functions.R' + 'dplyr-funcs.R' 'dplyr-group-by.R' 'dplyr-join.R' 'dplyr-mutate.R' diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index a72483a0d6d..386e3dcaa77 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -56,21 +56,10 @@ s3_register("reticulate::r_to_py", cl) } - # Create these once, at package build time - if (arrow_available()) { - # Also include all available Arrow Compute functions, - # namespaced as arrow_fun. - # We can't do this at install time because list_compute_functions() may error - all_arrow_funs <- list_compute_functions() - arrow_funcs <- set_names( - lapply(all_arrow_funs, function(fun) { - force(fun) - function(...) build_expr(fun, ...) - }), - paste0("arrow_", all_arrow_funs) - ) - .cache$functions <- c(nse_funcs, arrow_funcs) - } + # Create the .cache$functions list at package load time. + # We can't do this at build time because list_compute_functions() may error + # if arrow_available() is FALSE + create_binding_cache() if (tolower(Sys.info()[["sysname"]]) == "windows") { # Disable multithreading on Windows diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 13e68f3f484..c62f2559310 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -113,7 +113,7 @@ implicit_schema <- function(.data) { hash <- length(.data$group_by_vars) > 0 agg_fields <- imap( new_fields[setdiff(names(new_fields), .data$group_by_vars)], - ~ output_type(.data$aggregations[[.y]][["fun"]], .x, hash) + ~ agg_fun_output_type(.data$aggregations[[.y]][["fun"]], .x, hash) ) new_fields <- c(group_fields, agg_fields) } diff --git a/r/R/dplyr-funcs-conditional.R b/r/R/dplyr-funcs-conditional.R new file mode 100644 index 00000000000..493031d2f57 --- /dev/null +++ b/r/R/dplyr-funcs-conditional.R @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +register_bindings_conditional <- function() { + register_binding("coalesce", function(...) { + args <- list2(...) + if (length(args) < 1) { + abort("At least one argument must be supplied to coalesce()") + } + + # Treat NaN like NA for consistency with dplyr::coalesce(), but if *all* + # the values are NaN, we should return NaN, not NA, so don't replace + # NaN with NA in the final (or only) argument + # TODO: if an option is added to the coalesce kernel to treat NaN as NA, + # use that to simplify the code here (ARROW-13389) + attr(args[[length(args)]], "last") <- TRUE + args <- lapply(args, function(arg) { + last_arg <- is.null(attr(arg, "last")) + attr(arg, "last") <- NULL + + if (!inherits(arg, "Expression")) { + arg <- Expression$scalar(arg) + } + + if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) { + # store the NA_real_ in the same type as arg to avoid avoid casting + # smaller float types to larger float types + NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type())) + Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg) + } else { + arg + } + }) + Expression$create("coalesce", args = args) + }) + + if_else_binding <- function(condition, true, false, missing = NULL) { + if (!is.null(missing)) { + return(if_else_binding( + call_binding("is.na", (condition)), + missing, + if_else_binding(condition, true, false) + )) + } + + build_expr("if_else", condition, true, false) + } + + register_binding("if_else", if_else_binding) + + # Although base R ifelse allows `yes` and `no` to be different classes + register_binding("ifelse", function(test, yes, no) { + if_else_binding(condition = test, true = yes, false = no) + }) + + register_binding("case_when", function(...) { + formulas <- list2(...) + n <- length(formulas) + if (n == 0) { + abort("No cases provided in case_when()") + } + query <- vector("list", n) + value <- vector("list", n) + mask <- caller_env() + for (i in seq_len(n)) { + f <- formulas[[i]] + if (!inherits(f, "formula")) { + abort("Each argument to case_when() must be a two-sided formula") + } + query[[i]] <- arrow_eval(f[[2]], mask) + value[[i]] <- arrow_eval(f[[3]], mask) + if (!call_binding("is.logical", query[[i]])) { + abort("Left side of each formula in case_when() must be a logical expression") + } + if (inherits(value[[i]], "try-error")) { + abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) + } + } + build_expr( + "case_when", + args = c( + build_expr( + "make_struct", + args = query, + options = list(field_names = as.character(seq_along(query))) + ), + value + ) + ) + }) +} diff --git a/r/R/dplyr-funcs-datetime.R b/r/R/dplyr-funcs-datetime.R new file mode 100644 index 00000000000..2eedc03ad83 --- /dev/null +++ b/r/R/dplyr-funcs-datetime.R @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +register_bindings_datetime <- function() { + register_binding("strptime", function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL, + unit = "ms") { + # Arrow uses unit for time parsing, strptime() does not. + # Arrow has no default option for strptime (format, unit), + # we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms", + # (ARROW-12809) + + # ParseTimestampStrptime currently ignores the timezone information (ARROW-12820). + # Stop if tz is provided. + if (is.character(tz)) { + arrow_not_supported("Time zone argument") + } + + unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units)) + + Expression$create("strptime", x, options = list(format = format, unit = unit)) + }) + + register_binding("strftime", function(x, format = "", tz = "", usetz = FALSE) { + if (usetz) { + format <- paste(format, "%Z") + } + if (tz == "") { + tz <- Sys.timezone() + } + # Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first + # cast the timestamp to desired timezone. This is a metadata only change. + if (call_binding("is.POSIXct", x)) { + ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz))) + } else { + ts <- x + } + Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME"))) + }) + + register_binding("format_ISO8601", function(x, usetz = FALSE, precision = NULL, ...) { + ISO8601_precision_map <- + list( + y = "%Y", + ym = "%Y-%m", + ymd = "%Y-%m-%d", + ymdh = "%Y-%m-%dT%H", + ymdhm = "%Y-%m-%dT%H:%M", + ymdhms = "%Y-%m-%dT%H:%M:%S" + ) + + if (is.null(precision)) { + precision <- "ymdhms" + } + if (!precision %in% names(ISO8601_precision_map)) { + abort( + paste( + "`precision` must be one of the following values:", + paste(names(ISO8601_precision_map), collapse = ", "), + "\nValue supplied was: ", + precision + ) + ) + } + format <- ISO8601_precision_map[[precision]] + if (usetz) { + format <- paste0(format, "%z") + } + Expression$create("strftime", x, options = list(format = format, locale = "C")) + }) + + register_binding("second", function(x) { + Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x)) + }) + + register_binding("wday", function(x, label = FALSE, abbr = TRUE, + week_start = getOption("lubridate.week.start", 7), + locale = Sys.getlocale("LC_TIME")) { + if (label) { + if (abbr) { + format <- "%a" + } else { + format <- "%A" + } + return(Expression$create("strftime", x, options = list(format = format, locale = locale))) + } + + Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start)) + }) + + register_binding("month", function(x, label = FALSE, abbr = TRUE, locale = Sys.getlocale("LC_TIME")) { + if (label) { + if (abbr) { + format <- "%b" + } else { + format <- "%B" + } + return(Expression$create("strftime", x, options = list(format = format, locale = locale))) + } + + Expression$create("month", x) + }) + + register_binding("is.Date", function(x) { + inherits(x, "Date") || + (inherits(x, "Expression") && x$type_id() %in% Type[c("DATE32", "DATE64")]) + }) + + is_instant_binding <- function(x) { + inherits(x, c("POSIXt", "POSIXct", "POSIXlt", "Date")) || + (inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP", "DATE32", "DATE64")]) + } + register_binding("is.instant", is_instant_binding) + register_binding("is.timepoint", is_instant_binding) + + register_binding("is.POSIXct", function(x) { + inherits(x, "POSIXct") || + (inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")]) + }) +} diff --git a/r/R/dplyr-funcs-math.R b/r/R/dplyr-funcs-math.R new file mode 100644 index 00000000000..b92c202d048 --- /dev/null +++ b/r/R/dplyr-funcs-math.R @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +register_bindings_math <- function() { + log_binding <- function(x, base = exp(1)) { + # like other binary functions, either `x` or `base` can be Expression or double(1) + if (is.numeric(x) && length(x) == 1) { + x <- Expression$scalar(x) + } else if (!inherits(x, "Expression")) { + arrow_not_supported("x must be a column or a length-1 numeric; other values") + } + + # handle `base` differently because we use the simpler ln, log2, and log10 + # functions for specific scalar base values + if (inherits(base, "Expression")) { + return(Expression$create("logb_checked", x, base)) + } + + if (!is.numeric(base) || length(base) != 1) { + arrow_not_supported("base must be a column or a length-1 numeric; other values") + } + + if (base == exp(1)) { + return(Expression$create("ln_checked", x)) + } + + if (base == 2) { + return(Expression$create("log2_checked", x)) + } + + if (base == 10) { + return(Expression$create("log10_checked", x)) + } + + Expression$create("logb_checked", x, Expression$scalar(base)) + } + + register_binding("log", log_binding) + register_binding("logb", log_binding) + + register_binding("pmin", function(..., na.rm = FALSE) { + build_expr( + "min_element_wise", + ..., + options = list(skip_nulls = na.rm) + ) + }) + + register_binding("pmax", function(..., na.rm = FALSE) { + build_expr( + "max_element_wise", + ..., + options = list(skip_nulls = na.rm) + ) + }) + + register_binding("trunc", function(x, ...) { + # accepts and ignores ... for consistency with base::trunc() + build_expr("trunc", x) + }) + + register_binding("round", function(x, digits = 0) { + build_expr( + "round", + x, + options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN) + ) + }) +} diff --git a/r/R/dplyr-funcs-string.R b/r/R/dplyr-funcs-string.R new file mode 100644 index 00000000000..0d28ec858c0 --- /dev/null +++ b/r/R/dplyr-funcs-string.R @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# String function helpers + +#' Get `stringr` pattern options +#' +#' This function assigns definitions for the `stringr` pattern modifier +#' functions (`fixed()`, `regex()`, etc.) inside itself, and uses them to +#' evaluate the quoted expression `pattern`, returning a list that is used +#' to control pattern matching behavior in internal `arrow` functions. +#' +#' @param pattern Unevaluated expression containing a call to a `stringr` +#' pattern modifier function +#' +#' @return List containing elements `pattern`, `fixed`, and `ignore_case` +#' @keywords internal +get_stringr_pattern_options <- function(pattern) { + fixed <- function(pattern, ignore_case = FALSE, ...) { + check_dots(...) + list(pattern = pattern, fixed = TRUE, ignore_case = ignore_case) + } + regex <- function(pattern, ignore_case = FALSE, ...) { + check_dots(...) + list(pattern = pattern, fixed = FALSE, ignore_case = ignore_case) + } + coll <- function(...) { + arrow_not_supported("Pattern modifier `coll()`") + } + boundary <- function(...) { + arrow_not_supported("Pattern modifier `boundary()`") + } + check_dots <- function(...) { + dots <- list(...) + if (length(dots)) { + warning( + "Ignoring pattern modifier ", + ngettext(length(dots), "argument ", "arguments "), + "not supported in Arrow: ", + oxford_paste(names(dots)), + call. = FALSE + ) + } + } + ensure_opts <- function(opts) { + if (is.character(opts)) { + opts <- list(pattern = opts, fixed = FALSE, ignore_case = FALSE) + } + opts + } + ensure_opts(eval(pattern)) +} + +#' Does this string contain regex metacharacters? +#' +#' @param string String to be tested +#' @keywords internal +#' @return Logical: does `string` contain regex metacharacters? +contains_regex <- function(string) { + grepl("[.\\|()[{^$*+?]", string) +} + +# format `pattern` as needed for case insensitivity and literal matching by RE2 +format_string_pattern <- function(pattern, ignore.case, fixed) { + # Arrow lacks native support for case-insensitive literal string matching and + # replacement, so we use the regular expression engine (RE2) to do this. + # https://github.com/google/re2/wiki/Syntax + if (ignore.case) { + if (fixed) { + # Everything between "\Q" and "\E" is treated as literal text. + # If the search text contains any literal "\E" strings, make them + # lowercase so they won't signal the end of the literal text: + pattern <- gsub("\\E", "\\e", pattern, fixed = TRUE) + pattern <- paste0("\\Q", pattern, "\\E") + } + # Prepend "(?i)" for case-insensitive matching + pattern <- paste0("(?i)", pattern) + } + pattern +} + +# format `replacement` as needed for literal replacement by RE2 +format_string_replacement <- function(replacement, ignore.case, fixed) { + # Arrow lacks native support for case-insensitive literal string + # replacement, so we use the regular expression engine (RE2) to do this. + # https://github.com/google/re2/wiki/Syntax + if (ignore.case && fixed) { + # Escape single backslashes in the regex replacement text so they are + # interpreted as literal backslashes: + replacement <- gsub("\\", "\\\\", replacement, fixed = TRUE) + } + replacement +} + +# Currently, Arrow does not supports a locale option for string case conversion +# functions, contrast to stringr's API, so the 'locale' argument is only valid +# for stringr's default value ("en"). The following are string functions that +# take a 'locale' option as its second argument: +# str_to_lower +# str_to_upper +# str_to_title +# +# Arrow locale will be supported with ARROW-14126 +stop_if_locale_provided <- function(locale) { + if (!identical(locale, "en")) { + stop("Providing a value for 'locale' other than the default ('en') is not supported in Arrow. ", + "To change locale, use 'Sys.setlocale()'", + call. = FALSE + ) + } +} + + +# Split up into several register functions by category to satisfy the linter +register_bindings_string <- function() { + register_bindings_string_join() + register_bindings_string_regex() + register_bindings_string_other() +} + +register_bindings_string_join <- function() { + + arrow_string_join_function <- function(null_handling, null_replacement = NULL) { + # the `binary_join_element_wise` Arrow C++ compute kernel takes the separator + # as the last argument, so pass `sep` as the last dots arg to this function + function(...) { + args <- lapply(list(...), function(arg) { + # handle scalar literal args, and cast all args to string for + # consistency with base::paste(), base::paste0(), and stringr::str_c() + if (!inherits(arg, "Expression")) { + assert_that( + length(arg) == 1, + msg = "Literal vectors of length != 1 not supported in string concatenation" + ) + Expression$scalar(as.character(arg)) + } else { + call_binding("as.character", arg) + } + }) + Expression$create( + "binary_join_element_wise", + args = args, + options = list( + null_handling = null_handling, + null_replacement = null_replacement + ) + ) + } + } + + register_binding("paste", function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { + assert_that( + is.null(collapse), + msg = "paste() with the collapse argument is not yet supported in Arrow" + ) + if (!inherits(sep, "Expression")) { + assert_that(!is.na(sep), msg = "Invalid separator") + } + arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) + }) + + register_binding("paste0", function(..., collapse = NULL, recycle0 = FALSE) { + assert_that( + is.null(collapse), + msg = "paste0() with the collapse argument is not yet supported in Arrow" + ) + arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") + }) + + register_binding("str_c", function(..., sep = "", collapse = NULL) { + assert_that( + is.null(collapse), + msg = "str_c() with the collapse argument is not yet supported in Arrow" + ) + arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) + }) +} + +register_bindings_string_regex <- function() { + + register_binding("grepl", function(pattern, x, ignore.case = FALSE, fixed = FALSE) { + arrow_fun <- ifelse(fixed, "match_substring", "match_substring_regex") + Expression$create( + arrow_fun, + x, + options = list(pattern = pattern, ignore_case = ignore.case) + ) + }) + + register_binding("str_detect", function(string, pattern, negate = FALSE) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + out <- call_binding("grepl", + pattern = opts$pattern, + x = string, + ignore.case = opts$ignore_case, + fixed = opts$fixed + ) + if (negate) { + out <- !out + } + out + }) + + register_binding("str_like", function(string, pattern, ignore_case = TRUE) { + Expression$create( + "match_like", + string, + options = list(pattern = pattern, ignore_case = ignore_case) + ) + }) + + register_binding("str_count", function(string, pattern) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + if (!is.string(pattern)) { + arrow_not_supported("`pattern` must be a length 1 character vector; other values") + } + arrow_fun <- ifelse(opts$fixed, "count_substring", "count_substring_regex") + Expression$create( + arrow_fun, + string, + options = list(pattern = opts$pattern, ignore_case = opts$ignore_case) + ) + }) + + register_binding("startsWith", function(x, prefix) { + Expression$create( + "starts_with", + x, + options = list(pattern = prefix) + ) + }) + + register_binding("endsWith", function(x, suffix) { + Expression$create( + "ends_with", + x, + options = list(pattern = suffix) + ) + }) + + register_binding("str_starts", function(string, pattern, negate = FALSE) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + if (opts$fixed) { + out <- call_binding("startsWith", x = string, prefix = opts$pattern) + } else { + out <- call_binding("grepl", pattern = paste0("^", opts$pattern), x = string, fixed = FALSE) + } + + if (negate) { + out <- !out + } + out + }) + + register_binding("str_ends", function(string, pattern, negate = FALSE) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + if (opts$fixed) { + out <- call_binding("endsWith", x = string, suffix = opts$pattern) + } else { + out <- call_binding("grepl", pattern = paste0(opts$pattern, "$"), x = string, fixed = FALSE) + } + + if (negate) { + out <- !out + } + out + }) + + # Encapsulate some common logic for sub/gsub/str_replace/str_replace_all + arrow_r_string_replace_function <- function(max_replacements) { + function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) { + Expression$create( + ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"), + x, + options = list( + pattern = format_string_pattern(pattern, ignore.case, fixed), + replacement = format_string_replacement(replacement, ignore.case, fixed), + max_replacements = max_replacements + ) + ) + } + } + + arrow_stringr_string_replace_function <- function(max_replacements) { + force(max_replacements) + function(string, pattern, replacement) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + arrow_r_string_replace_function(max_replacements)( + pattern = opts$pattern, + replacement = replacement, + x = string, + ignore.case = opts$ignore_case, + fixed = opts$fixed + ) + } + } + + register_binding("sub", arrow_r_string_replace_function(1L)) + register_binding("gsub", arrow_r_string_replace_function(-1L)) + register_binding("str_replace", arrow_stringr_string_replace_function(1L)) + register_binding("str_replace_all", arrow_stringr_string_replace_function(-1L)) + + register_binding("strsplit", function(x, split, fixed = FALSE, perl = FALSE, + useBytes = FALSE) { + assert_that(is.string(split)) + + arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex") + # warn when the user specifies both fixed = TRUE and perl = TRUE, for + # consistency with the behavior of base::strsplit() + if (fixed && perl) { + warning("Argument 'perl = TRUE' will be ignored", call. = FALSE) + } + # since split is not a regex, proceed without any warnings or errors regardless + # of the value of perl, for consistency with the behavior of base::strsplit() + Expression$create( + arrow_fun, + x, + options = list(pattern = split, reverse = FALSE, max_splits = -1L) + ) + }) + + register_binding("str_split", function(string, pattern, n = Inf, simplify = FALSE) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex") + if (opts$ignore_case) { + arrow_not_supported("Case-insensitive string splitting") + } + if (n == 0) { + arrow_not_supported("Splitting strings into zero parts") + } + if (identical(n, Inf)) { + n <- 0L + } + if (simplify) { + warning("Argument 'simplify = TRUE' will be ignored", call. = FALSE) + } + # The max_splits option in the Arrow C++ library controls the maximum number + # of places at which the string is split, whereas the argument n to + # str_split() controls the maximum number of pieces to return. So we must + # subtract 1 from n to get max_splits. + Expression$create( + arrow_fun, + string, + options = list( + pattern = opts$pattern, + reverse = FALSE, + max_splits = n - 1L + ) + ) + }) +} + +register_bindings_string_other <- function() { + + register_binding("nchar", function(x, type = "chars", allowNA = FALSE, keepNA = NA) { + if (allowNA) { + arrow_not_supported("allowNA = TRUE") + } + if (is.na(keepNA)) { + keepNA <- !identical(type, "width") + } + if (!keepNA) { + # TODO: I think there is a fill_null kernel we could use, set null to 2 + arrow_not_supported("keepNA = TRUE") + } + if (identical(type, "bytes")) { + Expression$create("binary_length", x) + } else { + Expression$create("utf8_length", x) + } + }) + + register_binding("str_to_lower", function(string, locale = "en") { + stop_if_locale_provided(locale) + Expression$create("utf8_lower", string) + }) + + register_binding("str_to_upper", function(string, locale = "en") { + stop_if_locale_provided(locale) + Expression$create("utf8_upper", string) + }) + + register_binding("str_to_title", function(string, locale = "en") { + stop_if_locale_provided(locale) + Expression$create("utf8_title", string) + }) + + register_binding("str_trim", function(string, side = c("both", "left", "right")) { + side <- match.arg(side) + trim_fun <- switch(side, + left = "utf8_ltrim_whitespace", + right = "utf8_rtrim_whitespace", + both = "utf8_trim_whitespace" + ) + Expression$create(trim_fun, string) + }) + + register_binding("substr", function(x, start, stop) { + assert_that( + length(start) == 1, + msg = "`start` must be length 1 - other lengths are not supported in Arrow" + ) + assert_that( + length(stop) == 1, + msg = "`stop` must be length 1 - other lengths are not supported in Arrow" + ) + + # substr treats values as if they're on a continous number line, so values + # 0 are effectively blank characters - set `start` to 1 here so Arrow mimics + # this behavior + if (start <= 0) { + start <- 1 + } + + # if `stop` is lower than `start`, this is invalid, so set `stop` to + # 0 so that an empty string will be returned (consistent with base::substr()) + if (stop < start) { + stop <- 0 + } + + Expression$create( + "utf8_slice_codeunits", + x, + # we don't need to subtract 1 from `stop` as C++ counts exclusively + # which effectively cancels out the difference in indexing between R & C++ + options = list(start = start - 1L, stop = stop) + ) + }) + + register_binding("substring", function(text, first, last) { + call_binding("substr", x = text, start = first, stop = last) + }) + + register_binding("str_sub", function(string, start = 1L, end = -1L) { + assert_that( + length(start) == 1, + msg = "`start` must be length 1 - other lengths are not supported in Arrow" + ) + assert_that( + length(end) == 1, + msg = "`end` must be length 1 - other lengths are not supported in Arrow" + ) + + # In stringr::str_sub, an `end` value of -1 means the end of the string, so + # set it to the maximum integer to match this behavior + if (end == -1) { + end <- .Machine$integer.max + } + + # An end value lower than a start value returns an empty string in + # stringr::str_sub so set end to 0 here to match this behavior + if (end < start) { + end <- 0 + } + + # subtract 1 from `start` because C++ is 0-based and R is 1-based + # str_sub treats a `start` value of 0 or 1 as the same thing so don't subtract 1 when `start` == 0 + # when `start` < 0, both str_sub and utf8_slice_codeunits count backwards from the end + if (start > 0) { + start <- start - 1L + } + + Expression$create( + "utf8_slice_codeunits", + string, + options = list(start = start, stop = end) + ) + }) + + + register_binding("str_pad", function(string, width, side = c("left", "right", "both"), pad = " ") { + assert_that(is_integerish(width)) + side <- match.arg(side) + assert_that(is.string(pad)) + + if (side == "left") { + pad_func <- "utf8_lpad" + } else if (side == "right") { + pad_func <- "utf8_rpad" + } else if (side == "both") { + pad_func <- "utf8_center" + } + + Expression$create( + pad_func, + string, + options = list(width = width, padding = pad) + ) + }) +} diff --git a/r/R/dplyr-funcs-type.R b/r/R/dplyr-funcs-type.R new file mode 100644 index 00000000000..2f1fa96b835 --- /dev/null +++ b/r/R/dplyr-funcs-type.R @@ -0,0 +1,250 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Split up into several register functions by category to satisfy the linter +register_bindings_type <- function() { + register_bindings_type_cast() + register_bindings_type_inspect() + register_bindings_type_elementwise() +} + +register_bindings_type_cast <- function() { + register_binding("cast", function(x, target_type, safe = TRUE, ...) { + opts <- cast_options(safe, ...) + opts$to_type <- as_type(target_type) + Expression$create("cast", x, options = opts) + }) + + register_binding("dictionary_encode", function(x, + null_encoding_behavior = c("mask", "encode")) { + behavior <- toupper(match.arg(null_encoding_behavior)) + null_encoding_behavior <- NullEncodingBehavior[[behavior]] + Expression$create( + "dictionary_encode", + x, + options = list(null_encoding_behavior = null_encoding_behavior) + ) + }) + + # as.* type casting functions + # as.factor() is mapped in expression.R + register_binding("as.character", function(x) { + Expression$create("cast", x, options = cast_options(to_type = string())) + }) + register_binding("as.double", function(x) { + Expression$create("cast", x, options = cast_options(to_type = float64())) + }) + register_binding("as.integer", function(x) { + Expression$create( + "cast", + x, + options = cast_options( + to_type = int32(), + allow_float_truncate = TRUE, + allow_decimal_truncate = TRUE + ) + ) + }) + register_binding("as.integer64", function(x) { + Expression$create( + "cast", + x, + options = cast_options( + to_type = int64(), + allow_float_truncate = TRUE, + allow_decimal_truncate = TRUE + ) + ) + }) + register_binding("as.logical", function(x) { + Expression$create("cast", x, options = cast_options(to_type = boolean())) + }) + register_binding("as.numeric", function(x) { + Expression$create("cast", x, options = cast_options(to_type = float64())) + }) + + register_binding("is", function(object, class2) { + if (is.string(class2)) { + switch(class2, + # for R data types, pass off to is.*() functions + character = call_binding("is.character", object), + numeric = call_binding("is.numeric", object), + integer = call_binding("is.integer", object), + integer64 = call_binding("is.integer64", object), + logical = call_binding("is.logical", object), + factor = call_binding("is.factor", object), + list = call_binding("is.list", object), + # for Arrow data types, compare class2 with object$type()$ToString(), + # but first strip off any parameters to only compare the top-level data + # type, and canonicalize class2 + sub("^([^([<]+).*$", "\\1", object$type()$ToString()) == + canonical_type_str(class2) + ) + } else if (inherits(class2, "DataType")) { + object$type() == as_type(class2) + } else { + stop("Second argument to is() is not a string or DataType", call. = FALSE) + } + }) + + # Create a data frame/tibble/struct column + register_binding("tibble", function(..., .rows = NULL, .name_repair = NULL) { + if (!is.null(.rows)) arrow_not_supported(".rows") + if (!is.null(.name_repair)) arrow_not_supported(".name_repair") + + # use dots_list() because this is what tibble() uses to allow the + # useful shorthand of tibble(col1, col2) -> tibble(col1 = col1, col2 = col2) + # we have a stronger enforcement of unique names for arguments because + # it is difficult to replicate the .name_repair semantics and expanding of + # unnamed data frame arguments in the same way that the tibble() constructor + # does. + args <- rlang::dots_list(..., .named = TRUE, .homonyms = "error") + + build_expr( + "make_struct", + args = unname(args), + options = list(field_names = names(args)) + ) + }) + + register_binding("data.frame", function(..., row.names = NULL, + check.rows = NULL, check.names = TRUE, fix.empty.names = TRUE, + stringsAsFactors = FALSE) { + # we need a specific value of stringsAsFactors because the default was + # TRUE in R <= 3.6 + if (!identical(stringsAsFactors, FALSE)) { + arrow_not_supported("stringsAsFactors = TRUE") + } + + # ignore row.names and check.rows with a warning + if (!is.null(row.names)) arrow_not_supported("row.names") + if (!is.null(check.rows)) arrow_not_supported("check.rows") + + args <- rlang::dots_list(..., .named = fix.empty.names) + if (is.null(names(args))) { + names(args) <- rep("", length(args)) + } + + if (identical(check.names, TRUE)) { + if (identical(fix.empty.names, TRUE)) { + names(args) <- make.names(names(args), unique = TRUE) + } else { + name_emtpy <- names(args) == "" + names(args)[!name_emtpy] <- make.names(names(args)[!name_emtpy], unique = TRUE) + } + } + + build_expr( + "make_struct", + args = unname(args), + options = list(field_names = names(args)) + ) + }) +} + +register_bindings_type_inspect <- function() { + # is.* type functions + register_binding("is.character", function(x) { + is.character(x) || (inherits(x, "Expression") && + x$type_id() %in% Type[c("STRING", "LARGE_STRING")]) + }) + register_binding("is.numeric", function(x) { + is.numeric(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( + "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32", + "UINT64", "INT64", "HALF_FLOAT", "FLOAT", "DOUBLE", + "DECIMAL128", "DECIMAL256" + )]) + }) + register_binding("is.double", function(x) { + is.double(x) || (inherits(x, "Expression") && x$type_id() == Type["DOUBLE"]) + }) + register_binding("is.integer", function(x) { + is.integer(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( + "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32", + "UINT64", "INT64" + )]) + }) + register_binding("is.integer64", function(x) { + inherits(x, "integer64") || (inherits(x, "Expression") && x$type_id() == Type["INT64"]) + }) + register_binding("is.logical", function(x) { + is.logical(x) || (inherits(x, "Expression") && x$type_id() == Type["BOOL"]) + }) + register_binding("is.factor", function(x) { + is.factor(x) || (inherits(x, "Expression") && x$type_id() == Type["DICTIONARY"]) + }) + register_binding("is.list", function(x) { + is.list(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( + "LIST", "FIXED_SIZE_LIST", "LARGE_LIST" + )]) + }) + + # rlang::is_* type functions + register_binding("is_character", function(x, n = NULL) { + assert_that(is.null(n)) + call_binding("is.character", x) + }) + register_binding("is_double", function(x, n = NULL, finite = NULL) { + assert_that(is.null(n) && is.null(finite)) + call_binding("is.double", x) + }) + register_binding("is_integer", function(x, n = NULL) { + assert_that(is.null(n)) + call_binding("is.integer", x) + }) + register_binding("is_list", function(x, n = NULL) { + assert_that(is.null(n)) + call_binding("is.list", x) + }) + register_binding("is_logical", function(x, n = NULL) { + assert_that(is.null(n)) + call_binding("is.logical", x) + }) +} + +register_bindings_type_elementwise <- function() { + register_binding("is.na", function(x) { + build_expr("is_null", x, options = list(nan_is_null = TRUE)) + }) + + register_binding("is.nan", function(x) { + if (is.double(x) || (inherits(x, "Expression") && + x$type_id() %in% TYPES_WITH_NAN)) { + # TODO: if an option is added to the is_nan kernel to treat NA as NaN, + # use that to simplify the code here (ARROW-13366) + build_expr("is_nan", x) & build_expr("is_valid", x) + } else { + Expression$scalar(FALSE) + } + }) + + register_binding("between", function(x, left, right) { + x >= left & x <= right + }) + + register_binding("is.finite", function(x) { + is_fin <- Expression$create("is_finite", x) + # for compatibility with base::is.finite(), return FALSE for NA_real_ + is_fin & !call_binding("is.na", is_fin) + }) + + register_binding("is.infinite", function(x) { + is_inf <- Expression$create("is_inf", x) + # for compatibility with base::is.infinite(), return FALSE for NA_real_ + is_inf & !call_binding("is.na", is_inf) + }) +} diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R new file mode 100644 index 00000000000..4d7cb3bc63d --- /dev/null +++ b/r/R/dplyr-funcs.R @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +#' @include expression.R +NULL + + +#' Register compute bindings +#' +#' The `register_binding()` and `register_binding_agg()` functions +#' are used to populate a list of functions that operate on (and return) +#' Expressions. These are the basis for the `.data` mask inside dplyr methods. +#' +#' @section Writing bindings: +#' When to use `build_expr()` vs. `Expression$create()`? +#' +#' Use `build_expr()` if you need to +#' - map R function names to Arrow C++ functions +#' - wrap R inputs (vectors) as Array/Scalar +#' +#' `Expression$create()` is lower level. Most of the bindings use it +#' because they manage the preparation of the user-provided inputs +#' and don't need or don't want to the automatic conversion of R objects +#' to [Scalar]. +#' +#' @param fun_name A string containing a function name in the form `"function"` or +#' `"package::function"`. The package name is currently not used but +#' may be used in the future to allow these types of function calls. +#' @param fun A function or `NULL` to un-register a previous function. +#' This function must accept `Expression` objects as arguments and return +#' `Expression` objects instead of regular R objects. +#' @param agg_fun An aggregate function or `NULL` to un-register a previous +#' aggregate function. This function must accept `Expression` objects as +#' arguments and return a `list()` with components: +#' - `fun`: string function name +#' - `data`: `Expression` (these are all currently a single field) +#' - `options`: list of function options, as passed to call_function +#' @param registry An environment in which the functions should be +#' assigned. +#' +#' @return The previously registered binding or `NULL` if no previously +#' registered function existed. +#' @keywords internal +#' +register_binding <- function(fun_name, fun, registry = nse_funcs) { + name <- gsub("^.*?::", "", fun_name) + namespace <- gsub("::.*$", "", fun_name) + + previous_fun <- if (name %in% names(registry)) registry[[name]] else NULL + + if (is.null(fun) && !is.null(previous_fun)) { + rm(list = name, envir = registry, inherits = FALSE) + } else { + registry[[name]] <- fun + } + + invisible(previous_fun) +} + +register_binding_agg <- function(fun_name, agg_fun, registry = agg_funcs) { + register_binding(fun_name, agg_fun, registry = registry) +} + +# Supports functions and tests that call previously-defined bindings +call_binding <- function(fun_name, ...) { + nse_funcs[[fun_name]](...) +} + +call_binding_agg <- function(fun_name, ...) { + agg_funcs[[fun_name]](...) +} + +# Called in .onLoad() +create_binding_cache <- function() { + arrow_funcs <- list() + + # Register all available Arrow Compute functions, namespaced as arrow_fun. + if (arrow_available()) { + all_arrow_funs <- list_compute_functions() + arrow_funcs <- set_names( + lapply(all_arrow_funs, function(fun) { + force(fun) + function(...) build_expr(fun, ...) + }), + paste0("arrow_", all_arrow_funs) + ) + } + + # Register bindings into nse_funcs and agg_funcs + register_bindings_array_function_map() + register_bindings_aggregate() + register_bindings_conditional() + register_bindings_datetime() + register_bindings_math() + register_bindings_string() + register_bindings_type() + + # We only create the cache for nse_funcs and not agg_funcs + .cache$functions <- c(as.list(nse_funcs), arrow_funcs) +} + +# environments in the arrow namespace used in the above functions +nse_funcs <- new.env(parent = emptyenv()) +agg_funcs <- new.env(parent = emptyenv()) +.cache <- new.env(parent = emptyenv()) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R deleted file mode 100644 index ccd7ded3cca..00000000000 --- a/r/R/dplyr-functions.R +++ /dev/null @@ -1,1148 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -#' @include expression.R -NULL - -# This environment is an internal cache for things including data mask functions -# We'll populate it at package load time. -.cache <- NULL -init_env <- function() { - .cache <<- new.env(hash = TRUE) -} -init_env() - -# nse_funcs is a list of functions that operated on (and return) Expressions -# These will be the basis for a data_mask inside dplyr methods -# and will be added to .cache at package load time - -# Start with mappings from R function name spellings -nse_funcs <- lapply(set_names(names(.array_function_map)), function(operator) { - force(operator) - function(...) build_expr(operator, ...) -}) - -# Now add functions to that list where the mapping from R to Arrow isn't 1:1 -# Each of these functions should have the same signature as the R function -# they're replacing. -# -# When to use `build_expr()` vs. `Expression$create()`? -# -# Use `build_expr()` if you need to -# (1) map R function names to Arrow C++ functions -# (2) wrap R inputs (vectors) as Array/Scalar -# -# `Expression$create()` is lower level. Most of the functions below use it -# because they manage the preparation of the user-provided inputs -# and don't need to wrap scalars - -nse_funcs$cast <- function(x, target_type, safe = TRUE, ...) { - opts <- cast_options(safe, ...) - opts$to_type <- as_type(target_type) - Expression$create("cast", x, options = opts) -} - -nse_funcs$coalesce <- function(...) { - args <- list2(...) - if (length(args) < 1) { - abort("At least one argument must be supplied to coalesce()") - } - - # Treat NaN like NA for consistency with dplyr::coalesce(), but if *all* - # the values are NaN, we should return NaN, not NA, so don't replace - # NaN with NA in the final (or only) argument - # TODO: if an option is added to the coalesce kernel to treat NaN as NA, - # use that to simplify the code here (ARROW-13389) - attr(args[[length(args)]], "last") <- TRUE - args <- lapply(args, function(arg) { - last_arg <- is.null(attr(arg, "last")) - attr(arg, "last") <- NULL - - if (!inherits(arg, "Expression")) { - arg <- Expression$scalar(arg) - } - - if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) { - # store the NA_real_ in the same type as arg to avoid avoid casting - # smaller float types to larger float types - NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type())) - Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg) - } else { - arg - } - }) - Expression$create("coalesce", args = args) -} - -nse_funcs$is.na <- function(x) { - build_expr("is_null", x, options = list(nan_is_null = TRUE)) -} - -nse_funcs$is.nan <- function(x) { - if (is.double(x) || (inherits(x, "Expression") && - x$type_id() %in% TYPES_WITH_NAN)) { - # TODO: if an option is added to the is_nan kernel to treat NA as NaN, - # use that to simplify the code here (ARROW-13366) - build_expr("is_nan", x) & build_expr("is_valid", x) - } else { - Expression$scalar(FALSE) - } -} - -nse_funcs$is <- function(object, class2) { - if (is.string(class2)) { - switch(class2, - # for R data types, pass off to is.*() functions - character = nse_funcs$is.character(object), - numeric = nse_funcs$is.numeric(object), - integer = nse_funcs$is.integer(object), - integer64 = nse_funcs$is.integer64(object), - logical = nse_funcs$is.logical(object), - factor = nse_funcs$is.factor(object), - list = nse_funcs$is.list(object), - # for Arrow data types, compare class2 with object$type()$ToString(), - # but first strip off any parameters to only compare the top-level data - # type, and canonicalize class2 - sub("^([^([<]+).*$", "\\1", object$type()$ToString()) == - canonical_type_str(class2) - ) - } else if (inherits(class2, "DataType")) { - object$type() == as_type(class2) - } else { - stop("Second argument to is() is not a string or DataType", call. = FALSE) - } -} - -nse_funcs$dictionary_encode <- function(x, - null_encoding_behavior = c("mask", "encode")) { - behavior <- toupper(match.arg(null_encoding_behavior)) - null_encoding_behavior <- NullEncodingBehavior[[behavior]] - Expression$create( - "dictionary_encode", - x, - options = list(null_encoding_behavior = null_encoding_behavior) - ) -} - -nse_funcs$between <- function(x, left, right) { - x >= left & x <= right -} - -nse_funcs$is.finite <- function(x) { - is_fin <- Expression$create("is_finite", x) - # for compatibility with base::is.finite(), return FALSE for NA_real_ - is_fin & !nse_funcs$is.na(is_fin) -} - -nse_funcs$is.infinite <- function(x) { - is_inf <- Expression$create("is_inf", x) - # for compatibility with base::is.infinite(), return FALSE for NA_real_ - is_inf & !nse_funcs$is.na(is_inf) -} - -# as.* type casting functions -# as.factor() is mapped in expression.R -nse_funcs$as.character <- function(x) { - Expression$create("cast", x, options = cast_options(to_type = string())) -} -nse_funcs$as.double <- function(x) { - Expression$create("cast", x, options = cast_options(to_type = float64())) -} -nse_funcs$as.integer <- function(x) { - Expression$create( - "cast", - x, - options = cast_options( - to_type = int32(), - allow_float_truncate = TRUE, - allow_decimal_truncate = TRUE - ) - ) -} -nse_funcs$as.integer64 <- function(x) { - Expression$create( - "cast", - x, - options = cast_options( - to_type = int64(), - allow_float_truncate = TRUE, - allow_decimal_truncate = TRUE - ) - ) -} -nse_funcs$as.logical <- function(x) { - Expression$create("cast", x, options = cast_options(to_type = boolean())) -} -nse_funcs$as.numeric <- function(x) { - Expression$create("cast", x, options = cast_options(to_type = float64())) -} - -# is.* type functions -nse_funcs$is.character <- function(x) { - is.character(x) || (inherits(x, "Expression") && - x$type_id() %in% Type[c("STRING", "LARGE_STRING")]) -} -nse_funcs$is.numeric <- function(x) { - is.numeric(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( - "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32", - "UINT64", "INT64", "HALF_FLOAT", "FLOAT", "DOUBLE", - "DECIMAL128", "DECIMAL256" - )]) -} -nse_funcs$is.double <- function(x) { - is.double(x) || (inherits(x, "Expression") && x$type_id() == Type["DOUBLE"]) -} -nse_funcs$is.integer <- function(x) { - is.integer(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( - "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32", - "UINT64", "INT64" - )]) -} -nse_funcs$is.integer64 <- function(x) { - is.integer64(x) || (inherits(x, "Expression") && x$type_id() == Type["INT64"]) -} -nse_funcs$is.logical <- function(x) { - is.logical(x) || (inherits(x, "Expression") && x$type_id() == Type["BOOL"]) -} -nse_funcs$is.factor <- function(x) { - is.factor(x) || (inherits(x, "Expression") && x$type_id() == Type["DICTIONARY"]) -} -nse_funcs$is.list <- function(x) { - is.list(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( - "LIST", "FIXED_SIZE_LIST", "LARGE_LIST" - )]) -} - -# rlang::is_* type functions -nse_funcs$is_character <- function(x, n = NULL) { - assert_that(is.null(n)) - nse_funcs$is.character(x) -} -nse_funcs$is_double <- function(x, n = NULL, finite = NULL) { - assert_that(is.null(n) && is.null(finite)) - nse_funcs$is.double(x) -} -nse_funcs$is_integer <- function(x, n = NULL) { - assert_that(is.null(n)) - nse_funcs$is.integer(x) -} -nse_funcs$is_list <- function(x, n = NULL) { - assert_that(is.null(n)) - nse_funcs$is.list(x) -} -nse_funcs$is_logical <- function(x, n = NULL) { - assert_that(is.null(n)) - nse_funcs$is.logical(x) -} - -# Create a data frame/tibble/struct column -nse_funcs$tibble <- function(..., .rows = NULL, .name_repair = NULL) { - if (!is.null(.rows)) arrow_not_supported(".rows") - if (!is.null(.name_repair)) arrow_not_supported(".name_repair") - - # use dots_list() because this is what tibble() uses to allow the - # useful shorthand of tibble(col1, col2) -> tibble(col1 = col1, col2 = col2) - # we have a stronger enforcement of unique names for arguments because - # it is difficult to replicate the .name_repair semantics and expanding of - # unnamed data frame arguments in the same way that the tibble() constructor - # does. - args <- rlang::dots_list(..., .named = TRUE, .homonyms = "error") - - build_expr( - "make_struct", - args = unname(args), - options = list(field_names = names(args)) - ) -} - -nse_funcs$data.frame <- function(..., row.names = NULL, - check.rows = NULL, check.names = TRUE, fix.empty.names = TRUE, - stringsAsFactors = FALSE) { - # we need a specific value of stringsAsFactors because the default was - # TRUE in R <= 3.6 - if (!identical(stringsAsFactors, FALSE)) { - arrow_not_supported("stringsAsFactors = TRUE") - } - - # ignore row.names and check.rows with a warning - if (!is.null(row.names)) arrow_not_supported("row.names") - if (!is.null(check.rows)) arrow_not_supported("check.rows") - - args <- rlang::dots_list(..., .named = fix.empty.names) - if (is.null(names(args))) { - names(args) <- rep("", length(args)) - } - - if (identical(check.names, TRUE)) { - if (identical(fix.empty.names, TRUE)) { - names(args) <- make.names(names(args), unique = TRUE) - } else { - name_emtpy <- names(args) == "" - names(args)[!name_emtpy] <- make.names(names(args)[!name_emtpy], unique = TRUE) - } - } - - build_expr( - "make_struct", - args = unname(args), - options = list(field_names = names(args)) - ) -} - -# String functions -nse_funcs$nchar <- function(x, type = "chars", allowNA = FALSE, keepNA = NA) { - if (allowNA) { - arrow_not_supported("allowNA = TRUE") - } - if (is.na(keepNA)) { - keepNA <- !identical(type, "width") - } - if (!keepNA) { - # TODO: I think there is a fill_null kernel we could use, set null to 2 - arrow_not_supported("keepNA = TRUE") - } - if (identical(type, "bytes")) { - Expression$create("binary_length", x) - } else { - Expression$create("utf8_length", x) - } -} - -nse_funcs$paste <- function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { - assert_that( - is.null(collapse), - msg = "paste() with the collapse argument is not yet supported in Arrow" - ) - if (!inherits(sep, "Expression")) { - assert_that(!is.na(sep), msg = "Invalid separator") - } - arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) -} - -nse_funcs$paste0 <- function(..., collapse = NULL, recycle0 = FALSE) { - assert_that( - is.null(collapse), - msg = "paste0() with the collapse argument is not yet supported in Arrow" - ) - arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") -} - -nse_funcs$str_c <- function(..., sep = "", collapse = NULL) { - assert_that( - is.null(collapse), - msg = "str_c() with the collapse argument is not yet supported in Arrow" - ) - arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) -} - -arrow_string_join_function <- function(null_handling, null_replacement = NULL) { - # the `binary_join_element_wise` Arrow C++ compute kernel takes the separator - # as the last argument, so pass `sep` as the last dots arg to this function - function(...) { - args <- lapply(list(...), function(arg) { - # handle scalar literal args, and cast all args to string for - # consistency with base::paste(), base::paste0(), and stringr::str_c() - if (!inherits(arg, "Expression")) { - assert_that( - length(arg) == 1, - msg = "Literal vectors of length != 1 not supported in string concatenation" - ) - Expression$scalar(as.character(arg)) - } else { - nse_funcs$as.character(arg) - } - }) - Expression$create( - "binary_join_element_wise", - args = args, - options = list( - null_handling = null_handling, - null_replacement = null_replacement - ) - ) - } -} - -# Currently, Arrow does not supports a locale option for string case conversion -# functions, contrast to stringr's API, so the 'locale' argument is only valid -# for stringr's default value ("en"). The following are string functions that -# take a 'locale' option as its second argument: -# str_to_lower -# str_to_upper -# str_to_title -# -# Arrow locale will be supported with ARROW-14126 -stop_if_locale_provided <- function(locale) { - if (!identical(locale, "en")) { - stop("Providing a value for 'locale' other than the default ('en') is not supported in Arrow. ", - "To change locale, use 'Sys.setlocale()'", - call. = FALSE - ) - } -} - -nse_funcs$str_to_lower <- function(string, locale = "en") { - stop_if_locale_provided(locale) - Expression$create("utf8_lower", string) -} - -nse_funcs$str_to_upper <- function(string, locale = "en") { - stop_if_locale_provided(locale) - Expression$create("utf8_upper", string) -} - -nse_funcs$str_to_title <- function(string, locale = "en") { - stop_if_locale_provided(locale) - Expression$create("utf8_title", string) -} - -nse_funcs$str_trim <- function(string, side = c("both", "left", "right")) { - side <- match.arg(side) - trim_fun <- switch(side, - left = "utf8_ltrim_whitespace", - right = "utf8_rtrim_whitespace", - both = "utf8_trim_whitespace" - ) - Expression$create(trim_fun, string) -} - -nse_funcs$substr <- function(x, start, stop) { - assert_that( - length(start) == 1, - msg = "`start` must be length 1 - other lengths are not supported in Arrow" - ) - assert_that( - length(stop) == 1, - msg = "`stop` must be length 1 - other lengths are not supported in Arrow" - ) - - # substr treats values as if they're on a continous number line, so values - # 0 are effectively blank characters - set `start` to 1 here so Arrow mimics - # this behavior - if (start <= 0) { - start <- 1 - } - - # if `stop` is lower than `start`, this is invalid, so set `stop` to - # 0 so that an empty string will be returned (consistent with base::substr()) - if (stop < start) { - stop <- 0 - } - - Expression$create( - "utf8_slice_codeunits", - x, - # we don't need to subtract 1 from `stop` as C++ counts exclusively - # which effectively cancels out the difference in indexing between R & C++ - options = list(start = start - 1L, stop = stop) - ) -} - -nse_funcs$substring <- function(text, first, last) { - nse_funcs$substr(x = text, start = first, stop = last) -} - -nse_funcs$str_sub <- function(string, start = 1L, end = -1L) { - assert_that( - length(start) == 1, - msg = "`start` must be length 1 - other lengths are not supported in Arrow" - ) - assert_that( - length(end) == 1, - msg = "`end` must be length 1 - other lengths are not supported in Arrow" - ) - - # In stringr::str_sub, an `end` value of -1 means the end of the string, so - # set it to the maximum integer to match this behavior - if (end == -1) { - end <- .Machine$integer.max - } - - # An end value lower than a start value returns an empty string in - # stringr::str_sub so set end to 0 here to match this behavior - if (end < start) { - end <- 0 - } - - # subtract 1 from `start` because C++ is 0-based and R is 1-based - # str_sub treats a `start` value of 0 or 1 as the same thing so don't subtract 1 when `start` == 0 - # when `start` < 0, both str_sub and utf8_slice_codeunits count backwards from the end - if (start > 0) { - start <- start - 1L - } - - Expression$create( - "utf8_slice_codeunits", - string, - options = list(start = start, stop = end) - ) -} - -nse_funcs$grepl <- function(pattern, x, ignore.case = FALSE, fixed = FALSE) { - arrow_fun <- ifelse(fixed, "match_substring", "match_substring_regex") - Expression$create( - arrow_fun, - x, - options = list(pattern = pattern, ignore_case = ignore.case) - ) -} - -nse_funcs$str_detect <- function(string, pattern, negate = FALSE) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - out <- nse_funcs$grepl( - pattern = opts$pattern, - x = string, - ignore.case = opts$ignore_case, - fixed = opts$fixed - ) - if (negate) { - out <- !out - } - out -} - -nse_funcs$str_like <- function(string, pattern, ignore_case = TRUE) { - Expression$create( - "match_like", - string, - options = list(pattern = pattern, ignore_case = ignore_case) - ) -} - -# Encapsulate some common logic for sub/gsub/str_replace/str_replace_all -arrow_r_string_replace_function <- function(max_replacements) { - function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) { - Expression$create( - ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"), - x, - options = list( - pattern = format_string_pattern(pattern, ignore.case, fixed), - replacement = format_string_replacement(replacement, ignore.case, fixed), - max_replacements = max_replacements - ) - ) - } -} - -arrow_stringr_string_replace_function <- function(max_replacements) { - function(string, pattern, replacement) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - arrow_r_string_replace_function(max_replacements)( - pattern = opts$pattern, - replacement = replacement, - x = string, - ignore.case = opts$ignore_case, - fixed = opts$fixed - ) - } -} - -nse_funcs$sub <- arrow_r_string_replace_function(1L) -nse_funcs$gsub <- arrow_r_string_replace_function(-1L) -nse_funcs$str_replace <- arrow_stringr_string_replace_function(1L) -nse_funcs$str_replace_all <- arrow_stringr_string_replace_function(-1L) - -nse_funcs$strsplit <- function(x, - split, - fixed = FALSE, - perl = FALSE, - useBytes = FALSE) { - assert_that(is.string(split)) - - arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex") - # warn when the user specifies both fixed = TRUE and perl = TRUE, for - # consistency with the behavior of base::strsplit() - if (fixed && perl) { - warning("Argument 'perl = TRUE' will be ignored", call. = FALSE) - } - # since split is not a regex, proceed without any warnings or errors regardless - # of the value of perl, for consistency with the behavior of base::strsplit() - Expression$create( - arrow_fun, - x, - options = list(pattern = split, reverse = FALSE, max_splits = -1L) - ) -} - -nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex") - if (opts$ignore_case) { - arrow_not_supported("Case-insensitive string splitting") - } - if (n == 0) { - arrow_not_supported("Splitting strings into zero parts") - } - if (identical(n, Inf)) { - n <- 0L - } - if (simplify) { - warning("Argument 'simplify = TRUE' will be ignored", call. = FALSE) - } - # The max_splits option in the Arrow C++ library controls the maximum number - # of places at which the string is split, whereas the argument n to - # str_split() controls the maximum number of pieces to return. So we must - # subtract 1 from n to get max_splits. - Expression$create( - arrow_fun, - string, - options = list( - pattern = opts$pattern, - reverse = FALSE, - max_splits = n - 1L - ) - ) -} - -nse_funcs$pmin <- function(..., na.rm = FALSE) { - build_expr( - "min_element_wise", - ..., - options = list(skip_nulls = na.rm) - ) -} - -nse_funcs$pmax <- function(..., na.rm = FALSE) { - build_expr( - "max_element_wise", - ..., - options = list(skip_nulls = na.rm) - ) -} - -nse_funcs$str_pad <- function(string, width, side = c("left", "right", "both"), pad = " ") { - assert_that(is_integerish(width)) - side <- match.arg(side) - assert_that(is.string(pad)) - - if (side == "left") { - pad_func <- "utf8_lpad" - } else if (side == "right") { - pad_func <- "utf8_rpad" - } else if (side == "both") { - pad_func <- "utf8_center" - } - - Expression$create( - pad_func, - string, - options = list(width = width, padding = pad) - ) -} - -nse_funcs$startsWith <- function(x, prefix) { - Expression$create( - "starts_with", - x, - options = list(pattern = prefix) - ) -} - -nse_funcs$endsWith <- function(x, suffix) { - Expression$create( - "ends_with", - x, - options = list(pattern = suffix) - ) -} - -nse_funcs$str_starts <- function(string, pattern, negate = FALSE) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - if (opts$fixed) { - out <- nse_funcs$startsWith(x = string, prefix = opts$pattern) - } else { - out <- nse_funcs$grepl(pattern = paste0("^", opts$pattern), x = string, fixed = FALSE) - } - - if (negate) { - out <- !out - } - out -} - -nse_funcs$str_ends <- function(string, pattern, negate = FALSE) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - if (opts$fixed) { - out <- nse_funcs$endsWith(x = string, suffix = opts$pattern) - } else { - out <- nse_funcs$grepl(pattern = paste0(opts$pattern, "$"), x = string, fixed = FALSE) - } - - if (negate) { - out <- !out - } - out -} - -nse_funcs$str_count <- function(string, pattern) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - if (!is.string(pattern)) { - arrow_not_supported("`pattern` must be a length 1 character vector; other values") - } - arrow_fun <- ifelse(opts$fixed, "count_substring", "count_substring_regex") - Expression$create( - arrow_fun, - string, - options = list(pattern = opts$pattern, ignore_case = opts$ignore_case) - ) -} - -# String function helpers - -# format `pattern` as needed for case insensitivity and literal matching by RE2 -format_string_pattern <- function(pattern, ignore.case, fixed) { - # Arrow lacks native support for case-insensitive literal string matching and - # replacement, so we use the regular expression engine (RE2) to do this. - # https://github.com/google/re2/wiki/Syntax - if (ignore.case) { - if (fixed) { - # Everything between "\Q" and "\E" is treated as literal text. - # If the search text contains any literal "\E" strings, make them - # lowercase so they won't signal the end of the literal text: - pattern <- gsub("\\E", "\\e", pattern, fixed = TRUE) - pattern <- paste0("\\Q", pattern, "\\E") - } - # Prepend "(?i)" for case-insensitive matching - pattern <- paste0("(?i)", pattern) - } - pattern -} - -# format `replacement` as needed for literal replacement by RE2 -format_string_replacement <- function(replacement, ignore.case, fixed) { - # Arrow lacks native support for case-insensitive literal string - # replacement, so we use the regular expression engine (RE2) to do this. - # https://github.com/google/re2/wiki/Syntax - if (ignore.case && fixed) { - # Escape single backslashes in the regex replacement text so they are - # interpreted as literal backslashes: - replacement <- gsub("\\", "\\\\", replacement, fixed = TRUE) - } - replacement -} - -#' Get `stringr` pattern options -#' -#' This function assigns definitions for the `stringr` pattern modifier -#' functions (`fixed()`, `regex()`, etc.) inside itself, and uses them to -#' evaluate the quoted expression `pattern`, returning a list that is used -#' to control pattern matching behavior in internal `arrow` functions. -#' -#' @param pattern Unevaluated expression containing a call to a `stringr` -#' pattern modifier function -#' -#' @return List containing elements `pattern`, `fixed`, and `ignore_case` -#' @keywords internal -get_stringr_pattern_options <- function(pattern) { - fixed <- function(pattern, ignore_case = FALSE, ...) { - check_dots(...) - list(pattern = pattern, fixed = TRUE, ignore_case = ignore_case) - } - regex <- function(pattern, ignore_case = FALSE, ...) { - check_dots(...) - list(pattern = pattern, fixed = FALSE, ignore_case = ignore_case) - } - coll <- function(...) { - arrow_not_supported("Pattern modifier `coll()`") - } - boundary <- function(...) { - arrow_not_supported("Pattern modifier `boundary()`") - } - check_dots <- function(...) { - dots <- list(...) - if (length(dots)) { - warning( - "Ignoring pattern modifier ", - ngettext(length(dots), "argument ", "arguments "), - "not supported in Arrow: ", - oxford_paste(names(dots)), - call. = FALSE - ) - } - } - ensure_opts <- function(opts) { - if (is.character(opts)) { - opts <- list(pattern = opts, fixed = FALSE, ignore_case = FALSE) - } - opts - } - ensure_opts(eval(pattern)) -} - -#' Does this string contain regex metacharacters? -#' -#' @param string String to be tested -#' @keywords internal -#' @return Logical: does `string` contain regex metacharacters? -contains_regex <- function(string) { - grepl("[.\\|()[{^$*+?]", string) -} - -nse_funcs$trunc <- function(x, ...) { - # accepts and ignores ... for consistency with base::trunc() - build_expr("trunc", x) -} - -nse_funcs$round <- function(x, digits = 0) { - build_expr( - "round", - x, - options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN) - ) -} - -nse_funcs$strptime <- function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL, unit = "ms") { - # Arrow uses unit for time parsing, strptime() does not. - # Arrow has no default option for strptime (format, unit), - # we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms", - # (ARROW-12809) - - # ParseTimestampStrptime currently ignores the timezone information (ARROW-12820). - # Stop if tz is provided. - if (is.character(tz)) { - arrow_not_supported("Time zone argument") - } - - unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units)) - - Expression$create("strptime", x, options = list(format = format, unit = unit)) -} - -nse_funcs$strftime <- function(x, format = "", tz = "", usetz = FALSE) { - if (usetz) { - format <- paste(format, "%Z") - } - if (tz == "") { - tz <- Sys.timezone() - } - # Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first - # cast the timestamp to desired timezone. This is a metadata only change. - if (nse_funcs$is.POSIXct(x)) { - ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz))) - } else { - ts <- x - } - Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME"))) -} - -nse_funcs$format_ISO8601 <- function(x, usetz = FALSE, precision = NULL, ...) { - ISO8601_precision_map <- - list( - y = "%Y", - ym = "%Y-%m", - ymd = "%Y-%m-%d", - ymdh = "%Y-%m-%dT%H", - ymdhm = "%Y-%m-%dT%H:%M", - ymdhms = "%Y-%m-%dT%H:%M:%S" - ) - - if (is.null(precision)) { - precision <- "ymdhms" - } - if (!precision %in% names(ISO8601_precision_map)) { - abort( - paste( - "`precision` must be one of the following values:", - paste(names(ISO8601_precision_map), collapse = ", "), - "\nValue supplied was: ", - precision - ) - ) - } - format <- ISO8601_precision_map[[precision]] - if (usetz) { - format <- paste0(format, "%z") - } - Expression$create("strftime", x, options = list(format = format, locale = "C")) -} - -nse_funcs$second <- function(x) { - Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x)) -} - -nse_funcs$wday <- function(x, - label = FALSE, - abbr = TRUE, - week_start = getOption("lubridate.week.start", 7), - locale = Sys.getlocale("LC_TIME")) { - if (label) { - if (abbr) { - format <- "%a" - } else { - format <- "%A" - } - return(Expression$create("strftime", x, options = list(format = format, locale = locale))) - } - - Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start)) -} - -nse_funcs$month <- function(x, label = FALSE, abbr = TRUE, locale = Sys.getlocale("LC_TIME")) { - if (label) { - if (abbr) { - format <- "%b" - } else { - format <- "%B" - } - return(Expression$create("strftime", x, options = list(format = format, locale = locale))) - } - - Expression$create("month", x) -} - -nse_funcs$is.Date <- function(x) { - inherits(x, "Date") || - (inherits(x, "Expression") && x$type_id() %in% Type[c("DATE32", "DATE64")]) -} - -nse_funcs$is.instant <- nse_funcs$is.timepoint <- function(x) { - inherits(x, c("POSIXt", "POSIXct", "POSIXlt", "Date")) || - (inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP", "DATE32", "DATE64")]) -} - -nse_funcs$is.POSIXct <- function(x) { - inherits(x, "POSIXct") || - (inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")]) -} - -nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { - # like other binary functions, either `x` or `base` can be Expression or double(1) - if (is.numeric(x) && length(x) == 1) { - x <- Expression$scalar(x) - } else if (!inherits(x, "Expression")) { - arrow_not_supported("x must be a column or a length-1 numeric; other values") - } - - # handle `base` differently because we use the simpler ln, log2, and log10 - # functions for specific scalar base values - if (inherits(base, "Expression")) { - return(Expression$create("logb_checked", x, base)) - } - - if (!is.numeric(base) || length(base) != 1) { - arrow_not_supported("base must be a column or a length-1 numeric; other values") - } - - if (base == exp(1)) { - return(Expression$create("ln_checked", x)) - } - - if (base == 2) { - return(Expression$create("log2_checked", x)) - } - - if (base == 10) { - return(Expression$create("log10_checked", x)) - } - - Expression$create("logb_checked", x, Expression$scalar(base)) -} - -nse_funcs$if_else <- function(condition, true, false, missing = NULL) { - if (!is.null(missing)) { - return(nse_funcs$if_else( - nse_funcs$is.na(condition), - missing, - nse_funcs$if_else(condition, true, false) - )) - } - - build_expr("if_else", condition, true, false) -} - -# Although base R ifelse allows `yes` and `no` to be different classes -nse_funcs$ifelse <- function(test, yes, no) { - nse_funcs$if_else(condition = test, true = yes, false = no) -} - -nse_funcs$case_when <- function(...) { - formulas <- list2(...) - n <- length(formulas) - if (n == 0) { - abort("No cases provided in case_when()") - } - query <- vector("list", n) - value <- vector("list", n) - mask <- caller_env() - for (i in seq_len(n)) { - f <- formulas[[i]] - if (!inherits(f, "formula")) { - abort("Each argument to case_when() must be a two-sided formula") - } - query[[i]] <- arrow_eval(f[[2]], mask) - value[[i]] <- arrow_eval(f[[3]], mask) - if (!nse_funcs$is.logical(query[[i]])) { - abort("Left side of each formula in case_when() must be a logical expression") - } - if (inherits(value[[i]], "try-error")) { - abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) - } - } - build_expr( - "case_when", - args = c( - build_expr( - "make_struct", - args = query, - options = list(field_names = as.character(seq_along(query))) - ), - value - ) - ) -} - -# Aggregation functions -# These all return a list of: -# @param fun string function name -# @param data Expression (these are all currently a single field) -# @param options list of function options, as passed to call_function -# For group-by aggregation, `hash_` gets prepended to the function name. -# So to see a list of available hash aggregation functions, -# you can use list_compute_functions("^hash_") -agg_funcs <- list() -agg_funcs$sum <- function(..., na.rm = FALSE) { - list( - fun = "sum", - data = ensure_one_arg(list2(...), "sum"), - options = list(skip_nulls = na.rm, min_count = 0L) - ) -} -agg_funcs$any <- function(..., na.rm = FALSE) { - list( - fun = "any", - data = ensure_one_arg(list2(...), "any"), - options = list(skip_nulls = na.rm, min_count = 0L) - ) -} -agg_funcs$all <- function(..., na.rm = FALSE) { - list( - fun = "all", - data = ensure_one_arg(list2(...), "all"), - options = list(skip_nulls = na.rm, min_count = 0L) - ) -} -agg_funcs$mean <- function(x, na.rm = FALSE) { - list( - fun = "mean", - data = x, - options = list(skip_nulls = na.rm, min_count = 0L) - ) -} -agg_funcs$sd <- function(x, na.rm = FALSE, ddof = 1) { - list( - fun = "stddev", - data = x, - options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) - ) -} -agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { - list( - fun = "variance", - data = x, - options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) - ) -} -agg_funcs$quantile <- function(x, probs, na.rm = FALSE) { - if (length(probs) != 1) { - arrow_not_supported("quantile() with length(probs) != 1") - } - # TODO: Bind to the Arrow function that returns an exact quantile and remove - # this warning (ARROW-14021) - warn( - "quantile() currently returns an approximate quantile in Arrow", - .frequency = ifelse(is_interactive(), "once", "always"), - .frequency_id = "arrow.quantile.approximate" - ) - list( - fun = "tdigest", - data = x, - options = list(skip_nulls = na.rm, q = probs) - ) -} -agg_funcs$median <- function(x, na.rm = FALSE) { - # TODO: Bind to the Arrow function that returns an exact median and remove - # this warning (ARROW-14021) - warn( - "median() currently returns an approximate median in Arrow", - .frequency = ifelse(is_interactive(), "once", "always"), - .frequency_id = "arrow.median.approximate" - ) - list( - fun = "approximate_median", - data = x, - options = list(skip_nulls = na.rm) - ) -} -agg_funcs$n_distinct <- function(..., na.rm = FALSE) { - list( - fun = "count_distinct", - data = ensure_one_arg(list2(...), "n_distinct"), - options = list(na.rm = na.rm) - ) -} -agg_funcs$n <- function() { - list( - fun = "sum", - data = Expression$scalar(1L), - options = list() - ) -} -agg_funcs$min <- function(..., na.rm = FALSE) { - list( - fun = "min", - data = ensure_one_arg(list2(...), "min"), - options = list(skip_nulls = na.rm, min_count = 0L) - ) -} -agg_funcs$max <- function(..., na.rm = FALSE) { - list( - fun = "max", - data = ensure_one_arg(list2(...), "max"), - options = list(skip_nulls = na.rm, min_count = 0L) - ) -} - -ensure_one_arg <- function(args, fun) { - if (length(args) == 0) { - arrow_not_supported(paste0(fun, "() with 0 arguments")) - } else if (length(args) > 1) { - arrow_not_supported(paste0("Multiple arguments to ", fun, "()")) - } - args[[1]] -} - -output_type <- function(fun, input_type, hash) { - # These are quick and dirty heuristics. - if (fun %in% c("any", "all")) { - bool() - } else if (fun %in% "sum") { - # It may upcast to a bigger type but this is close enough - input_type - } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) { - float64() - } else if (fun %in% "tdigest") { - if (hash) { - fixed_size_list_of(float64(), 1L) - } else { - float64() - } - } else { - # Just so things don't error, assume the resulting type is the same - input_type - } -} diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 3b7d51ed47e..7cb9a3483d5 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -15,6 +15,149 @@ # specific language governing permissions and limitations # under the License. +# Aggregation functions +# These all return a list of: +# @param fun string function name +# @param data Expression (these are all currently a single field) +# @param options list of function options, as passed to call_function +# For group-by aggregation, `hash_` gets prepended to the function name. +# So to see a list of available hash aggregation functions, +# you can use list_compute_functions("^hash_") + + +ensure_one_arg <- function(args, fun) { + if (length(args) == 0) { + arrow_not_supported(paste0(fun, "() with 0 arguments")) + } else if (length(args) > 1) { + arrow_not_supported(paste0("Multiple arguments to ", fun, "()")) + } + args[[1]] +} + +agg_fun_output_type <- function(fun, input_type, hash) { + # These are quick and dirty heuristics. + if (fun %in% c("any", "all")) { + bool() + } else if (fun %in% "sum") { + # It may upcast to a bigger type but this is close enough + input_type + } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) { + float64() + } else if (fun %in% "tdigest") { + if (hash) { + fixed_size_list_of(float64(), 1L) + } else { + float64() + } + } else { + # Just so things don't error, assume the resulting type is the same + input_type + } +} + +register_bindings_aggregate <- function() { + register_binding_agg("sum", function(..., na.rm = FALSE) { + list( + fun = "sum", + data = ensure_one_arg(list2(...), "sum"), + options = list(skip_nulls = na.rm, min_count = 0L) + ) + }) + register_binding_agg("any", function(..., na.rm = FALSE) { + list( + fun = "any", + data = ensure_one_arg(list2(...), "any"), + options = list(skip_nulls = na.rm, min_count = 0L) + ) + }) + register_binding_agg("all", function(..., na.rm = FALSE) { + list( + fun = "all", + data = ensure_one_arg(list2(...), "all"), + options = list(skip_nulls = na.rm, min_count = 0L) + ) + }) + register_binding_agg("mean", function(x, na.rm = FALSE) { + list( + fun = "mean", + data = x, + options = list(skip_nulls = na.rm, min_count = 0L) + ) + }) + register_binding_agg("sd", function(x, na.rm = FALSE, ddof = 1) { + list( + fun = "stddev", + data = x, + options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) + ) + }) + register_binding_agg("var", function(x, na.rm = FALSE, ddof = 1) { + list( + fun = "variance", + data = x, + options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) + ) + }) + register_binding_agg("quantile", function(x, probs, na.rm = FALSE) { + if (length(probs) != 1) { + arrow_not_supported("quantile() with length(probs) != 1") + } + # TODO: Bind to the Arrow function that returns an exact quantile and remove + # this warning (ARROW-14021) + warn( + "quantile() currently returns an approximate quantile in Arrow", + .frequency = ifelse(is_interactive(), "once", "always"), + .frequency_id = "arrow.quantile.approximate" + ) + list( + fun = "tdigest", + data = x, + options = list(skip_nulls = na.rm, q = probs) + ) + }) + register_binding_agg("median", function(x, na.rm = FALSE) { + # TODO: Bind to the Arrow function that returns an exact median and remove + # this warning (ARROW-14021) + warn( + "median() currently returns an approximate median in Arrow", + .frequency = ifelse(is_interactive(), "once", "always"), + .frequency_id = "arrow.median.approximate" + ) + list( + fun = "approximate_median", + data = x, + options = list(skip_nulls = na.rm) + ) + }) + register_binding_agg("n_distinct", function(..., na.rm = FALSE) { + list( + fun = "count_distinct", + data = ensure_one_arg(list2(...), "n_distinct"), + options = list(na.rm = na.rm) + ) + }) + register_binding_agg("n", function() { + list( + fun = "sum", + data = Expression$scalar(1L), + options = list() + ) + }) + register_binding_agg("min", function(..., na.rm = FALSE) { + list( + fun = "min", + data = ensure_one_arg(list2(...), "min"), + options = list(skip_nulls = na.rm, min_count = 0L) + ) + }) + register_binding_agg("max", function(..., na.rm = FALSE) { + list( + fun = "max", + data = ensure_one_arg(list2(...), "max"), + options = list(skip_nulls = na.rm, min_count = 0L) + ) + }) +} # The following S3 methods are registered on load if dplyr is present diff --git a/r/R/expression.R b/r/R/expression.R index a76e16185cb..37fc21c25c4 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -106,6 +106,20 @@ .array_function_map <- c(.unary_function_map, .binary_function_map) +register_bindings_array_function_map <- function() { + # use a function to generate the binding so that `operator` persists + # beyond execution time (another option would be to use quasiquotation + # and unquote `operator` directly into the function expression) + array_function_map_factory <- function(operator) { + force(operator) + function(...) build_expr(operator, ...) + } + + for (name in names(.array_function_map)) { + register_binding(name, array_function_map_factory(name)) + } +} + #' Arrow expressions #' #' @description diff --git a/r/man/contains_regex.Rd b/r/man/contains_regex.Rd index f05f11d0279..338e62aa964 100644 --- a/r/man/contains_regex.Rd +++ b/r/man/contains_regex.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/dplyr-functions.R +% Please edit documentation in R/dplyr-funcs-string.R \name{contains_regex} \alias{contains_regex} \title{Does this string contain regex metacharacters?} diff --git a/r/man/get_stringr_pattern_options.Rd b/r/man/get_stringr_pattern_options.Rd index 7107b906024..6fff9796159 100644 --- a/r/man/get_stringr_pattern_options.Rd +++ b/r/man/get_stringr_pattern_options.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/dplyr-functions.R +% Please edit documentation in R/dplyr-funcs-string.R \name{get_stringr_pattern_options} \alias{get_stringr_pattern_options} \title{Get \code{stringr} pattern options} diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd new file mode 100644 index 00000000000..e776e7b3f5b --- /dev/null +++ b/r/man/register_binding.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dplyr-funcs.R +\name{register_binding} +\alias{register_binding} +\title{Register compute bindings} +\usage{ +register_binding(fun_name, fun, registry = nse_funcs) +} +\arguments{ +\item{fun_name}{A string containing a function name in the form \code{"function"} or +\code{"package::function"}. The package name is currently not used but +may be used in the future to allow these types of function calls.} + +\item{fun}{A function or \code{NULL} to un-register a previous function. +This function must accept \code{Expression} objects as arguments and return +\code{Expression} objects instead of regular R objects.} + +\item{registry}{An environment in which the functions should be +assigned.} + +\item{agg_fun}{An aggregate function or \code{NULL} to un-register a previous +aggregate function. This function must accept \code{Expression} objects as +arguments and return a \code{list()} with components: +\itemize{ +\item \code{fun}: string function name +\item \code{data}: \code{Expression} (these are all currently a single field) +\item \code{options}: list of function options, as passed to call_function +}} +} +\value{ +The previously registered binding or \code{NULL} if no previously +registered function existed. +} +\description{ +The \code{register_binding()} and \code{register_binding_agg()} functions +are used to populate a list of functions that operate on (and return) +Expressions. These are the basis for the \code{.data} mask inside dplyr methods. +} +\section{Writing bindings}{ + +When to use \code{build_expr()} vs. \code{Expression$create()}? + +Use \code{build_expr()} if you need to +\itemize{ +\item map R function names to Arrow C++ functions +\item wrap R inputs (vectors) as Array/Scalar +} + +\code{Expression$create()} is lower level. Most of the bindings use it +because they manage the preparation of the user-provided inputs +and don't need or don't want to the automatic conversion of R objects +to \link{Scalar}. +} + +\keyword{internal} diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index a5e6ce6ea09..80a2357c0cd 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -407,7 +407,7 @@ test_that("coalesce()", { # no arguments expect_error( - nse_funcs$coalesce(), + call_binding("coalesce"), "At least one argument must be supplied to coalesce()", fixed = TRUE ) diff --git a/r/tests/testthat/test-dplyr-funcs-datetime.R b/r/tests/testthat/test-dplyr-funcs-datetime.R index 0b1395680cb..577770a5af4 100644 --- a/r/tests/testthat/test-dplyr-funcs-datetime.R +++ b/r/tests/testthat/test-dplyr-funcs-datetime.R @@ -45,6 +45,227 @@ test_df <- tibble::tibble( integer = 1:2 ) + +test_that("strptime", { + t_string <- tibble(x = c("2018-10-07 19:04:05", NA)) + t_stamp <- tibble(x = c(lubridate::ymd_hms("2018-10-07 19:04:05"), NA)) + + expect_equal( + t_string %>% + Table$create() %>% + mutate( + x = strptime(x) + ) %>% + collect(), + t_stamp, + ignore_attr = "tzone" + ) + + expect_equal( + t_string %>% + Table$create() %>% + mutate( + x = strptime(x, format = "%Y-%m-%d %H:%M:%S") + ) %>% + collect(), + t_stamp, + ignore_attr = "tzone" + ) + + expect_equal( + t_string %>% + Table$create() %>% + mutate( + x = strptime(x, format = "%Y-%m-%d %H:%M:%S", unit = "ns") + ) %>% + collect(), + t_stamp, + ignore_attr = "tzone" + ) + + expect_equal( + t_string %>% + Table$create() %>% + mutate( + x = strptime(x, format = "%Y-%m-%d %H:%M:%S", unit = "s") + ) %>% + collect(), + t_stamp, + ignore_attr = "tzone" + ) + + tstring <- tibble(x = c("08-05-2008", NA)) + tstamp <- strptime(c("08-05-2008", NA), format = "%m-%d-%Y") + + expect_equal( + tstring %>% + Table$create() %>% + mutate( + x = strptime(x, format = "%m-%d-%Y") + ) %>% + pull(), + # R's strptime returns POSIXlt (list type) + as.POSIXct(tstamp), + ignore_attr = "tzone" + ) +}) + +test_that("errors in strptime", { + # Error when tz is passed + x <- Expression$field_ref("x") + expect_error( + call_binding("strptime", x, tz = "PDT"), + "Time zone argument not supported in Arrow" + ) +}) + +test_that("strftime", { + skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-13168 + + times <- tibble( + datetime = c(lubridate::ymd_hms("2018-10-07 19:04:05", tz = "Etc/GMT+6"), NA), + date = c(as.Date("2021-01-01"), NA) + ) + formats <- "%a %A %w %d %b %B %m %y %Y %H %I %p %M %z %Z %j %U %W %x %X %% %G %V %u" + formats_date <- "%a %A %w %d %b %B %m %y %Y %H %I %p %M %j %U %W %x %X %% %G %V %u" + + compare_dplyr_binding( + .input %>% + mutate(x = strftime(datetime, format = formats)) %>% + collect(), + times + ) + + compare_dplyr_binding( + .input %>% + mutate(x = strftime(date, format = formats_date)) %>% + collect(), + times + ) + + compare_dplyr_binding( + .input %>% + mutate(x = strftime(datetime, format = formats, tz = "Pacific/Marquesas")) %>% + collect(), + times + ) + + compare_dplyr_binding( + .input %>% + mutate(x = strftime(datetime, format = formats, tz = "EST", usetz = TRUE)) %>% + collect(), + times + ) + + withr::with_timezone( + "Pacific/Marquesas", + { + compare_dplyr_binding( + .input %>% + mutate( + x = strftime(datetime, format = formats, tz = "EST"), + x_date = strftime(date, format = formats_date, tz = "EST") + ) %>% + collect(), + times + ) + + compare_dplyr_binding( + .input %>% + mutate( + x = strftime(datetime, format = formats), + x_date = strftime(date, format = formats_date) + ) %>% + collect(), + times + ) + } + ) + + # This check is due to differences in the way %c currently works in Arrow and R's strftime. + # We can revisit after https://github.com/HowardHinnant/date/issues/704 is resolved. + expect_error( + times %>% + Table$create() %>% + mutate(x = strftime(datetime, format = "%c")) %>% + collect(), + "%c flag is not supported in non-C locales." + ) + + # Output precision of %S depends on the input timestamp precision. + # Timestamps with second precision are represented as integers while + # milliseconds, microsecond and nanoseconds are represented as fixed floating + # point numbers with 3, 6 and 9 decimal places respectively. + compare_dplyr_binding( + .input %>% + mutate(x = strftime(datetime, format = "%S")) %>% + transmute(as.double(substr(x, 1, 2))) %>% + collect(), + times, + tolerance = 1e-6 + ) +}) + +test_that("format_ISO8601", { + skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-13168 + times <- tibble(x = c(lubridate::ymd_hms("2018-10-07 19:04:05", tz = "Etc/GMT+6"), NA)) + + compare_dplyr_binding( + .input %>% + mutate(x = format_ISO8601(x, precision = "ymd", usetz = FALSE)) %>% + collect(), + times + ) + + if (getRversion() < "3.5") { + # before 3.5, times$x will have no timezone attribute, so Arrow faithfully + # errors that there is no timezone to format: + expect_error( + times %>% + Table$create() %>% + mutate(x = format_ISO8601(x, precision = "ymd", usetz = TRUE)) %>% + collect(), + "Timezone not present, cannot convert to string with timezone: %Y-%m-%d%z" + ) + + # See comment regarding %S flag in strftime tests + expect_error( + times %>% + Table$create() %>% + mutate(x = format_ISO8601(x, precision = "ymdhms", usetz = TRUE)) %>% + mutate(x = gsub("\\.0*", "", x)) %>% + collect(), + "Timezone not present, cannot convert to string with timezone: %Y-%m-%dT%H:%M:%S%z" + ) + } else { + compare_dplyr_binding( + .input %>% + mutate(x = format_ISO8601(x, precision = "ymd", usetz = TRUE)) %>% + collect(), + times + ) + + # See comment regarding %S flag in strftime tests + compare_dplyr_binding( + .input %>% + mutate(x = format_ISO8601(x, precision = "ymdhms", usetz = TRUE)) %>% + mutate(x = gsub("\\.0*", "", x)) %>% + collect(), + times + ) + } + + + # See comment regarding %S flag in strftime tests + compare_dplyr_binding( + .input %>% + mutate(x = format_ISO8601(x, precision = "ymdhms", usetz = FALSE)) %>% + mutate(x = gsub("\\.0*", "", x)) %>% + collect(), + times + ) +}) + # These tests test detection of dates and times test_that("is.* functions from lubridate", { diff --git a/r/tests/testthat/test-dplyr-funcs-math.R b/r/tests/testthat/test-dplyr-funcs-math.R index b5321945dcc..88d39140715 100644 --- a/r/tests/testthat/test-dplyr-funcs-math.R +++ b/r/tests/testthat/test-dplyr-funcs-math.R @@ -177,14 +177,14 @@ test_that("log functions", { # test log(, base = (length != 1)) expect_error( - nse_funcs$log(10, base = 5:6), + call_binding("log", 10, base = 5:6), "base must be a column or a length-1 numeric; other values not supported in Arrow", fixed = TRUE ) # test log(x = (length != 1)) expect_error( - nse_funcs$log(10:11), + call_binding("log", 10:11), "x must be a column or a length-1 numeric; other values not supported in Arrow", fixed = TRUE ) diff --git a/r/tests/testthat/test-dplyr-funcs-string.R b/r/tests/testthat/test-dplyr-funcs-string.R index f0965926f29..571b26332f8 100644 --- a/r/tests/testthat/test-dplyr-funcs-string.R +++ b/r/tests/testthat/test-dplyr-funcs-string.R @@ -121,7 +121,7 @@ test_that("paste, paste0, and str_c", { # sep is literal NA # errors in paste() (consistent with base::paste()) expect_error( - nse_funcs$paste(x, y, sep = NA_character_), + call_binding("paste", x, y, sep = NA_character_), "Invalid separator" ) # emits null in str_c() (consistent with stringr::str_c()) @@ -156,25 +156,25 @@ test_that("paste, paste0, and str_c", { # collapse argument not supported expect_error( - nse_funcs$paste(x, y, collapse = ""), + call_binding("paste", x, y, collapse = ""), "collapse" ) expect_error( - nse_funcs$paste0(x, y, collapse = ""), + call_binding("paste0", x, y, collapse = ""), "collapse" ) expect_error( - nse_funcs$str_c(x, y, collapse = ""), + call_binding("str_c", x, y, collapse = ""), "collapse" ) # literal vectors of length != 1 not supported expect_error( - nse_funcs$paste(x, character(0), y), + call_binding("paste", x, character(0), y), "Literal vectors of length != 1 not supported in string concatenation" ) expect_error( - nse_funcs$paste(x, c(",", ";"), y), + call_binding("paste", x, c(",", ";"), y), "Literal vectors of length != 1 not supported in string concatenation" ) }) @@ -501,7 +501,7 @@ test_that("str_to_lower, str_to_upper, and str_to_title", { # Error checking a single function because they all use the same code path. expect_error( - nse_funcs$str_to_lower("Apache Arrow", locale = "sp"), + call_binding("str_to_lower", "Apache Arrow", locale = "sp"), "Providing a value for 'locale' other than the default ('en') is not supported in Arrow", fixed = TRUE ) @@ -565,27 +565,27 @@ test_that("errors and warnings in string splitting", { x <- Expression$field_ref("x") expect_error( - nse_funcs$str_split(x, fixed("and", ignore_case = TRUE)), + call_binding("str_split", x, fixed("and", ignore_case = TRUE)), "Case-insensitive string splitting not supported in Arrow" ) expect_error( - nse_funcs$str_split(x, coll("and.?")), + call_binding("str_split", x, coll("and.?")), "Pattern modifier `coll()` not supported in Arrow", fixed = TRUE ) expect_error( - nse_funcs$str_split(x, boundary(type = "word")), + call_binding("str_split", x, boundary(type = "word")), "Pattern modifier `boundary()` not supported in Arrow", fixed = TRUE ) expect_error( - nse_funcs$str_split(x, "and", n = 0), + call_binding("str_split", x, "and", n = 0), "Splitting strings into zero parts not supported in Arrow" ) # This condition generates a warning expect_warning( - nse_funcs$str_split(x, fixed("and"), simplify = TRUE), + call_binding("str_split", x, fixed("and"), simplify = TRUE), "Argument 'simplify = TRUE' will be ignored" ) }) @@ -594,19 +594,19 @@ test_that("errors and warnings in string detection and replacement", { x <- Expression$field_ref("x") expect_error( - nse_funcs$str_detect(x, boundary(type = "character")), + call_binding("str_detect", x, boundary(type = "character")), "Pattern modifier `boundary()` not supported in Arrow", fixed = TRUE ) expect_error( - nse_funcs$str_replace_all(x, coll("o", locale = "en"), "ó"), + call_binding("str_replace_all", x, coll("o", locale = "en"), "ó"), "Pattern modifier `coll()` not supported in Arrow", fixed = TRUE ) # This condition generates a warning expect_warning( - nse_funcs$str_replace_all(x, regex("o", multiline = TRUE), "u"), + call_binding("str_replace_all", x, regex("o", multiline = TRUE), "u"), "Ignoring pattern modifier argument not supported in Arrow: \"multiline\"" ) }) @@ -692,232 +692,6 @@ test_that("edge cases in string detection and replacement", { ) }) -test_that("strptime", { - # base::strptime() defaults to local timezone - # but arrow's strptime defaults to UTC. - # So that tests are consistent, set the local timezone to UTC - # TODO: consider reevaluating this workaround after ARROW-12980 - withr::local_timezone("UTC") - - t_string <- tibble(x = c("2018-10-07 19:04:05", NA)) - t_stamp <- tibble(x = c(lubridate::ymd_hms("2018-10-07 19:04:05"), NA)) - - expect_equal( - t_string %>% - Table$create() %>% - mutate( - x = strptime(x) - ) %>% - collect(), - t_stamp, - ignore_attr = "tzone" - ) - - expect_equal( - t_string %>% - Table$create() %>% - mutate( - x = strptime(x, format = "%Y-%m-%d %H:%M:%S") - ) %>% - collect(), - t_stamp, - ignore_attr = "tzone" - ) - - expect_equal( - t_string %>% - Table$create() %>% - mutate( - x = strptime(x, format = "%Y-%m-%d %H:%M:%S", unit = "ns") - ) %>% - collect(), - t_stamp, - ignore_attr = "tzone" - ) - - expect_equal( - t_string %>% - Table$create() %>% - mutate( - x = strptime(x, format = "%Y-%m-%d %H:%M:%S", unit = "s") - ) %>% - collect(), - t_stamp, - ignore_attr = "tzone" - ) - - tstring <- tibble(x = c("08-05-2008", NA)) - tstamp <- strptime(c("08-05-2008", NA), format = "%m-%d-%Y") - - expect_equal( - tstring %>% - Table$create() %>% - mutate( - x = strptime(x, format = "%m-%d-%Y") - ) %>% - pull(), - # R's strptime returns POSIXlt (list type) - as.POSIXct(tstamp), - ignore_attr = "tzone" - ) -}) - -test_that("errors in strptime", { - # Error when tz is passed - x <- Expression$field_ref("x") - expect_error( - nse_funcs$strptime(x, tz = "PDT"), - "Time zone argument not supported in Arrow" - ) -}) - -test_that("strftime", { - skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-13168 - - times <- tibble( - datetime = c(lubridate::ymd_hms("2018-10-07 19:04:05", tz = "Etc/GMT+6"), NA), - date = c(as.Date("2021-01-01"), NA) - ) - formats <- "%a %A %w %d %b %B %m %y %Y %H %I %p %M %z %Z %j %U %W %x %X %% %G %V %u" - formats_date <- "%a %A %w %d %b %B %m %y %Y %H %I %p %M %j %U %W %x %X %% %G %V %u" - - compare_dplyr_binding( - .input %>% - mutate(x = strftime(datetime, format = formats)) %>% - collect(), - times - ) - - compare_dplyr_binding( - .input %>% - mutate(x = strftime(date, format = formats_date)) %>% - collect(), - times - ) - - compare_dplyr_binding( - .input %>% - mutate(x = strftime(datetime, format = formats, tz = "Pacific/Marquesas")) %>% - collect(), - times - ) - - compare_dplyr_binding( - .input %>% - mutate(x = strftime(datetime, format = formats, tz = "EST", usetz = TRUE)) %>% - collect(), - times - ) - - withr::with_timezone( - "Pacific/Marquesas", - { - compare_dplyr_binding( - .input %>% - mutate( - x = strftime(datetime, format = formats, tz = "EST"), - x_date = strftime(date, format = formats_date, tz = "EST") - ) %>% - collect(), - times - ) - - compare_dplyr_binding( - .input %>% - mutate( - x = strftime(datetime, format = formats), - x_date = strftime(date, format = formats_date) - ) %>% - collect(), - times - ) - } - ) - - # This check is due to differences in the way %c currently works in Arrow and R's strftime. - # We can revisit after https://github.com/HowardHinnant/date/issues/704 is resolved. - expect_error( - times %>% - Table$create() %>% - mutate(x = strftime(datetime, format = "%c")) %>% - collect(), - "%c flag is not supported in non-C locales." - ) - - # Output precision of %S depends on the input timestamp precision. - # Timestamps with second precision are represented as integers while - # milliseconds, microsecond and nanoseconds are represented as fixed floating - # point numbers with 3, 6 and 9 decimal places respectively. - compare_dplyr_binding( - .input %>% - mutate(x = strftime(datetime, format = "%S")) %>% - transmute(as.double(substr(x, 1, 2))) %>% - collect(), - times, - tolerance = 1e-6 - ) -}) - -test_that("format_ISO8601", { - skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-13168 - times <- tibble(x = c(lubridate::ymd_hms("2018-10-07 19:04:05", tz = "Etc/GMT+6"), NA)) - - compare_dplyr_binding( - .input %>% - mutate(x = format_ISO8601(x, precision = "ymd", usetz = FALSE)) %>% - collect(), - times - ) - - if (getRversion() < "3.5") { - # before 3.5, times$x will have no timezone attribute, so Arrow faithfully - # errors that there is no timezone to format: - expect_error( - times %>% - Table$create() %>% - mutate(x = format_ISO8601(x, precision = "ymd", usetz = TRUE)) %>% - collect(), - "Timezone not present, cannot convert to string with timezone: %Y-%m-%d%z" - ) - - # See comment regarding %S flag in strftime tests - expect_error( - times %>% - Table$create() %>% - mutate(x = format_ISO8601(x, precision = "ymdhms", usetz = TRUE)) %>% - mutate(x = gsub("\\.0*", "", x)) %>% - collect(), - "Timezone not present, cannot convert to string with timezone: %Y-%m-%dT%H:%M:%S%z" - ) - } else { - compare_dplyr_binding( - .input %>% - mutate(x = format_ISO8601(x, precision = "ymd", usetz = TRUE)) %>% - collect(), - times - ) - - # See comment regarding %S flag in strftime tests - compare_dplyr_binding( - .input %>% - mutate(x = format_ISO8601(x, precision = "ymdhms", usetz = TRUE)) %>% - mutate(x = gsub("\\.0*", "", x)) %>% - collect(), - times - ) - } - - - # See comment regarding %S flag in strftime tests - compare_dplyr_binding( - .input %>% - mutate(x = format_ISO8601(x, precision = "ymdhms", usetz = FALSE)) %>% - mutate(x = gsub("\\.0*", "", x)) %>% - collect(), - times - ) -}) - test_that("arrow_find_substring and arrow_find_substring_regex", { df <- tibble(x = c("Foo and Bar", "baz and qux and quux")) @@ -1163,18 +937,19 @@ test_that("substr", { ) expect_error( - nse_funcs$substr("Apache Arrow", c(1, 2), 3), + call_binding("substr", "Apache Arrow", c(1, 2), 3), "`start` must be length 1 - other lengths are not supported in Arrow" ) expect_error( - nse_funcs$substr("Apache Arrow", 1, c(2, 3)), + call_binding("substr", "Apache Arrow", 1, c(2, 3)), "`stop` must be length 1 - other lengths are not supported in Arrow" ) }) test_that("substring", { - # nse_funcs$substring just calls nse_funcs$substr, tested extensively above + # binding for substring just calls call_binding("substr", ...), + # tested extensively above df <- tibble(x = "Apache Arrow") compare_dplyr_binding( @@ -1259,12 +1034,12 @@ test_that("str_sub", { ) expect_error( - nse_funcs$str_sub("Apache Arrow", c(1, 2), 3), + call_binding("str_sub", "Apache Arrow", c(1, 2), 3), "`start` must be length 1 - other lengths are not supported in Arrow" ) expect_error( - nse_funcs$str_sub("Apache Arrow", 1, c(2, 3)), + call_binding("str_sub", "Apache Arrow", 1, c(2, 3)), "`end` must be length 1 - other lengths are not supported in Arrow" ) }) @@ -1393,7 +1168,7 @@ test_that("str_count", { df ) - # nse_funcs$str_count() is not vectorised over pattern + # call_binding("str_count", ) is not vectorised over pattern compare_dplyr_binding( .input %>% mutate(let_count = str_count(cities, pattern = c("a", "b", "e", "g", "p", "n", "s"))) %>% diff --git a/r/tests/testthat/test-dplyr-funcs.R b/r/tests/testthat/test-dplyr-funcs.R new file mode 100644 index 00000000000..d96b4b2cf87 --- /dev/null +++ b/r/tests/testthat/test-dplyr-funcs.R @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("register_binding() works", { + fake_registry <- new.env(parent = emptyenv()) + fun1 <- function() NULL + + expect_null(register_binding("some_fun", fun1, fake_registry)) + expect_identical(fake_registry$some_fun, fun1) + + expect_identical(register_binding("some_fun", NULL, fake_registry), fun1) + expect_false("some_fun" %in% names(fake_registry)) + expect_silent(expect_null(register_binding("some_fun", NULL, fake_registry))) + + expect_null(register_binding("some_pkg::some_fun", fun1, fake_registry)) + expect_identical(fake_registry$some_fun, fun1) +}) + +test_that("register_binding_agg() works", { + fake_registry <- new.env(parent = emptyenv()) + fun1 <- function() NULL + + expect_null(register_binding_agg("some_fun", fun1, fake_registry)) + expect_identical(fake_registry$some_fun, fun1) +}) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 46a5e98c4c5..7b2bdc517bb 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -276,18 +276,18 @@ test_that("Functions that take ... but we only accept a single arg", { ) # Now that we've demonstrated that the whole machinery works, let's test # the agg_funcs directly - expect_error(agg_funcs$n_distinct(), "n_distinct() with 0 arguments", fixed = TRUE) - expect_error(agg_funcs$sum(), "sum() with 0 arguments", fixed = TRUE) - expect_error(agg_funcs$any(), "any() with 0 arguments", fixed = TRUE) - expect_error(agg_funcs$all(), "all() with 0 arguments", fixed = TRUE) - expect_error(agg_funcs$min(), "min() with 0 arguments", fixed = TRUE) - expect_error(agg_funcs$max(), "max() with 0 arguments", fixed = TRUE) - expect_error(agg_funcs$n_distinct(1, 2), "Multiple arguments to n_distinct()") - expect_error(agg_funcs$sum(1, 2), "Multiple arguments to sum") - expect_error(agg_funcs$any(1, 2), "Multiple arguments to any()") - expect_error(agg_funcs$all(1, 2), "Multiple arguments to all()") - expect_error(agg_funcs$min(1, 2), "Multiple arguments to min()") - expect_error(agg_funcs$max(1, 2), "Multiple arguments to max()") + expect_error(call_binding_agg("n_distinct"), "n_distinct() with 0 arguments", fixed = TRUE) + expect_error(call_binding_agg("sum"), "sum() with 0 arguments", fixed = TRUE) + expect_error(call_binding_agg("any"), "any() with 0 arguments", fixed = TRUE) + expect_error(call_binding_agg("all"), "all() with 0 arguments", fixed = TRUE) + expect_error(call_binding_agg("min"), "min() with 0 arguments", fixed = TRUE) + expect_error(call_binding_agg("max"), "max() with 0 arguments", fixed = TRUE) + expect_error(call_binding_agg("n_distinct", 1, 2), "Multiple arguments to n_distinct()") + expect_error(call_binding_agg("sum", 1, 2), "Multiple arguments to sum") + expect_error(call_binding_agg("any", 1, 2), "Multiple arguments to any()") + expect_error(call_binding_agg("all", 1, 2), "Multiple arguments to all()") + expect_error(call_binding_agg("min", 1, 2), "Multiple arguments to min()") + expect_error(call_binding_agg("max", 1, 2), "Multiple arguments to max()") }) test_that("median()", {