Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ S3method(as_record_batch,arrow_dplyr_query)
S3method(as_record_batch,data.frame)
S3method(as_record_batch,pyarrow.lib.RecordBatch)
S3method(as_record_batch,pyarrow.lib.Table)
S3method(as_record_batch_reader,"function")
S3method(as_record_batch_reader,Dataset)
S3method(as_record_batch_reader,RecordBatch)
S3method(as_record_batch_reader,RecordBatchReader)
Expand Down
4 changes: 4 additions & 0 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 54 additions & 14 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ tail_from_batches <- function(batches, n) {
#' @param FUN A function or `purrr`-style lambda expression to apply to each
#' batch. It must return a RecordBatch or something coercible to one via
#' `as_record_batch()'.
#' @param .schema An optional [schema()]. If NULL, the schema will be inferred
#' from the first batch.
#' @param .lazy Use `TRUE` to evaluate `FUN` lazily as batches are read from
#' the result; use `FALSE` to evaluate `FUN` on all batches before returning
#' the reader.
#' @param ... Additional arguments passed to `FUN`
#' @param .data.frame Deprecated argument, ignored
#' @return An `arrow_dplyr_query`.
#' @export
map_batches <- function(X, FUN, ..., .data.frame = NULL) {
map_batches <- function(X, FUN, ..., .schema = NULL, .lazy = FALSE, .data.frame = NULL) {
if (!is.null(.data.frame)) {
warning(
"The .data.frame argument is deprecated. ",
Expand All @@ -197,25 +202,60 @@ map_batches <- function(X, FUN, ..., .data.frame = NULL) {
}
FUN <- as_mapper(FUN)
reader <- as_record_batch_reader(X)
dots <- rlang::list2(...)

# TODO: for future consideration
# * Move eval to C++ and make it a generator so it can stream, not block
# * Accept an output schema argument: with that, we could make this lazy (via collapse)
batch <- reader$read_next_batch()
res <- vector("list", 1024)
i <- 0L
while (!is.null(batch)) {
i <- i + 1L
res[[i]] <- as_record_batch(FUN(batch, ...))
# If no schema is supplied, we have to evaluate the first batch here
if (is.null(.schema)) {
batch <- reader$read_next_batch()
if (is.null(batch)) {
abort("Can't infer schema from a RecordBatchReader with zero batches")
}

first_result <- as_record_batch(do.call(FUN, c(list(batch), dots)))
.schema <- first_result$schema
fun <- function() {
if (!is.null(first_result)) {
result <- first_result
first_result <<- NULL
result
} else {
batch <- reader$read_next_batch()
if (is.null(batch)) {
NULL
} else {
as_record_batch(
do.call(FUN, c(list(batch), dots)),
schema = .schema
)
}
}
}
} else {
fun <- function() {
batch <- reader$read_next_batch()
if (is.null(batch)) {
return(NULL)
}

as_record_batch(
do.call(FUN, c(list(batch), dots)),
schema = .schema
)
}
}

# Trim list back
if (i < length(res)) {
res <- res[seq_len(i)]
reader_out <- as_record_batch_reader(fun, schema = .schema)

# TODO(ARROW-17178) because there are some restrictions on evaluating
# reader_out in some ExecPlans, the default .lazy is FALSE for now.
if (!.lazy) {
reader_out <- RecordBatchReader$create(
batches = reader_out$batches(),
schema = .schema
)
}

RecordBatchReader$create(batches = res)
reader_out
}

#' @usage NULL
Expand Down
9 changes: 9 additions & 0 deletions r/R/record-batch-reader.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ RecordBatchFileReader$create <- function(file) {
#' Convert an object to an Arrow RecordBatchReader
#'
#' @param x An object to convert to a [RecordBatchReader]
#' @param schema The [schema()] that must match the schema returned by each
#' call to `x` when `x` is a function.
#' @param ... Passed to S3 methods
#'
#' @return A [RecordBatchReader]
Expand Down Expand Up @@ -234,6 +236,13 @@ as_record_batch_reader.Dataset <- function(x, ...) {
Scanner$create(x)$ToRecordBatchReader()
}

#' @rdname as_record_batch_reader
#' @export
as_record_batch_reader.function <- function(x, ..., schema) {
assert_that(inherits(schema, "Schema"))
RecordBatchReader__from_function(x, schema)
}

#' @rdname as_record_batch_reader
#' @export
as_record_batch_reader.arrow_dplyr_query <- function(x, ...) {
Expand Down
6 changes: 6 additions & 0 deletions r/man/as_record_batch_reader.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion r/man/map_batches.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions r/src/recordbatchreader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "./arrow_types.h"
#include "./safe-call-into-r.h"

#include <arrow/ipc/reader.h>
#include <arrow/table.h>
Expand Down Expand Up @@ -54,6 +55,50 @@ std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_batches(
}
}

class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
public:
RFunctionRecordBatchReader(cpp11::sexp fun,
const std::shared_ptr<arrow::Schema>& schema)
: fun_(fun), schema_(schema) {}

std::shared_ptr<arrow::Schema> schema() const { return schema_; }

arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
cpp11::sexp result_sexp = fun_();
if (result_sexp == R_NilValue) {
return std::shared_ptr<arrow::RecordBatch>(nullptr);
} else if (!Rf_inherits(result_sexp, "RecordBatch")) {
cpp11::stop("Expected fun() to return an arrow::RecordBatch");
}

return cpp11::as_cpp<std::shared_ptr<arrow::RecordBatch>>(result_sexp);
});

RETURN_NOT_OK(batch);

if (batch.ValueUnsafe().get() != nullptr &&
!batch.ValueUnsafe()->schema()->Equals(schema_)) {
return arrow::Status::Invalid("Expected fun() to return batch with schema '",
schema_->ToString(), "' but got batch with schema '",
batch.ValueUnsafe()->schema()->ToString(), "'");
}

*batch_out = batch.ValueUnsafe();
return arrow::Status::OK();
}

private:
cpp11::function fun_;
std::shared_ptr<arrow::Schema> schema_;
};

// [[arrow::export]]
std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_function(
cpp11::sexp fun_sexp, const std::shared_ptr<arrow::Schema>& schema) {
return std::make_shared<RFunctionRecordBatchReader>(fun_sexp, schema);
}

// [[arrow::export]]
std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_Table(
const std::shared_ptr<arrow::Table>& table) {
Expand Down
2 changes: 1 addition & 1 deletion r/src/safe-call-into-r-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool CanRunWithCapturedR() {
on_old_windows = on_old_windows_fun();
}

return !on_old_windows;
return !on_old_windows && GetMainRThread().Executor() == nullptr;
#else
return false;
#endif
Expand Down
4 changes: 3 additions & 1 deletion r/src/safe-call-into-r.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
// and crash R in older versions (ARROW-16201). Crashes also occur
// on 32-bit R builds on R 3.6 and lower. Implementation provided
// in safe-call-into-r-impl.cpp so that we can skip some tests
// when this feature is not provided.
// when this feature is not provided. This also checks that there
// is not already an event loop registered (via MainRThread::Executor()),
// because only one of these can exist at any given time.
bool CanRunWithCapturedR();

// The MainRThread class keeps track of the thread on which it is safe
Expand Down
86 changes: 86 additions & 0 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,92 @@ test_that("map_batches", {
)
})

test_that("map_batches with explicit schema", {
fun_with_dots <- function(batch, first_col, first_col_val) {
record_batch(
!! first_col := first_col_val,
b = batch$a$cast(float64())
)
}

empty_reader <- RecordBatchReader$create(
batches = list(),
schema = schema(a = int32())
)
expect_equal(
map_batches(
empty_reader,
fun_with_dots,
"first_col_name",
"first_col_value",
.schema = schema(first_col_name = string(), b = float64())
)$read_table(),
arrow_table(first_col_name = character(), b = double())
)

reader <- RecordBatchReader$create(
batches = list(
record_batch(a = 1, b = "two"),
record_batch(a = 2, b = "three")
)
)
expect_equal(
map_batches(
reader,
fun_with_dots,
"first_col_name",
"first_col_value",
.schema = schema(first_col_name = string(), b = float64())
)$read_table(),
arrow_table(
first_col_name = c("first_col_value", "first_col_value"),
b = as.numeric(1:2)
)
)
})

test_that("map_batches without explicit schema", {
fun_with_dots <- function(batch, first_col, first_col_val) {
record_batch(
!! first_col := first_col_val,
b = batch$a$cast(float64())
)
}

empty_reader <- RecordBatchReader$create(
batches = list(),
schema = schema(a = int32())
)
expect_error(
map_batches(
empty_reader,
fun_with_dots,
"first_col_name",
"first_col_value"
)$read_table(),
"Can't infer schema"
)

reader <- RecordBatchReader$create(
batches = list(
record_batch(a = 1, b = "two"),
record_batch(a = 2, b = "three")
)
)
expect_equal(
map_batches(
reader,
fun_with_dots,
"first_col_name",
"first_col_value"
)$read_table(),
arrow_table(
first_col_name = c("first_col_value", "first_col_value"),
b = as.numeric(1:2)
)
)
})

test_that("head/tail", {
# head/tail with no query are still deterministic order
ds <- open_dataset(dataset_dir)
Expand Down
Loading