diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 36a55c05b26..a5fb1ee9a4e 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -108,6 +108,7 @@ Collate: 'table.R' 'dplyr.R' 'duckdb.R' + 'extension.R' 'feather.R' 'field.R' 'filesystem.R' diff --git a/r/NAMESPACE b/r/NAMESPACE index 7cb89b0a53a..f8aece152c0 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -134,6 +134,8 @@ export(DictionaryArray) export(DirectoryPartitioning) export(DirectoryPartitioningFactory) export(Expression) +export(ExtensionArray) +export(ExtensionType) export(FeatherReader) export(Field) export(FileFormat) @@ -267,6 +269,8 @@ export(match_arrow) export(matches) export(mmap_create) export(mmap_open) +export(new_extension_array) +export(new_extension_type) export(null) export(num_range) export(one_of) @@ -282,6 +286,8 @@ export(read_parquet) export(read_schema) export(read_tsv_arrow) export(record_batch) +export(register_extension_type) +export(reregister_extension_type) export(s3_bucket) export(schema) export(set_cpu_count) @@ -300,8 +306,11 @@ export(uint32) export(uint64) export(uint8) export(unify_schemas) +export(unregister_extension_type) export(utf8) export(value_counts) +export(vctrs_extension_array) +export(vctrs_extension_type) export(write_arrow) export(write_csv_arrow) export(write_dataset) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 2fab03d08c3..427fda0165b 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -80,6 +80,11 @@ } } + if (arrow_available()) { + # register extension types that we use internally + reregister_extension_type(vctrs_extension_type(vctrs::unspecified())) + } + invisible() } diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 5ef6312196d..7bf77f1e66c 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1068,6 +1068,42 @@ compute___expr__type_id <- function(x, schema) { .Call(`_arrow_compute___expr__type_id`, x, schema) } +ExtensionType__initialize <- function(storage_type, extension_name, extension_metadata, r6_class) { + .Call(`_arrow_ExtensionType__initialize`, storage_type, extension_name, extension_metadata, r6_class) +} + +ExtensionType__extension_name <- function(type) { + .Call(`_arrow_ExtensionType__extension_name`, type) +} + +ExtensionType__Serialize <- function(type) { + .Call(`_arrow_ExtensionType__Serialize`, type) +} + +ExtensionType__storage_type <- function(type) { + .Call(`_arrow_ExtensionType__storage_type`, type) +} + +ExtensionType__MakeArray <- function(type, data) { + .Call(`_arrow_ExtensionType__MakeArray`, type, data) +} + +ExtensionType__r6_class <- function(type) { + .Call(`_arrow_ExtensionType__r6_class`, type) +} + +ExtensionArray__storage <- function(array) { + .Call(`_arrow_ExtensionArray__storage`, array) +} + +arrow__RegisterRExtensionType <- function(type) { + invisible(.Call(`_arrow_arrow__RegisterRExtensionType`, type)) +} + +arrow__UnregisterRExtensionType <- function(type_name) { + invisible(.Call(`_arrow_arrow__UnregisterRExtensionType`, type_name)) +} + ipc___WriteFeather__Table <- function(stream, table, version, chunk_size, compression, compression_level) { invisible(.Call(`_arrow_ipc___WriteFeather__Table`, stream, table, version, chunk_size, compression, compression_level)) } diff --git a/r/R/extension.R b/r/R/extension.R new file mode 100644 index 00000000000..111a0e86203 --- /dev/null +++ b/r/R/extension.R @@ -0,0 +1,545 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#' @include arrow-package.R + + +#' @title class arrow::ExtensionArray +#' +#' @usage NULL +#' @format NULL +#' @docType class +#' +#' @section Methods: +#' +#' The `ExtensionArray` class inherits from `Array`, but also provides +#' access to the underlying storage of the extension. +#' +#' - `$storage()`: Returns the underlying [Array] used to store +#' values. +#' +#' The `ExtensionArray` is not intended to be subclassed for extension +#' types. +#' +#' @rdname ExtensionArray +#' @name ExtensionArray +#' @export +ExtensionArray <- R6Class("ExtensionArray", + inherit = Array, + public = list( + storage = function() { + ExtensionArray__storage(self) + }, + + as_vector = function() { + self$type$as_vector(self) + } + ) +) + +ExtensionArray$create <- function(x, type) { + assert_is(type, "ExtensionType") + if (inherits(x, "ExtensionArray") && type$Equals(x$type)) { + return(x) + } + + storage <- Array$create(x, type = type$storage_type()) + type$WrapArray(storage) +} + +#' @title class arrow::ExtensionType +#' +#' @usage NULL +#' @format NULL +#' @docType class +#' +#' @section Methods: +#' +#' The `ExtensionType` class inherits from `DataType`, but also defines +#' extra methods specific to extension types: +#' +#' - `$storage_type()`: Returns the underlying [DataType] used to store +#' values. +#' - `$storage_id()`: Returns the [Type] identifier corresponding to the +#' `$storage_type()`. +#' - `$extension_name()`: Returns the extension name. +#' - `$extension_metadata()`: Returns the serialized version of the extension +#' metadata as a [raw()] vector. +#' - `$extension_metadata_utf8()`: Returns the serialized version of the +#' extension metadata as a UTF-8 encoded string. +#' - `$WrapArray(array)`: Wraps a storage [Array] into an [ExtensionArray] +#' with this extension type. +#' +#' In addition, subclasses may override the following methos to customize +#' the behaviour of extension classes. +#' +#' - `$deserialize_instance()`: This method is called when a new [ExtensionType] +#' is initialized and is responsible for parsing and validating +#' the serialized extension_metadata (a [raw()] vector) +#' such that its contents can be inspected by fields and/or methods +#' of the R6 ExtensionType subclass. Implementations must also check the +#' `storage_type` to make sure it is compatible with the extension type. +#' - `$as_vector(extension_array)`: Convert an [Array] or [ChunkedArray] to an R +#' vector. This method is called by [as.vector()] on [ExtensionArray] +#' objects, when a [RecordBatch] containing an [ExtensionArray] is +#' converted to a [data.frame()], or when a [ChunkedArray] (e.g., a column +#' in a [Table]) is converted to an R vector. The default method returns the +#' converted storage array. +#' - `$ToString()` Return a string representation that will be printed +#' to the console when this type or an Array of this type is printed. +#' +#' @rdname ExtensionType +#' @name ExtensionType +#' @export +ExtensionType <- R6Class("ExtensionType", + inherit = DataType, + public = list( + + # In addition to the initialization that occurs for all + # ArrowObject instances, we call deserialize_instance(), which can + # be overridden to populate custom fields + initialize = function(xp) { + super$initialize(xp) + self$deserialize_instance() + }, + + # Because of how C++ shared_ptr<> objects are converted to R objects, + # the initial object that is instantiated will be of this class + # (ExtensionType), but the R6Class object that was registered is + # available from C++. We need this in order to produce the correct + # R6 subclass when a shared_ptr is returned to R. + r6_class = function() { + ExtensionType__r6_class(self) + }, + + storage_type = function() { + ExtensionType__storage_type(self) + }, + + storage_id = function() { + self$storage_type()$id + }, + + extension_name = function() { + ExtensionType__extension_name(self) + }, + + extension_metadata = function() { + ExtensionType__Serialize(self) + }, + + # To make sure this conversion is done properly + extension_metadata_utf8 = function() { + metadata_utf8 <- rawToChar(self$extension_metadata()) + Encoding(metadata_utf8) <- "UTF-8" + metadata_utf8 + }, + + WrapArray = function(array) { + assert_is(array, "Array") + ExtensionType__MakeArray(self, array$data()) + }, + + deserialize_instance = function() { + # Do nothing by default but allow other classes to override this method + # to populate R6 class members. + }, + + ExtensionEquals = function(other) { + inherits(other, "ExtensionType") && + identical(other$extension_name(), self$extension_name()) && + identical(other$extension_metadata(), self$extension_metadata()) + }, + + as_vector = function(extension_array) { + if (inherits(extension_array, "ChunkedArray")) { + # Converting one array at a time so that users don't have to remember + # to implement two methods. Converting all the storage arrays to + # a ChunkedArray and then converting is probably faster + # (VctrsExtensionType does this). + storage_vectors <- lapply( + seq_len(extension_array$num_chunks) - 1L, + function(i) self$as_vector(extension_array$chunk(i)) + ) + + vctrs::vec_c(!!! storage_vectors) + } else if (inherits(extension_array, "ExtensionArray")) { + extension_array$storage()$as_vector() + } else { + classes <- paste(class(extension_array), collapse = " / ") + abort( + c( + "`extension_array` must be a ChunkedArray or ExtensionArray", + i = glue::glue("Got object of type {classes}") + ) + ) + } + }, + + ToString = function() { + # metadata is probably valid UTF-8 (e.g., JSON), but might not be + # and it's confusing to error when printing the object. This herustic + # isn't perfect (but subclasses should override this method anyway) + metadata_raw <- self$extension_metadata() + + if (as.raw(0x00) %in% metadata_raw) { + if (length(metadata_raw) > 20) { + sprintf( + "<%s %s...>", + class(self)[1], + paste(format(utils::head(metadata_raw, 20)), collapse = " ") + ) + } else { + sprintf( + "<%s %s>", + class(self)[1], + paste(format(metadata_raw), collapse = " ") + ) + } + + } else { + paste0(class(self)[1], " <", self$extension_metadata_utf8(), ">") + } + } + ) +) + +# ExtensionType$new() is what gets used by the generated wrapper code to +# create an R6 object when a shared_ptr is returned to R and +# that object has type_id() EXTENSION_TYPE. Rather than add complexity +# to the wrapper code, we modify ExtensionType$new() to do what we need +# it to do here (which is to return an instance of a custom R6 +# type whose .deserialize_instance method is called to populate custom fields). +ExtensionType$.default_new <- ExtensionType$new +ExtensionType$new <- function(xp) { + super <- ExtensionType$.default_new(xp) + r6_class <- super$r6_class() + if (identical(r6_class$classname, "ExtensionType")) { + super + } else { + r6_class$new(xp) + } +} + +ExtensionType$create <- function(storage_type, + extension_name, + extension_metadata = raw(), + type_class = ExtensionType) { + if (is.string(extension_metadata)) { + extension_metadata <- charToRaw(enc2utf8(extension_metadata)) + } + + assert_that(is.string(extension_name), is.raw(extension_metadata)) + assert_is(storage_type, "DataType") + assert_is(type_class, "R6ClassGenerator") + + ExtensionType__initialize( + storage_type, + extension_name, + extension_metadata, + type_class + ) +} + +#' Extension types +#' +#' Extension arrays are wrappers around regular Arrow [Array] objects +#' that provide some customized behaviour and/or storage. A common use-case +#' for extension types is to define a customized conversion between an +#' an Arrow [Array] and an R object when the default conversion is slow +#' or looses metadata important to the interpretation of values in the array. +#' For most types, the built-in +#' [vctrs extension type][vctrs_extension_type] is probably sufficient. +#' +#' These functions create, register, and unregister [ExtensionType] +#' and [ExtensionArray] objects. To use an extension type you will have to: +#' +#' - Define an [R6::R6Class] that inherits from [ExtensionType] and reimplement +#' one or more methods (e.g., `deserialize_instance()`). +#' - Make a type constructor function (e.g., `my_extension_type()`) that calls +#' [new_extension_type()] to create an R6 instance that can be used as a +#' [data type][data-type] elsewhere in the package. +#' - Make an array constructor function (e.g., `my_extension_array()`) that +#' calls [new_extension_array()] to create an [Array] instance of your +#' extension type. +#' - Register a dummy instance of your extension type created using +#' you constructor function using [register_extension_type()]. +#' +#' If defining an extension type in an R package, you will probably want to +#' use [reregister_extension_type()] in that package's [.onLoad()] hook +#' since your package will probably get reloaded in the same R session +#' during its development and [register_extension_type()] will error if +#' called twice for the same `extension_name`. For an example of an +#' extension type that uses most of these features, see +#' [vctrs_extension_type()]. +#' +#' @param storage_type The [data type][data-type] of the underlying storage +#' array. +#' @param storage_array An [Array] object of the underlying storage. +#' @param extension_type An [ExtensionType] instance. +#' @param extension_name The extension name. This should be namespaced using +#' "dot" syntax (i.e., "some_package.some_type"). The namespace "arrow" +#' is reserved for extension types defined by the Apache Arrow libraries. +#' @param extension_metadata A [raw()] or [character()] vector containing the +#' serialized version of the type. Chatacter vectors must be length 1 and +#' are converted to UTF-8 before converting to [raw()]. +#' @param type_class An [R6::R6Class] whose `$new()` class method will be +#' used to construct a new instance of the type. +#' +#' @return +#' - `new_extension_type()` returns an [ExtensionType] instance according +#' to the `type_class` specified. +#' - `new_extension_array()` returns an [ExtensionArray] whose `$type` +#' corresponds to `extension_type`. +#' - `register_extension_type()`, `unregister_extension_type()` +#' and `reregister_extension_type()` return `NULL`, invisibly. +#' @export +#' +#' @examplesIf arrow_available() +#' # Create the R6 type whose methods control how Array objects are +#' # converted to R objects, how equality between types is computed, +#' # and how types are printed. +#' QuantizedType <- R6::R6Class( +#' "QuantizedType", +#' inherit = ExtensionType, +#' public = list( +#' # methods to access the custom metadata fields +#' center = function() private$.center, +#' scale = function() private$.scale, +#' +#' # called when an Array of this type is converted to an R vector +#' as_vector = function(extension_array) { +#' if (inherits(extension_array, "ExtensionArray")) { +#' unquantized_arrow <- +#' (extension_array$storage()$cast(float64()) / private$.scale) + +#' private$.center +#' +#' as.vector(unquantized_arrow) +#' } else { +#' super$as_vector(extension_array) +#' } +#' }, +#' +#' # populate the custom metadata fields from the serialized metadata +#' deserialize_instance = function() { +#' vals <- as.numeric(strsplit(self$extension_metadata_utf8(), ";")[[1]]) +#' private$.center <- vals[1] +#' private$.scale <- vals[2] +#' } +#' ), +#' +#' private = list( +#' .center = NULL, +#' .scale = NULL +#' ) +#' ) +#' +#' # Create a helper type constructor that calls new_extension_type() +#' quantized <- function(center = 0, scale = 1, storage_type = int32()) { +#' new_extension_type( +#' storage_type = storage_type, +#' extension_name = "arrow.example.quantized", +#' extension_metadata = paste(center, scale, sep = ";"), +#' type_class = QuantizedType +#' ) +#' } +#' +#' # Create a helper array constructor that calls new_extension_array() +#' quantized_array <- function(x, center = 0, scale = 1, +#' storage_type = int32()) { +#' type <- quantized(center, scale, storage_type) +#' new_extension_array( +#' Array$create((x - center) * scale, type = storage_type), +#' type +#' ) +#' } +#' +#' # Register the extension type so that Arrow knows what to do when +#' # it encounters this extension type +#' reregister_extension_type(quantized()) +#' +#' # Create Array objects and use them! +#' (vals <- runif(5, min = 19, max = 21)) +#' +#' (array <- quantized_array( +#' vals, +#' center = 20, +#' scale = 2 ^ 15 - 1, +#' storage_type = int16()) +#' ) +#' +#' array$type$center() +#' array$type$scale() +#' +#' as.vector(array) +new_extension_type <- function(storage_type, + extension_name, + extension_metadata = raw(), + type_class = ExtensionType) { + ExtensionType$create( + storage_type, + extension_name, + extension_metadata, + type_class + ) +} + +#' @rdname new_extension_type +#' @export +new_extension_array <- function(storage_array, extension_type) { + ExtensionArray$create(storage_array, extension_type) +} + +#' @rdname new_extension_type +#' @export +register_extension_type <- function(extension_type) { + assert_is(extension_type, "ExtensionType") + arrow__RegisterRExtensionType(extension_type) +} + +#' @rdname new_extension_type +#' @export +reregister_extension_type <- function(extension_type) { + tryCatch( + register_extension_type(extension_type), + error = function(e) { + unregister_extension_type(extension_type$extension_name()) + register_extension_type(extension_type) + } + ) +} + +#' @rdname new_extension_type +#' @export +unregister_extension_type <- function(extension_name) { + arrow__UnregisterRExtensionType(extension_name) +} + +VctrsExtensionType <- R6Class("VctrsExtensionType", + inherit = ExtensionType, + public = list( + ptype = function() { + private$.ptype + }, + + ToString = function() { + tf <- tempfile() + sink(tf) + on.exit({ + sink(NULL) + unlink(tf) + }) + print(self$ptype()) + paste0(readLines(tf), collapse = "\n") + }, + + deserialize_instance = function() { + private$.ptype <- unserialize(self$extension_metadata()) + }, + + ExtensionEquals = function(other) { + if (!inherits(other, "VctrsExtensionType")) { + return(FALSE) + } + + identical(self$ptype(), other$ptype()) + }, + + as_vector = function(extension_array) { + if (inherits(extension_array, "ChunkedArray")) { + # rather than convert one array at a time, use more Arrow + # machinery to convert the whole ChunkedArray at once + storage_arrays <- lapply( + seq_len(extension_array$num_chunks) - 1L, + function(i) extension_array$chunk(i)$storage() + ) + storage <- chunked_array(!!! storage_arrays, type = self$storage_type()) + + vctrs::vec_restore(storage$as_vector(), self$ptype()) + } else if (inherits(extension_array, "Array")) { + vctrs::vec_restore( + super$as_vector(extension_array), + self$ptype() + ) + } else { + super$as_vector(extension_array) + } + } + ), + private = list( + .ptype = NULL + ) +) + + +#' Extension type for generic typed vectors +#' +#' Most common R vector types are converted automatically to a suitable +#' Arrow [data type][data-type] without the need for an extension type. For +#' vector types whose conversion is not suitably handled by default, you can +#' create a [vctrs_extension_array()], which passes [vctrs::vec_data()] to +#' `Array$create()` and calls [vctrs::vec_restore()] when the [Array] is +#' converted back into an R vector. +#' +#' @param x A vctr (i.e., [vctrs::vec_is()] returns `TRUE`). +#' @param ptype A [vctrs::vec_ptype()], which is usually a zero-length +#' version of the object with the appropriate attributes set. This value +#' will be serialized using [serialize()], so it should not refer to any +#' R object that can't be saved/reloaded. +#' @inheritParams new_extension_type +#' +#' @return +#' - `vctrs_extension_array()` returns an [ExtensionArray] instance with a +#' `vctrs_extension_type()`. +#' - `vctrs_extension_type()` returns an [ExtensionType] instance for the +#' extension name "arrow.r.vctrs". +#' @export +#' +#' @examplesIf arrow_available() +#' (array <- vctrs_extension_array(as.POSIXlt("2022-01-02 03:45", tz = "UTC"))) +#' array$type +#' as.vector(array) +#' +#' temp_feather <- tempfile() +#' write_feather(arrow_table(col = array), temp_feather) +#' read_feather(temp_feather) +#' unlink(temp_feather) +vctrs_extension_array <- function(x, ptype = vctrs::vec_ptype(x), + storage_type = NULL) { + if (inherits(x, "ExtensionArray") && inherits(x$type, "VctrsExtensionType")) { + return(x) + } + + vctrs::vec_assert(x) + storage <- Array$create(vctrs::vec_data(x), type = storage_type) + type <- vctrs_extension_type(ptype, storage$type) + new_extension_array(storage, type) +} + +#' @rdname vctrs_extension_array +#' @export +vctrs_extension_type <- function(ptype, + storage_type = type(vctrs::vec_data(ptype))) { + ptype <- vctrs::vec_ptype(ptype) + + new_extension_type( + storage_type = storage_type, + extension_name = "arrow.r.vctrs", + extension_metadata = serialize(ptype, NULL), + type_class = VctrsExtensionType + ) +} diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index fcb7b2016ac..c3810cdf099 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -144,6 +144,8 @@ reference: - buffer - read_message - concat_arrays + - ExtensionArray + - vctrs_extension_array - title: Arrow data types and schema contents: - Schema @@ -156,6 +158,9 @@ reference: - DataType - DictionaryType - FixedWidthType + - new_extension_type + - vctrs_extension_type + - ExtensionType - title: Flight contents: - load_flight_server diff --git a/r/man/ExtensionArray.Rd b/r/man/ExtensionArray.Rd new file mode 100644 index 00000000000..84a63c9bb94 --- /dev/null +++ b/r/man/ExtensionArray.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extension.R +\docType{class} +\name{ExtensionArray} +\alias{ExtensionArray} +\title{class arrow::ExtensionArray} +\description{ +class arrow::ExtensionArray +} +\section{Methods}{ + + +The \code{ExtensionArray} class inherits from \code{Array}, but also provides +access to the underlying storage of the extension. +\itemize{ +\item \verb{$storage()}: Returns the underlying \link{Array} used to store +values. +} + +The \code{ExtensionArray} is not intended to be subclassed for extension +types. +} + diff --git a/r/man/ExtensionType.Rd b/r/man/ExtensionType.Rd new file mode 100644 index 00000000000..6b05f3490d2 --- /dev/null +++ b/r/man/ExtensionType.Rd @@ -0,0 +1,48 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extension.R +\docType{class} +\name{ExtensionType} +\alias{ExtensionType} +\title{class arrow::ExtensionType} +\description{ +class arrow::ExtensionType +} +\section{Methods}{ + + +The \code{ExtensionType} class inherits from \code{DataType}, but also defines +extra methods specific to extension types: +\itemize{ +\item \verb{$storage_type()}: Returns the underlying \link{DataType} used to store +values. +\item \verb{$storage_id()}: Returns the \link{Type} identifier corresponding to the +\verb{$storage_type()}. +\item \verb{$extension_name()}: Returns the extension name. +\item \verb{$extension_metadata()}: Returns the serialized version of the extension +metadata as a \code{\link[=raw]{raw()}} vector. +\item \verb{$extension_metadata_utf8()}: Returns the serialized version of the +extension metadata as a UTF-8 encoded string. +\item \verb{$WrapArray(array)}: Wraps a storage \link{Array} into an \link{ExtensionArray} +with this extension type. +} + +In addition, subclasses may override the following methos to customize +the behaviour of extension classes. +\itemize{ +\item \verb{$deserialize_instance()}: This method is called when a new \link{ExtensionType} +is initialized and is responsible for parsing and validating +the serialized extension_metadata (a \code{\link[=raw]{raw()}} vector) +such that its contents can be inspected by fields and/or methods +of the R6 ExtensionType subclass. Implementations must also check the +\code{storage_type} to make sure it is compatible with the extension type. +\item \verb{$as_vector(extension_array)}: Convert an \link{Array} or \link{ChunkedArray} to an R +vector. This method is called by \code{\link[=as.vector]{as.vector()}} on \link{ExtensionArray} +objects, when a \link{RecordBatch} containing an \link{ExtensionArray} is +converted to a \code{\link[=data.frame]{data.frame()}}, or when a \link{ChunkedArray} (e.g., a column +in a \link{Table}) is converted to an R vector. The default method returns the +converted storage array. +\item \verb{$ToString()} Return a string representation that will be printed +to the console when this type or an Array of this type is printed. +} +} + diff --git a/r/man/new_extension_type.Rd b/r/man/new_extension_type.Rd new file mode 100644 index 00000000000..96d5c10c935 --- /dev/null +++ b/r/man/new_extension_type.Rd @@ -0,0 +1,167 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extension.R +\name{new_extension_type} +\alias{new_extension_type} +\alias{new_extension_array} +\alias{register_extension_type} +\alias{reregister_extension_type} +\alias{unregister_extension_type} +\title{Extension types} +\usage{ +new_extension_type( + storage_type, + extension_name, + extension_metadata = raw(), + type_class = ExtensionType +) + +new_extension_array(storage_array, extension_type) + +register_extension_type(extension_type) + +reregister_extension_type(extension_type) + +unregister_extension_type(extension_name) +} +\arguments{ +\item{storage_type}{The \link[=data-type]{data type} of the underlying storage +array.} + +\item{extension_name}{The extension name. This should be namespaced using +"dot" syntax (i.e., "some_package.some_type"). The namespace "arrow" +is reserved for extension types defined by the Apache Arrow libraries.} + +\item{extension_metadata}{A \code{\link[=raw]{raw()}} or \code{\link[=character]{character()}} vector containing the +serialized version of the type. Chatacter vectors must be length 1 and +are converted to UTF-8 before converting to \code{\link[=raw]{raw()}}.} + +\item{type_class}{An \link[R6:R6Class]{R6::R6Class} whose \verb{$new()} class method will be +used to construct a new instance of the type.} + +\item{storage_array}{An \link{Array} object of the underlying storage.} + +\item{extension_type}{An \link{ExtensionType} instance.} +} +\value{ +\itemize{ +\item \code{new_extension_type()} returns an \link{ExtensionType} instance according +to the \code{type_class} specified. +\item \code{new_extension_array()} returns an \link{ExtensionArray} whose \verb{$type} +corresponds to \code{extension_type}. +\item \code{register_extension_type()}, \code{unregister_extension_type()} +and \code{reregister_extension_type()} return \code{NULL}, invisibly. +} +} +\description{ +Extension arrays are wrappers around regular Arrow \link{Array} objects +that provide some customized behaviour and/or storage. A common use-case +for extension types is to define a customized conversion between an +an Arrow \link{Array} and an R object when the default conversion is slow +or looses metadata important to the interpretation of values in the array. +For most types, the built-in +\link[=vctrs_extension_type]{vctrs extension type} is probably sufficient. +} +\details{ +These functions create, register, and unregister \link{ExtensionType} +and \link{ExtensionArray} objects. To use an extension type you will have to: +\itemize{ +\item Define an \link[R6:R6Class]{R6::R6Class} that inherits from \link{ExtensionType} and reimplement +one or more methods (e.g., \code{deserialize_instance()}). +\item Make a type constructor function (e.g., \code{my_extension_type()}) that calls +\code{\link[=new_extension_type]{new_extension_type()}} to create an R6 instance that can be used as a +\link[=data-type]{data type} elsewhere in the package. +\item Make an array constructor function (e.g., \code{my_extension_array()}) that +calls \code{\link[=new_extension_array]{new_extension_array()}} to create an \link{Array} instance of your +extension type. +\item Register a dummy instance of your extension type created using +you constructor function using \code{\link[=register_extension_type]{register_extension_type()}}. +} + +If defining an extension type in an R package, you will probably want to +use \code{\link[=reregister_extension_type]{reregister_extension_type()}} in that package's \code{\link[=.onLoad]{.onLoad()}} hook +since your package will probably get reloaded in the same R session +during its development and \code{\link[=register_extension_type]{register_extension_type()}} will error if +called twice for the same \code{extension_name}. For an example of an +extension type that uses most of these features, see +\code{\link[=vctrs_extension_type]{vctrs_extension_type()}}. +} +\examples{ +\dontshow{if (arrow_available()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +# Create the R6 type whose methods control how Array objects are +# converted to R objects, how equality between types is computed, +# and how types are printed. +QuantizedType <- R6::R6Class( + "QuantizedType", + inherit = ExtensionType, + public = list( + # methods to access the custom metadata fields + center = function() private$.center, + scale = function() private$.scale, + + # called when an Array of this type is converted to an R vector + as_vector = function(extension_array) { + if (inherits(extension_array, "ExtensionArray")) { + unquantized_arrow <- + (extension_array$storage()$cast(float64()) / private$.scale) + + private$.center + + as.vector(unquantized_arrow) + } else { + super$as_vector(extension_array) + } + }, + + # populate the custom metadata fields from the serialized metadata + deserialize_instance = function() { + vals <- as.numeric(strsplit(self$extension_metadata_utf8(), ";")[[1]]) + private$.center <- vals[1] + private$.scale <- vals[2] + } + ), + + private = list( + .center = NULL, + .scale = NULL + ) +) + +# Create a helper type constructor that calls new_extension_type() +quantized <- function(center = 0, scale = 1, storage_type = int32()) { + new_extension_type( + storage_type = storage_type, + extension_name = "arrow.example.quantized", + extension_metadata = paste(center, scale, sep = ";"), + type_class = QuantizedType + ) +} + +# Create a helper array constructor that calls new_extension_array() +quantized_array <- function(x, center = 0, scale = 1, + storage_type = int32()) { + type <- quantized(center, scale, storage_type) + new_extension_array( + Array$create((x - center) * scale, type = storage_type), + type + ) +} + +# Register the extension type so that Arrow knows what to do when +# it encounters this extension type +reregister_extension_type(quantized()) + +# Create Array objects and use them! +(vals <- runif(5, min = 19, max = 21)) + +(array <- quantized_array( + vals, + center = 20, + scale = 2 ^ 15 - 1, + storage_type = int16()) +) + +array$type$center() +array$type$scale() + +as.vector(array) +\dontshow{\}) # examplesIf} +} diff --git a/r/man/vctrs_extension_array.Rd b/r/man/vctrs_extension_array.Rd new file mode 100644 index 00000000000..b80ce48dc2a --- /dev/null +++ b/r/man/vctrs_extension_array.Rd @@ -0,0 +1,50 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extension.R +\name{vctrs_extension_array} +\alias{vctrs_extension_array} +\alias{vctrs_extension_type} +\title{Extension type for generic typed vectors} +\usage{ +vctrs_extension_array(x, ptype = vctrs::vec_ptype(x), storage_type = NULL) + +vctrs_extension_type(ptype, storage_type = type(vctrs::vec_data(ptype))) +} +\arguments{ +\item{x}{A vctr (i.e., \code{\link[vctrs:vec_assert]{vctrs::vec_is()}} returns \code{TRUE}).} + +\item{ptype}{A \code{\link[vctrs:vec_ptype]{vctrs::vec_ptype()}}, which is usually a zero-length +version of the object with the appropriate attributes set. This value +will be serialized using \code{\link[=serialize]{serialize()}}, so it should not refer to any +R object that can't be saved/reloaded.} + +\item{storage_type}{The \link[=data-type]{data type} of the underlying storage +array.} +} +\value{ +\itemize{ +\item \code{vctrs_extension_array()} returns an \link{ExtensionArray} instance with a +\code{vctrs_extension_type()}. +\item \code{vctrs_extension_type()} returns an \link{ExtensionType} instance for the +extension name "arrow.r.vctrs". +} +} +\description{ +Most common R vector types are converted automatically to a suitable +Arrow \link[=data-type]{data type} without the need for an extension type. For +vector types whose conversion is not suitably handled by default, you can +create a \code{\link[=vctrs_extension_array]{vctrs_extension_array()}}, which passes \code{\link[vctrs:vec_data]{vctrs::vec_data()}} to +\code{Array$create()} and calls \code{\link[vctrs:vec_proxy]{vctrs::vec_restore()}} when the \link{Array} is +converted back into an R vector. +} +\examples{ +\dontshow{if (arrow_available()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +(array <- vctrs_extension_array(as.POSIXlt("2022-01-02 03:45", tz = "UTC"))) +array$type +as.vector(array) + +temp_feather <- tempfile() +write_feather(arrow_table(col = array), temp_feather) +read_feather(temp_feather) +unlink(temp_feather) +\dontshow{\}) # examplesIf} +} diff --git a/r/src/array.cpp b/r/src/array.cpp index 8fcc96e0d42..16490bbaeca 100644 --- a/r/src/array.cpp +++ b/r/src/array.cpp @@ -41,6 +41,8 @@ const char* r6_class_name::get(const std::shared_ptr return "FixedSizeListArray"; case arrow::Type::MAP: return "MapArray"; + case arrow::Type::EXTENSION: + return "ExtensionArray"; default: return "Array"; diff --git a/r/src/array_to_vector.cpp b/r/src/array_to_vector.cpp index 06d0a87a101..b89738d6c65 100644 --- a/r/src/array_to_vector.cpp +++ b/r/src/array_to_vector.cpp @@ -29,6 +29,7 @@ #include #include +#include "./extension.h" #include "./r_task_group.h" namespace arrow { @@ -1154,6 +1155,35 @@ class Converter_Null : public Converter { } }; +// Unlike other types, conversion of ExtensionType (chunked) arrays occurs at +// R level via the ExtensionType (or subclass) R6 instance. We do this via Allocate, +// since it is called once per ChunkedArray. +class Converter_Extension : public Converter { + public: + explicit Converter_Extension(const std::shared_ptr& chunked_array) + : Converter(chunked_array) {} + + SEXP Allocate(R_xlen_t n) const { + auto extension_type = + dynamic_cast(chunked_array_->type().get()); + if (extension_type == nullptr) { + Rf_error("Converter_Extension can't be used with a non-R extension type"); + } + + return extension_type->Convert(chunked_array_); + } + + // At this point we have already done the conversion + Status Ingest_all_nulls(SEXP data, R_xlen_t start, R_xlen_t n) const { + return Status::OK(); + } + + Status Ingest_some_nulls(SEXP data, const std::shared_ptr& array, + R_xlen_t start, R_xlen_t n, size_t chunk_index) const { + return Status::OK(); + } +}; + bool ArraysCanFitInteger(ArrayVector arrays) { bool all_can_fit = true; auto i32 = arrow::int32(); @@ -1316,6 +1346,9 @@ std::shared_ptr Converter::Make( case Type::NA: return std::make_shared(chunked_array); + case Type::EXTENSION: + return std::make_shared(chunked_array); + default: break; } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 0a29ed0872d..c4271a19aaf 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -4167,6 +4167,147 @@ extern "C" SEXP _arrow_compute___expr__type_id(SEXP x_sexp, SEXP schema_sexp){ } #endif +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +cpp11::environment ExtensionType__initialize(const std::shared_ptr& storage_type, std::string extension_name, cpp11::raws extension_metadata, cpp11::environment r6_class); +extern "C" SEXP _arrow_ExtensionType__initialize(SEXP storage_type_sexp, SEXP extension_name_sexp, SEXP extension_metadata_sexp, SEXP r6_class_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type storage_type(storage_type_sexp); + arrow::r::Input::type extension_name(extension_name_sexp); + arrow::r::Input::type extension_metadata(extension_metadata_sexp); + arrow::r::Input::type r6_class(r6_class_sexp); + return cpp11::as_sexp(ExtensionType__initialize(storage_type, extension_name, extension_metadata, r6_class)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionType__initialize(SEXP storage_type_sexp, SEXP extension_name_sexp, SEXP extension_metadata_sexp, SEXP r6_class_sexp){ + Rf_error("Cannot call ExtensionType__initialize(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +std::string ExtensionType__extension_name(const std::shared_ptr& type); +extern "C" SEXP _arrow_ExtensionType__extension_name(SEXP type_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type type(type_sexp); + return cpp11::as_sexp(ExtensionType__extension_name(type)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionType__extension_name(SEXP type_sexp){ + Rf_error("Cannot call ExtensionType__extension_name(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +cpp11::raws ExtensionType__Serialize(const std::shared_ptr& type); +extern "C" SEXP _arrow_ExtensionType__Serialize(SEXP type_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type type(type_sexp); + return cpp11::as_sexp(ExtensionType__Serialize(type)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionType__Serialize(SEXP type_sexp){ + Rf_error("Cannot call ExtensionType__Serialize(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExtensionType__storage_type(const std::shared_ptr& type); +extern "C" SEXP _arrow_ExtensionType__storage_type(SEXP type_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type type(type_sexp); + return cpp11::as_sexp(ExtensionType__storage_type(type)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionType__storage_type(SEXP type_sexp){ + Rf_error("Cannot call ExtensionType__storage_type(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExtensionType__MakeArray(const std::shared_ptr& type, const std::shared_ptr& data); +extern "C" SEXP _arrow_ExtensionType__MakeArray(SEXP type_sexp, SEXP data_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type type(type_sexp); + arrow::r::Input&>::type data(data_sexp); + return cpp11::as_sexp(ExtensionType__MakeArray(type, data)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionType__MakeArray(SEXP type_sexp, SEXP data_sexp){ + Rf_error("Cannot call ExtensionType__MakeArray(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +cpp11::environment ExtensionType__r6_class(const std::shared_ptr& type); +extern "C" SEXP _arrow_ExtensionType__r6_class(SEXP type_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type type(type_sexp); + return cpp11::as_sexp(ExtensionType__r6_class(type)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionType__r6_class(SEXP type_sexp){ + Rf_error("Cannot call ExtensionType__r6_class(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr ExtensionArray__storage(const std::shared_ptr& array); +extern "C" SEXP _arrow_ExtensionArray__storage(SEXP array_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type array(array_sexp); + return cpp11::as_sexp(ExtensionArray__storage(array)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_ExtensionArray__storage(SEXP array_sexp){ + Rf_error("Cannot call ExtensionArray__storage(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +void arrow__RegisterRExtensionType(const std::shared_ptr& type); +extern "C" SEXP _arrow_arrow__RegisterRExtensionType(SEXP type_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type type(type_sexp); + arrow__RegisterRExtensionType(type); + return R_NilValue; +END_CPP11 +} +#else +extern "C" SEXP _arrow_arrow__RegisterRExtensionType(SEXP type_sexp){ + Rf_error("Cannot call arrow__RegisterRExtensionType(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// extension-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +void arrow__UnregisterRExtensionType(std::string type_name); +extern "C" SEXP _arrow_arrow__UnregisterRExtensionType(SEXP type_name_sexp){ +BEGIN_CPP11 + arrow::r::Input::type type_name(type_name_sexp); + arrow__UnregisterRExtensionType(type_name); + return R_NilValue; +END_CPP11 +} +#else +extern "C" SEXP _arrow_arrow__UnregisterRExtensionType(SEXP type_name_sexp){ + Rf_error("Cannot call arrow__UnregisterRExtensionType(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // feather.cpp #if defined(ARROW_R_WITH_ARROW) void ipc___WriteFeather__Table(const std::shared_ptr& stream, const std::shared_ptr& table, int version, int chunk_size, arrow::Compression::type compression, int compression_level); @@ -8011,6 +8152,15 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1}, { "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2}, { "_arrow_compute___expr__type_id", (DL_FUNC) &_arrow_compute___expr__type_id, 2}, + { "_arrow_ExtensionType__initialize", (DL_FUNC) &_arrow_ExtensionType__initialize, 4}, + { "_arrow_ExtensionType__extension_name", (DL_FUNC) &_arrow_ExtensionType__extension_name, 1}, + { "_arrow_ExtensionType__Serialize", (DL_FUNC) &_arrow_ExtensionType__Serialize, 1}, + { "_arrow_ExtensionType__storage_type", (DL_FUNC) &_arrow_ExtensionType__storage_type, 1}, + { "_arrow_ExtensionType__MakeArray", (DL_FUNC) &_arrow_ExtensionType__MakeArray, 2}, + { "_arrow_ExtensionType__r6_class", (DL_FUNC) &_arrow_ExtensionType__r6_class, 1}, + { "_arrow_ExtensionArray__storage", (DL_FUNC) &_arrow_ExtensionArray__storage, 1}, + { "_arrow_arrow__RegisterRExtensionType", (DL_FUNC) &_arrow_arrow__RegisterRExtensionType, 1}, + { "_arrow_arrow__UnregisterRExtensionType", (DL_FUNC) &_arrow_arrow__UnregisterRExtensionType, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, { "_arrow_ipc___feather___Reader__version", (DL_FUNC) &_arrow_ipc___feather___Reader__version, 1}, { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 2}, diff --git a/r/src/datatype.cpp b/r/src/datatype.cpp index fd083f66d41..68b6c8fada5 100644 --- a/r/src/datatype.cpp +++ b/r/src/datatype.cpp @@ -101,6 +101,8 @@ const char* r6_class_name::get( return "StructType"; case Type::DICTIONARY: return "DictionaryType"; + case Type::EXTENSION: + return "ExtensionType"; default: break; diff --git a/r/src/extension-impl.cpp b/r/src/extension-impl.cpp new file mode 100644 index 00000000000..57c4874c973 --- /dev/null +++ b/r/src/extension-impl.cpp @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "./arrow_types.h" + +#if defined(ARROW_R_WITH_ARROW) + +#include + +#include +#include +#include + +#include "./extension.h" +#include "./safe-call-into-r.h" + +bool RExtensionType::ExtensionEquals(const arrow::ExtensionType& other) const { + // Avoid materializing the R6 instance if at all possible + if (other.extension_name() != extension_name()) { + return false; + } + + if (other.Serialize() == Serialize()) { + return true; + } + + // With any ambiguity, we need to materialize the R6 instance and call its + // ExtensionEquals method. We can't do this on the non-R thread. + // After ARROW-15841, we can use SafeCallIntoR. + arrow::Result result = SafeCallIntoR([&]() { + cpp11::environment instance = r6_instance(); + cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]); + + std::shared_ptr other_shared = + ValueOrStop(other.Deserialize(other.storage_type(), other.Serialize())); + cpp11::sexp other_r6 = cpp11::to_r6(other_shared, "ExtensionType"); + + cpp11::logicals result(instance_ExtensionEquals(other_r6)); + return cpp11::as_cpp(result); + }); + + if (!result.ok()) { + throw std::runtime_error(result.status().message()); + } + + return result.ValueUnsafe(); +} + +std::shared_ptr RExtensionType::MakeArray( + std::shared_ptr data) const { + std::shared_ptr new_data = data->Copy(); + std::unique_ptr cloned = Clone(); + new_data->type = std::shared_ptr(cloned.release()); + return std::make_shared(new_data); +} + +arrow::Result> RExtensionType::Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const { + std::unique_ptr cloned = Clone(); + cloned->storage_type_ = storage_type; + cloned->extension_metadata_ = serialized_data; + + // We could create an ephemeral R6 instance here, which will call the R6 instance's + // deserialize_instance() method, possibly erroring when the metadata is + // invalid or the deserialized values are invalid. The complexity of setting up + // an event loop from wherever this *might* be called is high and hard to + // predict. As a compromise, just create the instance when it is safe to + // do so. + if (GetMainRThread().IsMainThread()) { + r6_instance(); + } + + return std::shared_ptr(cloned.release()); +} + +std::string RExtensionType::ToString() const { + arrow::Result result = SafeCallIntoR([&]() { + cpp11::environment instance = r6_instance(); + cpp11::function instance_ToString(instance["ToString"]); + cpp11::sexp result = instance_ToString(); + return cpp11::as_cpp(result); + }); + + // In the event of an error (e.g., we are not on the main thread + // and we are not inside RunWithCapturedR()), just call the default method + if (!result.ok()) { + return ExtensionType::ToString(); + } else { + return result.ValueUnsafe(); + } +} + +cpp11::sexp RExtensionType::Convert( + const std::shared_ptr& array) const { + cpp11::environment instance = r6_instance(); + cpp11::function instance_Convert(instance["as_vector"]); + cpp11::sexp array_sexp = cpp11::to_r6(array, "ChunkedArray"); + return instance_Convert(array_sexp); +} + +std::unique_ptr RExtensionType::Clone() const { + RExtensionType* ptr = + new RExtensionType(storage_type(), extension_name_, extension_metadata_, r6_class_); + return std::unique_ptr(ptr); +} + +cpp11::environment RExtensionType::r6_instance( + std::shared_ptr storage_type, + const std::string& serialized_data) const { + // This is a version of to_r6<>() that is a more direct route to creating the object. + // This is done to avoid circular calls, since to_r6<>() has to go through + // ExtensionType$new(), which then calls back to C++ to get r6_class_ to then + // return the correct subclass. + std::unique_ptr cloned = Clone(); + cpp11::external_pointer> xp( + new std::shared_ptr(cloned.release())); + + cpp11::function r6_class_new(r6_class()["new"]); + return r6_class_new(xp); +} + +// [[arrow::export]] +cpp11::environment ExtensionType__initialize( + const std::shared_ptr& storage_type, std::string extension_name, + cpp11::raws extension_metadata, cpp11::environment r6_class) { + std::string metadata_string(extension_metadata.begin(), extension_metadata.end()); + auto r6_class_shared = std::make_shared(r6_class); + RExtensionType cpp_type(storage_type, extension_name, metadata_string, r6_class_shared); + return cpp_type.r6_instance(); +} + +// [[arrow::export]] +std::string ExtensionType__extension_name( + const std::shared_ptr& type) { + return type->extension_name(); +} + +// [[arrow::export]] +cpp11::raws ExtensionType__Serialize(const std::shared_ptr& type) { + std::string serialized_string = type->Serialize(); + cpp11::writable::raws bytes(serialized_string.begin(), serialized_string.end()); + return bytes; +} + +// [[arrow::export]] +std::shared_ptr ExtensionType__storage_type( + const std::shared_ptr& type) { + return type->storage_type(); +} + +// [[arrow::export]] +std::shared_ptr ExtensionType__MakeArray( + const std::shared_ptr& type, + const std::shared_ptr& data) { + return type->MakeArray(data); +} + +// [[arrow::export]] +cpp11::environment ExtensionType__r6_class( + const std::shared_ptr& type) { + auto r_type = + arrow::internal::checked_pointer_cast(type); + return r_type->r6_class(); +} + +// [[arrow::export]] +std::shared_ptr ExtensionArray__storage( + const std::shared_ptr& array) { + return array->storage(); +} + +// [[arrow::export]] +void arrow__RegisterRExtensionType(const std::shared_ptr& type) { + auto ext_type = std::dynamic_pointer_cast(type); + StopIfNotOk(arrow::RegisterExtensionType(ext_type)); +} + +// [[arrow::export]] +void arrow__UnregisterRExtensionType(std::string type_name) { + StopIfNotOk(arrow::UnregisterExtensionType(type_name)); +} + +#endif diff --git a/r/src/extension.h b/r/src/extension.h new file mode 100644 index 00000000000..fbd3ad48469 --- /dev/null +++ b/r/src/extension.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// or more contributor license agreements. See the NOTICE file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "./arrow_types.h" + +#include +#include +#include + +// A wrapper around arrow::ExtensionType that allows R to register extension +// types whose Deserialize, ExtensionEquals, and Serialize methods are +// in meaningfully handled at the R level. At the C++ level, the type is +// already serialized to minimize calls to R from C++. +// +// Using a std::shared_ptr<> to wrap a cpp11::sexp type is unusual, but we +// need it here to avoid calling the copy constructor from another thread, +// since this might call into the R API. If we don't do this, we get crashes +// when reading a multi-file Dataset. +class RExtensionType : public arrow::ExtensionType { + public: + RExtensionType(const std::shared_ptr storage_type, + std::string extension_name, std::string extension_metadata, + std::shared_ptr r6_class) + : arrow::ExtensionType(storage_type), + extension_name_(extension_name), + extension_metadata_(extension_metadata), + r6_class_(r6_class) {} + + std::string extension_name() const { return extension_name_; } + + bool ExtensionEquals(const arrow::ExtensionType& other) const; + + std::shared_ptr MakeArray(std::shared_ptr data) const; + + arrow::Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const; + + std::string Serialize() const { return extension_metadata_; } + + std::string ToString() const; + + cpp11::sexp Convert(const std::shared_ptr& array) const; + + std::unique_ptr Clone() const; + + cpp11::environment r6_class() const { return *r6_class_; } + + cpp11::environment r6_instance(std::shared_ptr storage_type, + const std::string& serialized_data) const; + + cpp11::environment r6_instance() const { + return r6_instance(storage_type(), Serialize()); + } + + private: + std::string extension_name_; + std::string extension_metadata_; + std::string cached_to_string_; + std::shared_ptr r6_class_; +}; diff --git a/r/tests/testthat/_snaps/extension.md b/r/tests/testthat/_snaps/extension.md new file mode 100644 index 00000000000..4335958b8ae --- /dev/null +++ b/r/tests/testthat/_snaps/extension.md @@ -0,0 +1,10 @@ +# extension types can be created + + `extension_array` must be a ChunkedArray or ExtensionArray + i Got object of type character + +# vctrs extension type works + + `extension_array` must be a ChunkedArray or ExtensionArray + i Got object of type character + diff --git a/r/tests/testthat/test-extension.R b/r/tests/testthat/test-extension.R new file mode 100644 index 00000000000..cf82b2f1f26 --- /dev/null +++ b/r/tests/testthat/test-extension.R @@ -0,0 +1,345 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("extension types can be created", { + type <- new_extension_type( + int32(), + "arrow_r.simple_extension", + charToRaw("some custom metadata"), + ) + + expect_r6_class(type, "ExtensionType") + expect_identical(type$extension_name(), "arrow_r.simple_extension") + expect_true(type$storage_type() == int32()) + expect_identical(type$storage_id(), int32()$id) + expect_identical(type$extension_metadata(), charToRaw("some custom metadata")) + expect_identical(type$ToString(), "ExtensionType ") + + storage <- Array$create(1:10) + array <- type$WrapArray(storage) + expect_r6_class(array, "ExtensionArray") + expect_r6_class(array$type, "ExtensionType") + + expect_true(array$type == type) + expect_true(all(array$storage() == storage)) + + expect_identical(array$as_vector(), 1:10) + expect_identical(chunked_array(array)$as_vector(), 1:10) + + expect_snapshot_error( + type$as_vector("not an extension array or chunked array") + ) +}) + +test_that("extension type subclasses work", { + SomeExtensionTypeSubclass <- R6Class( + "SomeExtensionTypeSubclass", inherit = ExtensionType, + public = list( + some_custom_method = function() { + private$some_custom_field + }, + + deserialize_instance = function() { + private$some_custom_field <- head(self$extension_metadata(), 5) + } + ), + private = list( + some_custom_field = NULL + ) + ) + + type <- new_extension_type( + int32(), + "some_extension_subclass", + charToRaw("some custom metadata"), + type_class = SomeExtensionTypeSubclass + ) + + expect_r6_class(type, "SomeExtensionTypeSubclass") + expect_identical(type$some_custom_method(), charToRaw("some ")) + + register_extension_type(type) + + # create a new type instance with storage/metadata not identical + # to the registered type + type2 <- new_extension_type( + float64(), + "some_extension_subclass", + charToRaw("some other custom metadata"), + type_class = SomeExtensionTypeSubclass + ) + + ptr_type <- allocate_arrow_schema() + type2$export_to_c(ptr_type) + type3 <- DataType$import_from_c(ptr_type) + delete_arrow_schema(ptr_type) + + expect_identical(type3$extension_name(), "some_extension_subclass") + expect_identical(type3$some_custom_method(), type2$some_custom_method()) + expect_identical(type3$extension_metadata(), type2$extension_metadata()) + expect_true(type3$storage_type() == type2$storage_type()) + + array <- type3$WrapArray(Array$create(1:10)) + expect_r6_class(array, "ExtensionArray") + + unregister_extension_type("some_extension_subclass") +}) + +test_that("extension types can use UTF-8 for metadata", { + type <- new_extension_type( + int32(), + "arrow.test.simple_extension", + "\U0001f4a9\U0001f4a9\U0001f4a9\U0001f4a9" + ) + + expect_identical( + type$extension_metadata_utf8(), + "\U0001f4a9\U0001f4a9\U0001f4a9\U0001f4a9" + ) + + expect_match(type$ToString(), "\U0001f4a9", fixed = TRUE) +}) + +test_that("extension types can be printed that don't use UTF-8 for metadata", { + type <- new_extension_type( + int32(), + "arrow.test.simple_extension", + as.raw(0:5) + ) + + expect_match(type$ToString(), "00 01 02 03 04 05") +}) + +test_that("extension subclasses can override the ExtensionEquals method", { + SomeExtensionTypeSubclass <- R6Class( + "SomeExtensionTypeSubclass", inherit = ExtensionType, + public = list( + field_values = NULL, + + deserialize_instance = function() { + self$field_values <- unserialize(self$extension_metadata()) + }, + + ExtensionEquals = function(other) { + if (!inherits(other, "SomeExtensionTypeSubclass")) { + return(FALSE) + } + + setequal(names(other$field_values), names(self$field_values)) && + identical( + other$field_values[names(self$field_values)], + self$field_values + ) + } + ) + ) + + type <- new_extension_type( + int32(), + "some_extension_subclass", + serialize(list(field1 = "value1", field2 = "value2"), NULL), + type_class = SomeExtensionTypeSubclass + ) + + register_extension_type(type) + + expect_true(type$ExtensionEquals(type)) + expect_true(type$Equals(type)) + + type2 <- new_extension_type( + int32(), + "some_extension_subclass", + serialize(list(field2 = "value2", field1 = "value1"), NULL), + type_class = SomeExtensionTypeSubclass + ) + + expect_true(type$ExtensionEquals(type2)) + expect_true(type$Equals(type2)) + + unregister_extension_type("some_extension_subclass") +}) + +test_that("vctrs extension type works", { + custom_vctr <- vctrs::new_vctr( + 1:4, + attr_key = "attr_val", + class = "arrow_custom_test" + ) + + type <- vctrs_extension_type(custom_vctr) + expect_r6_class(type, "VctrsExtensionType") + expect_identical(type$ptype(), vctrs::vec_ptype(custom_vctr)) + expect_true(type$Equals(type)) + expect_match(type$ToString(), "arrow_custom_test") + + array_in <- vctrs_extension_array(custom_vctr) + expect_true(array_in$type$Equals(type)) + expect_identical(vctrs_extension_array(array_in), array_in) + + tf <- tempfile() + on.exit(unlink(tf)) + write_feather(arrow_table(col = array_in), tf) + table_out <- read_feather(tf, as_data_frame = FALSE) + array_out <- table_out$col$chunk(0) + + expect_r6_class(array_out$type, "VctrsExtensionType") + expect_r6_class(array_out, "ExtensionArray") + + expect_true(array_out$type$Equals(type)) + expect_identical( + array_out$as_vector(), + custom_vctr + ) + + chunked_array_out <- table_out$col + expect_true(chunked_array_out$type$Equals(type)) + expect_identical( + chunked_array_out$as_vector(), + custom_vctr + ) + + expect_snapshot_error( + type$as_vector("not an extension array or chunked array") + ) +}) + +test_that("chunked arrays can roundtrip extension types", { + custom_vctr1 <- vctrs::new_vctr(1:4, class = "arrow_custom_test") + custom_vctr2 <- vctrs::new_vctr(5:8, class = "arrow_custom_test") + custom_array1 <- vctrs_extension_array(custom_vctr1) + custom_array2 <- vctrs_extension_array(custom_vctr2) + + custom_chunked <- chunked_array(custom_array1, custom_array2) + expect_r6_class(custom_chunked$type, "VctrsExtensionType") + expect_identical( + custom_chunked$as_vector(), + vctrs::new_vctr(1:8, class = "arrow_custom_test") + ) +}) + +test_that("RecordBatch can roundtrip extension types", { + custom_vctr <- vctrs::new_vctr(1:8, class = "arrow_custom_test") + custom_array <- vctrs_extension_array(custom_vctr) + normal_vctr <- letters[1:8] + + custom_record_batch <- record_batch(custom = custom_array) + expect_identical( + custom_record_batch$to_data_frame(), + tibble::tibble( + custom = custom_vctr + ) + ) + + mixed_record_batch <- record_batch( + custom = custom_array, + normal = normal_vctr + ) + expect_identical( + mixed_record_batch$to_data_frame(), + tibble::tibble( + custom = custom_vctr, + normal = normal_vctr + ) + ) + + # check both column orders, since column order should stay in the same + # order whether the colunns are are extension types or not + mixed_record_batch2 <- record_batch( + normal = normal_vctr, + custom = custom_array + ) + expect_identical( + mixed_record_batch2$to_data_frame(), + tibble::tibble( + normal = normal_vctr, + custom = custom_vctr + ) + ) +}) + +test_that("Table can roundtrip extension types", { + custom_vctr <- vctrs::new_vctr(1:8, class = "arrow_custom_test") + custom_array <- vctrs_extension_array(custom_vctr) + normal_vctr <- letters[1:8] + + custom_table <- arrow_table(custom = custom_array) + expect_identical( + custom_table$to_data_frame(), + tibble::tibble( + custom = custom_vctr + ) + ) + + mixed_table <- arrow_table( + custom = custom_array, + normal = normal_vctr + ) + expect_identical( + mixed_table$to_data_frame(), + tibble::tibble( + custom = custom_vctr, + normal = normal_vctr + ) + ) + + # check both column orders, since column order should stay in the same + # order whether the colunns are are extension types or not + mixed_table2 <- arrow_table( + normal = normal_vctr, + custom = custom_array + ) + expect_identical( + mixed_table2$to_data_frame(), + tibble::tibble( + normal = normal_vctr, + custom = custom_vctr + ) + ) +}) + +test_that("Dataset/arrow_dplyr_query can roundtrip extension types", { + skip_if_not_available("dataset") + + tf <- tempfile() + on.exit(unlink(tf, recursive = TRUE)) + + df <- expand.grid( + number = 1:10, + letter = letters, + stringsAsFactors = FALSE, + KEEP.OUT.ATTRS = FALSE + ) %>% + tibble::as_tibble() + + df$extension <- vctrs::new_vctr(df$letter, class = "arrow_custom_vctr") + + table <- arrow_table( + number = df$number, + letter = df$letter, + extension = vctrs_extension_array(df$extension) + ) + + table %>% + dplyr::group_by(number) %>% + write_dataset(tf) + + roundtripped <- open_dataset(tf) %>% + dplyr::select(number, letter, extension) %>% + dplyr::collect() + + expect_identical(unclass(roundtripped$extension), roundtripped$letter) +})