diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4ea05b25ecc9e..25162f3e23b38 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -427,6 +427,7 @@ exportMethods("%<=>%", "variance", "var_pop", "var_samp", + "vector_to_array", "weekofyear", "when", "window", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b216f404a3ca5..61ea90efb348d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -345,6 +345,17 @@ NULL #' head(tmp)} NULL +#' ML functions for Column operations +#' +#' ML functions defined for \code{Column}. +#' +#' @param x Column to compute on. +#' @param ... additional argument(s). +#' @name column_ml_functions +#' @rdname column_ml_functions +#' @family ml functions +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -4458,3 +4469,25 @@ setMethod("timestamp_seconds", ) column(jc) }) + +#' @details +#' \code{vector_to_array} Converts a column of MLlib sparse/dense vectors into +#' a column of dense arrays. +#' +#' @param dtype The data type of the output array. Valid values: "float64" or "float32". +#' +#' @rdname column_ml_functions +#' @aliases vector_to_array vector_to_array,Column-method +#' @note vector_to_array since 3.1.0 +setMethod("vector_to_array", + signature(x = "Column"), + function(x, dtype = c("float32", "float64")) { + dtype <- match.arg(dtype) + jc <- callJStatic( + "org.apache.spark.ml.functions", + "vector_to_array", + x@jc, + dtype + ) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 985678679dec8..993fc758adbe5 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1449,6 +1449,10 @@ setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) +#' @rdname column_ml_functions +#' @name NULL +setGeneric("vector_to_array", function(x, ...) { standardGeneric("vector_to_array") }) + #' @rdname column_datetime_functions #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index c36620227593d..c3b271b1205c5 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1424,7 +1424,8 @@ test_that("column functions", { date_trunc("quarter", c) + current_date() + current_timestamp() c25 <- overlay(c1, c2, c3, c3) + overlay(c1, c2, c3) + overlay(c1, c2, 1) + overlay(c1, c2, 3, 4) - c26 <- timestamp_seconds(c1) + c26 <- timestamp_seconds(c1) + vector_to_array(c) + + vector_to_array(c, "float32") + vector_to_array(c, "float64") c27 <- nth_value("x", 1L) + nth_value("y", 2, TRUE) + nth_value(column("v"), 3) + nth_value(column("z"), 4L, FALSE)