diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7ed2e36d59531..2f7b876f0ec33 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -199,9 +199,13 @@ exportMethods("%<=>%", "approx_count_distinct", "approxCountDistinct", "approxQuantile", + "array_aggregate", "array_contains", "array_distinct", "array_except", + "array_exists", + "array_filter", + "array_forall", "array_intersect", "array_join", "array_max", @@ -210,9 +214,11 @@ exportMethods("%<=>%", "array_remove", "array_repeat", "array_sort", + "array_transform", "arrays_overlap", "array_union", "arrays_zip", + "arrays_zip_with", "asc", "ascii", "asin", @@ -314,10 +320,12 @@ exportMethods("%<=>%", "ltrim", "map_concat", "map_entries", + "map_filter", "map_from_arrays", "map_from_entries", "map_keys", "map_values", + "map_zip_with", "max", "md5", "mean", @@ -396,6 +404,8 @@ exportMethods("%<=>%", "to_timestamp", "to_utc_timestamp", "translate", + "transform_keys", + "transform_values", "trim", "trunc", "unbase64", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 48f69d5769620..0ecf688a636d1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -219,6 +219,34 @@ NULL #' the DDL-formatted string literal can also be accepted. #' \item \code{from_csv}: a structType object, DDL-formatted string or \code{schema_of_csv} #' } +#' +#' @param f a \code{function} mapping from \code{Column(s)} to \code{Column}. +#' \itemize{ +#' \item \code{array_exists} +#' \item \code{array_filter} the Boolean \code{function} used to filter the data. +#' Either unary or binary. In the latter case the second argument +#' is the index in the array (0-based). +#' \item \code{array_forall} the Boolean unary \code{function} used to filter the data. +#' \item \code{array_transform} a \code{function} used to transform the data. +#' Either unary or binary. In the latter case the second argument +#' is the index in the array (0-based). +#' \item \code{arrays_zip_with} +#' \item \code{map_zip_with} +#' \item \code{map_filter} the Boolean binary \code{function} used to filter the data. +#' The first argument is the key, the second argument is the value. +#' \item \code{transform_keys} a binary \code{function} +#' used to transform the data. The first argument is the key, the second argument +#' is the value. +#' \item \code{transform_values} a binary \code{function} +#' used to transform the data. The first argument is the key, the second argument +#' is the value. +#' } +#' @param zero a \code{Column} used as the initial value in \code{array_aggregate} +#' @param merge a \code{function} a binary function \code{(Column, Column) -> Column} +#' used in \code{array_aggregate}to merge values (the second argument) +#' into accumulator (the first argument). +#' @param finish an unary \code{function} \code{(Column) -> Column} used to +#' apply final transformation on the accumulated data in \code{array_aggregate}. #' @param ... additional argument(s). #' \itemize{ #' \item \code{to_json}, \code{from_json} and \code{schema_of_json}: this contains @@ -244,6 +272,14 @@ NULL #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, reverse(tmp$v1), array_remove(tmp$v1, 21))) +#' head(select(tmp, array_transform("v1", function(x) x * 10))) +#' head(select(tmp, array_exists("v1", function(x) x > 120))) +#' head(select(tmp, array_forall("v1", function(x) x >= 8.0))) +#' head(select(tmp, array_filter("v1", function(x) x < 10))) +#' head(select(tmp, array_aggregate("v1", lit(0), function(acc, y) acc + y))) +#' head(select( +#' tmp, +#' array_aggregate("v1", lit(0), function(acc, y) acc + y, function(acc) acc / 10))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -253,17 +289,22 @@ NULL #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) #' head(select(tmp3, element_at(tmp3$v3, "Valiant"), map_concat(tmp3$v3, tmp3$v3))) +#' head(select(tmp3, transform_keys("v3", function(k, v) upper(k)))) +#' head(select(tmp3, transform_values("v3", function(k, v) v * 10))) +#' head(select(tmp3, map_filter("v3", function(k, v) v < 42))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) #' head(select(tmp4, array_except(tmp4$v4, tmp4$v5), array_intersect(tmp4$v4, tmp4$v5))) #' head(select(tmp4, array_union(tmp4$v4, tmp4$v5))) #' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' head(select(tmp4, arrays_zip_with(tmp4$v4, tmp4$v5, function(x, y) x * y))) #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) #' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL"))) #' tmp6 <- mutate(df, v7 = create_array(create_array(df$model, df$model))) #' head(select(tmp6, flatten(tmp6$v7))) #' tmp7 <- mutate(df, v8 = create_array(df$model, df$cyl), v9 = create_array(df$model, df$hp)) +#' head(select(tmp7, arrays_zip_with("v8", "v9", function(x, y) (x * y) %% 3))) #' head(select(tmp7, map_from_arrays(tmp7$v8, tmp7$v9))) #' tmp8 <- mutate(df, v10 = create_array(struct(df$model, df$cyl))) #' head(select(tmp8, map_from_entries(tmp8$v10)))} @@ -3281,6 +3322,121 @@ setMethod("row_number", ###################### Collection functions###################### +#' Create o.a.s.sql.expressions.UnresolvedNamedLambdaVariable, +#' convert it to o.s.sql.Column and wrap with R Column. +#' Used by higher order functions. +#' +#' @param ... character of length = 1 +#' if length(...) > 1 then argument is interpreted as a nested +#' Column, for example \code{unresolved_named_lambda_var("a", "b", "c")} +#' yields unresolved \code{a.b.c} +#' @return Column object wrapping JVM UnresolvedNamedLambdaVariable +unresolved_named_lambda_var <- function(...) { + jc <- newJObject( + "org.apache.spark.sql.Column", + newJObject( + "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", + list(...) + ) + ) + column(jc) +} + +#' Create o.a.s.sql.expressions.LambdaFunction corresponding +#' to transformation described by func. +#' Used by higher order functions. +#' +#' @param fun R \code{function} (unary, binary or ternary) +#' that transforms \code{Columns} into a \code{Column} +#' @return JVM \code{LambdaFunction} object +create_lambda <- function(fun) { + as_jexpr <- function(x) callJMethod(x@jc, "expr") + + # Process function arguments + parameters <- formals(fun) + nparameters <- length(parameters) + + stopifnot( + nparameters >= 1 & + nparameters <= 3 & + !"..." %in% names(parameters) + ) + + args <- lapply(c("x", "y", "z")[seq_along(parameters)], function(p) { + unresolved_named_lambda_var(p) + }) + + # Invoke function and validate return type + result <- do.call(fun, args) + stopifnot(class(result) == "Column") + + # Convert both Columns to Scala expressions + jexpr <- as_jexpr(result) + + jargs <- handledCallJStatic( + "org.apache.spark.api.python.PythonUtils", + "toSeq", + handledCallJStatic( + "java.util.Arrays", "asList", lapply(args, as_jexpr) + ) + ) + + # Create Scala LambdaFunction + newJObject( + "org.apache.spark.sql.catalyst.expressions.LambdaFunction", + jexpr, + jargs, + FALSE + ) +} + +#' Invokes higher order function expression identified by name, +#' (relative to o.a.s.sql.catalyst.expressions) +#' +#' @param name character +#' @param cols list of character or Column objects +#' @param funs list of named list(fun = ..., expected_narg = ...) +#' @return a \code{Column} representing name applied to cols with funs +invoke_higher_order_function <- function(name, cols, funs) { + as_jexpr <- function(x) { + if (class(x) == "character") { + x <- column(x) + } + callJMethod(x@jc, "expr") + } + + jexpr <- do.call(newJObject, c( + paste("org.apache.spark.sql.catalyst.expressions", name, sep = "."), + lapply(cols, as_jexpr), + lapply(funs, create_lambda) + )) + + column(newJObject("org.apache.spark.sql.Column", jexpr)) +} + +#' @details +#' \code{array_aggregate} Applies a binary operator to an initial state +#' and all elements in the array, and reduces this to a single state. +#' The final state is converted into the final result by applying +#' a finish function. +#' +#' @rdname column_collection_functions +#' @aliases array_aggregate array_aggregate,characterOrColumn,Column,function-method +#' @note array_aggregate since 3.1.0 +setMethod("array_aggregate", + signature(x = "characterOrColumn", zero = "Column", merge = "function"), + function(x, zero, merge, finish = NULL) { + invoke_higher_order_function( + "ArrayAggregate", + cols = list(x, zero), + funs = if (is.null(finish)) { + list(merge) + } else { + list(merge, finish) + } + ) + }) + #' @details #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. @@ -3322,6 +3478,54 @@ setMethod("array_except", column(jc) }) +#' @details +#' \code{array_exists} Returns whether a predicate holds for one or more elements in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_exists array_exists,characterOrColumn,function-method +#' @note array_exists since 3.1.0 +setMethod("array_exists", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "ArrayExists", + cols = list(x), + funs = list(f) + ) + }) + +#' @details +#' \code{array_filter} Returns an array of elements for which a predicate holds in a given array. +#' +#' @rdname column_collection_functions +#' @aliases array_filter array_filter,characterOrColumn,function-method +#' @note array_filter since 3.1.0 +setMethod("array_filter", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "ArrayFilter", + cols = list(x), + funs = list(f) + ) + }) + +#' @details +#' \code{array_forall} Returns whether a predicate holds for every element in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_forall array_forall,characterOrColumn,function-method +#' @note array_forall since 3.1.0 +setMethod("array_forall", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "ArrayForAll", + cols = list(x), + funs = list(f) + ) + }) + #' @details #' \code{array_intersect}: Returns an array of the elements in the intersection of the given two #' arrays, without duplicates. @@ -3446,6 +3650,23 @@ setMethod("array_sort", column(jc) }) +#' @details +#' \code{array_transform} Returns an array of elements after applying +#' a transformation to each element in the input array. +#' +#' @rdname column_collection_functions +#' @aliases array_transform array_transform,characterOrColumn,characterOrColumn,function-method +#' @note array_transform since 3.1.0 +setMethod("array_transform", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "ArrayTransform", + cols = list(x), + funs = list(f) + ) + }) + #' @details #' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in #' common. If not and both arrays are non-empty and any of them contains a null, it returns null. @@ -3493,6 +3714,24 @@ setMethod("arrays_zip", column(jc) }) +#' @details +#' \code{arrays_zip_with} Merge two given arrays, element-wise, into a single array +#' using a function. If one array is shorter, nulls are appended at the end +#' to match the length of the longer array, before applying the function. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip_with arrays_zip_with,characterOrColumn,characterOrColumn,function-method +#' @note zip_with since 3.1.0 +setMethod("arrays_zip_with", + signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), + function(x, y, f) { + invoke_higher_order_function( + "ZipWith", + cols = list(x, y), + funs = list(f) + ) + }) + #' @details #' \code{shuffle}: Returns a random permutation of the given array. #' @@ -3550,6 +3789,21 @@ setMethod("map_entries", column(jc) }) +#' @details +#' \code{map_filter} Returns a map whose key-value pairs satisfy a predicate. +#' +#' @rdname column_collection_functions +#' @aliases map_filter map_filter,characterOrColumn,function-method +#' @note map_filter since 3.1.0 +setMethod("map_filter", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "MapFilter", + cols = list(x), + funs = list(f)) + }) + #' @details #' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for #' keys. The array in the second column is used for values. All elements in the array for key @@ -3591,6 +3845,41 @@ setMethod("map_keys", column(jc) }) +#' @details +#' \code{transform_keys} Applies a function to every key-value pair in a map and returns +#' a map with the results of those applications as the new keys for the pairs. +#' +#' @rdname column_collection_functions +#' @aliases transform_keys transform_keys,characterOrColumn,function-method +#' @note transform_keys since 3.1.0 +setMethod("transform_keys", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "TransformKeys", + cols = list(x), + funs = list(f) + ) + }) + +#' @details +#' \code{transform_values} Applies a function to every key-value pair in a map and returns +#' a map with the results of those applications as the new values for the pairs. +#' +#' @rdname column_collection_functions +#' @aliases transform_values transform_values,characterOrColumn,function-method +#' @note transform_values since 3.1.0 +setMethod("transform_values", + signature(x = "characterOrColumn", f = "function"), + function(x, f) { + invoke_higher_order_function( + "TransformValues", + cols = list(x), + funs = list(f) + ) + }) + + #' @details #' \code{map_values}: Returns an unordered array containing the values of the map. #' @@ -3604,6 +3893,24 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{map_zip} Merge two given maps, key-wise into a single map using a function. +#' +#' @rdname column_collection_functions +#' @aliases map_zip_with map_zip_with,characterOrColumn,characterOrColumn,function-method +#' +#' @examples +#' @note map_zip_with since 3.1.0 +setMethod("map_zip_with", + signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), + function(x, y, f) { + invoke_higher_order_function( + "MapZipWith", + cols = list(x, y), + funs = list(f) + ) + }) + #' @details #' \code{element_at}: Returns element of array at given index in \code{extraction} if #' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4134d5cecc888..a52ec7a4a27c1 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approx_count_distinct", function(x, ...) { standardGeneric("approx_c #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_aggregate", function(x, zero, merge, ...) { standardGeneric("array_aggregate") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) @@ -769,6 +773,18 @@ setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) #' @name NULL setGeneric("array_except", function(x, y) { standardGeneric("array_except") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_exists", function(x, f) { standardGeneric("array_exists") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_forall", function(x, f) { standardGeneric("array_forall") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_filter", function(x, f) { standardGeneric("array_filter") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_intersect", function(x, y) { standardGeneric("array_intersect") }) @@ -801,6 +817,10 @@ setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") #' @name NULL setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_transform", function(x, f) { standardGeneric("array_transform") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) @@ -813,6 +833,10 @@ setGeneric("array_union", function(x, y) { standardGeneric("array_union") }) #' @name NULL setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip_with", function(x, y, f) { standardGeneric("arrays_zip_with") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1086,6 +1110,10 @@ setGeneric("map_concat", function(x, ...) { standardGeneric("map_concat") }) #' @name NULL setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_filter", function(x, f) { standardGeneric("map_filter") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) @@ -1102,6 +1130,10 @@ setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) #' @name NULL setGeneric("map_values", function(x) { standardGeneric("map_values") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_zip_with", function(x, y, f) { standardGeneric("map_zip_with") }) + #' @rdname column_misc_functions #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) @@ -1314,6 +1346,14 @@ setGeneric("substring_index", function(x, delim, count) { standardGeneric("subst #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("transform_keys", function(x, f) { standardGeneric("transform_keys") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("transform_values", function(x, f) { standardGeneric("transform_values") }) + #' @rdname column_math_functions #' @name NULL setGeneric("degrees", function(x) { standardGeneric("degrees") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 3b3768f7e2715..ff310757bfafb 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1993,6 +1993,70 @@ test_that("when(), otherwise() and ifelse() with column on a DataFrame", { expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) }) +test_that("higher order functions", { + df <- select( + createDataFrame(data.frame(id = 1)), + expr("CAST(array(1.0, 2.0, -3.0, -4.0) AS array) xs"), + expr("CAST(array(0.0, 3.0, 48.0) AS array) ys"), + expr("array('FAILED', 'SUCCEDED') as vs"), + expr("map('foo', 1, 'bar', 2) as mx"), + expr("map('foo', 42, 'bar', -1, 'baz', 0) as my") + ) + + map_to_sorted_array <- function(x) { + sort_array(arrays_zip(map_keys(x), map_values(x))) + } + + result <- collect(select( + df, + array_transform("xs", function(x) x + 1) == expr("transform(xs, x -> x + 1)"), + array_transform("xs", function(x, i) otherwise(when(i %% 2 == 0, x), -x)) == + expr("transform(xs, (x, i) -> CASE WHEN ((i % 2.0) = 0.0) THEN x ELSE (- x) END)"), + array_exists("vs", function(v) rlike(v, "FAILED")) == + expr("exists(vs, v -> (v RLIKE 'FAILED'))"), + array_forall("xs", function(x) x > 0) == + expr("forall(xs, x -> x > 0)"), + array_filter("xs", function(x, i) x > 0 | i %% 2 == 0) == + expr("filter(xs, (x, i) -> x > 0 OR i % 2 == 0)"), + array_filter("xs", function(x) signum(x) > 0) == + expr("filter(xs, x -> signum(x) > 0)"), + array_aggregate("xs", lit(0.0), function(x, y) otherwise(when(x > y, x), y)) == + expr("aggregate(xs, CAST(0.0 AS double), (x, y) -> CASE WHEN x > y THEN x ELSE y END)"), + array_aggregate( + "xs", + struct( + alias(lit(0.0), "count"), + alias(lit(0.0), "sum") + ), + function(acc, x) { + count <- getItem(acc, "count") + sum <- getItem(acc, "sum") + struct(alias(count + 1.0, "count"), alias(sum + x, "sum")) + }, + function(acc) getItem(acc, "sum") / getItem(acc, "count") + ) == expr(paste0( + "aggregate(xs, struct(CAST(0.0 AS double) count, CAST(0.0 AS double) sum), ", + "(acc, x) -> ", + "struct(cast(acc.count + 1.0 AS double) count, CAST(acc.sum + x AS double) sum), ", + "acc -> acc.sum / acc.count)" + )), + arrays_zip_with("xs", "ys", function(x, y) x + y) == + expr("zip_with(xs, ys, (x, y) -> x + y)"), + map_to_sorted_array(transform_keys("mx", function(k, v) upper(k))) == + map_to_sorted_array(expr("transform_keys(mx, (k, v) -> upper(k))")), + map_to_sorted_array(transform_values("mx", function(k, v) v * 2)) == + map_to_sorted_array(expr("transform_values(mx, (k, v) -> v * 2)")), + map_to_sorted_array(map_filter(column("my"), function(k, v) lower(v) != "foo")) == + map_to_sorted_array(expr("map_filter(my, (k, v) -> lower(v) != 'foo')")), + map_to_sorted_array(map_zip_with("mx", "my", function(k, vx, vy) vx * vy)) == + map_to_sorted_array(expr("map_zip_with(mx, my, (k, vx, vy) -> vx * vy)")) + )) + + expect_true(all(unlist(result))) + + expect_error(array_transform("xs", function(...) 42)) +}) + test_that("group by, agg functions", { df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum")