diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 47b10b1bb10..9b40495413e 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -31,6 +31,11 @@ #' @importFrom vctrs s3_register vec_size vec_cast vec_unique .onLoad <- function(...) { + if (arrow_available()) { + # Make sure C++ knows on which thread it is safe to call the R API + InitializeMainRThread() + } + dplyr_methods <- paste0( "dplyr::", c( diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index c9468f52ae3..332e797ee04 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1704,6 +1704,14 @@ ipc___RecordBatchStreamWriter__Open <- function(stream, schema, use_legacy_forma .Call(`_arrow_ipc___RecordBatchStreamWriter__Open`, stream, schema, use_legacy_format, metadata_version) } +InitializeMainRThread <- function() { + invisible(.Call(`_arrow_InitializeMainRThread`)) +} + +TestSafeCallIntoR <- function(r_fun_that_returns_a_string, opt) { + .Call(`_arrow_TestSafeCallIntoR`, r_fun_that_returns_a_string, opt) +} + Array__GetScalar <- function(x, i) { .Call(`_arrow_Array__GetScalar`, x, i) } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 59762790fa3..23681a7927c 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -6711,6 +6711,37 @@ extern "C" SEXP _arrow_ipc___RecordBatchStreamWriter__Open(SEXP stream_sexp, SEX } #endif +// safe-call-into-r-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +void InitializeMainRThread(); +extern "C" SEXP _arrow_InitializeMainRThread(){ +BEGIN_CPP11 + InitializeMainRThread(); + return R_NilValue; +END_CPP11 +} +#else +extern "C" SEXP _arrow_InitializeMainRThread(){ + Rf_error("Cannot call InitializeMainRThread(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + +// safe-call-into-r-impl.cpp +#if defined(ARROW_R_WITH_ARROW) +std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string, std::string opt); +extern "C" SEXP _arrow_TestSafeCallIntoR(SEXP r_fun_that_returns_a_string_sexp, SEXP opt_sexp){ +BEGIN_CPP11 + arrow::r::Input::type r_fun_that_returns_a_string(r_fun_that_returns_a_string_sexp); + arrow::r::Input::type opt(opt_sexp); + return cpp11::as_sexp(TestSafeCallIntoR(r_fun_that_returns_a_string, opt)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_TestSafeCallIntoR(SEXP r_fun_that_returns_a_string_sexp, SEXP opt_sexp){ + Rf_error("Cannot call TestSafeCallIntoR(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // scalar.cpp #if defined(ARROW_R_WITH_ARROW) std::shared_ptr Array__GetScalar(const std::shared_ptr& x, int64_t i); @@ -7541,32 +7572,6 @@ extern "C" SEXP _arrow_Array__infer_type(SEXP x_sexp){ } #endif -#if defined(ARROW_R_WITH_ARROW) -extern "C" SEXP _arrow_Table__Reset(SEXP r6) { -BEGIN_CPP11 -arrow::r::r6_reset_pointer(r6); -END_CPP11 -return R_NilValue; -} -#else -extern "C" SEXP _arrow_Table__Reset(SEXP r6){ - Rf_error("Cannot call Table(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - -#if defined(ARROW_R_WITH_ARROW) -extern "C" SEXP _arrow_RecordBatch__Reset(SEXP r6) { -BEGIN_CPP11 -arrow::r::r6_reset_pointer(r6); -END_CPP11 -return R_NilValue; -} -#else -extern "C" SEXP _arrow_RecordBatch__Reset(SEXP r6){ - Rf_error("Cannot call RecordBatch(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - extern "C" SEXP _arrow_available() { return Rf_ScalarLogical( #if defined(ARROW_R_WITH_ARROW) @@ -8044,6 +8049,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchWriter__Close", (DL_FUNC) &_arrow_ipc___RecordBatchWriter__Close, 1}, { "_arrow_ipc___RecordBatchFileWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchFileWriter__Open, 4}, { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, + { "_arrow_InitializeMainRThread", (DL_FUNC) &_arrow_InitializeMainRThread, 0}, + { "_arrow_TestSafeCallIntoR", (DL_FUNC) &_arrow_TestSafeCallIntoR, 2}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, { "_arrow_StructScalar__field", (DL_FUNC) &_arrow_StructScalar__field, 2}, @@ -8097,8 +8104,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_GetIOThreadPoolCapacity", (DL_FUNC) &_arrow_GetIOThreadPoolCapacity, 0}, { "_arrow_SetIOThreadPoolCapacity", (DL_FUNC) &_arrow_SetIOThreadPoolCapacity, 1}, { "_arrow_Array__infer_type", (DL_FUNC) &_arrow_Array__infer_type, 1}, - { "_arrow_Table__Reset", (DL_FUNC) &_arrow_Table__Reset, 1}, - { "_arrow_RecordBatch__Reset", (DL_FUNC) &_arrow_RecordBatch__Reset, 1}, {NULL, NULL, 0} }; extern "C" void R_init_arrow(DllInfo* dll){ diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp new file mode 100644 index 00000000000..aa0645aa7b4 --- /dev/null +++ b/r/src/safe-call-into-r-impl.cpp @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "./arrow_types.h" +#if defined(ARROW_R_WITH_ARROW) + +#include +#include +#include "./safe-call-into-r.h" + +MainRThread& GetMainRThread() { + static MainRThread main_r_thread; + return main_r_thread; +} + +// [[arrow::export]] +void InitializeMainRThread() { GetMainRThread().Initialize(); } + +// [[arrow::export]] +std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string, + std::string opt) { + if (opt == "async_with_executor") { + std::thread* thread_ptr; + + auto result = + RunWithCapturedR([&thread_ptr, r_fun_that_returns_a_string]() { + auto fut = arrow::Future::Make(); + thread_ptr = new std::thread([fut, r_fun_that_returns_a_string]() mutable { + auto result = SafeCallIntoR([&] { + return cpp11::as_cpp(r_fun_that_returns_a_string()); + }); + + fut.MarkFinished(result); + }); + + return fut; + }); + + thread_ptr->join(); + delete thread_ptr; + + return arrow::ValueOrStop(result); + } else if (opt == "async_without_executor") { + std::thread* thread_ptr; + + auto fut = arrow::Future::Make(); + thread_ptr = new std::thread([fut, r_fun_that_returns_a_string]() mutable { + auto result = SafeCallIntoR( + [&] { return cpp11::as_cpp(r_fun_that_returns_a_string()); }); + + if (result.ok()) { + fut.MarkFinished(result.ValueUnsafe()); + } else { + fut.MarkFinished(result.status()); + } + }); + + thread_ptr->join(); + delete thread_ptr; + + // We should be able to get this far, but fut will contain an error + // because it tried to evaluate R code from another thread + return arrow::ValueOrStop(fut.result()); + + } else if (opt == "on_main_thread") { + auto result = SafeCallIntoR( + [&]() { return cpp11::as_cpp(r_fun_that_returns_a_string()); }); + arrow::StopIfNotOk(result.status()); + return result.ValueUnsafe(); + } else { + cpp11::stop("Unknown `opt`"); + } +} + +#endif diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h new file mode 100644 index 00000000000..1a27507b788 --- /dev/null +++ b/r/src/safe-call-into-r.h @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef SAFE_CALL_INTO_R_INCLUDED +#define SAFE_CALL_INTO_R_INCLUDED + +#include "./arrow_types.h" + +#include +#include + +#include +#include + +// The MainRThread class keeps track of the thread on which it is safe +// to call the R API to facilitate its safe use (or erroring +// if it is not safe). The MainRThread singleton can be accessed from +// any thread using GetMainRThread(); the preferred way to call +// the R API where it may not be safe to do so is to use +// SafeCallIntoR([&]() { ... }). +class MainRThread { + public: + MainRThread() : initialized_(false), executor_(nullptr) {} + + // Call this method from the R thread (e.g., on package load) + // to save an internal copy of the thread id. + void Initialize() { + thread_id_ = std::this_thread::get_id(); + initialized_ = true; + SetError(R_NilValue); + } + + bool IsInitialized() { return initialized_; } + + // Check if the current thread is the main R thread + bool IsMainThread() { return initialized_ && std::this_thread::get_id() == thread_id_; } + + // The Executor that is running on the main R thread, if it exists + arrow::internal::Executor*& Executor() { return executor_; } + + // Save an error token generated from a cpp11::unwind_exception + // so that it can be properly handled after some cleanup code + // has run (e.g., cancelling some futures or waiting for them + // to finish). + void SetError(cpp11::sexp token) { error_token_ = token; } + + void ResetError() { error_token_ = R_NilValue; } + + // Check if there is a saved error + bool HasError() { return error_token_ != R_NilValue; } + + // Throw a cpp11::unwind_exception() with the saved token if it exists + void ClearError() { + if (HasError()) { + cpp11::unwind_exception e(error_token_); + ResetError(); + throw e; + } + } + + private: + bool initialized_; + std::thread::id thread_id_; + cpp11::sexp error_token_; + arrow::internal::Executor* executor_; +}; + +// Retrieve the MainRThread singleton +MainRThread& GetMainRThread(); + +// Call into R and return a C++ object. Note that you can't return +// a SEXP (use cpp11::as_cpp to convert it to a C++ type inside +// `fun`). +template +arrow::Future SafeCallIntoRAsync(std::function fun) { + MainRThread& main_r_thread = GetMainRThread(); + if (main_r_thread.IsMainThread()) { + // If we're on the main thread, run the task immediately and let + // the cpp11::unwind_exception be thrown since it will be caught + // at the top level. + return fun(); + } else if (main_r_thread.Executor() != nullptr) { + // If we are not on the main thread and have an Executor, + // use it to run the task on the main R thread. We can't throw + // a cpp11::unwind_exception here, so we need to propagate it back + // to RunWithCapturedR through the MainRThread singleton. + return DeferNotOk(main_r_thread.Executor()->Submit([fun]() { + if (GetMainRThread().HasError()) { + return arrow::Result(arrow::Status::UnknownError("R code execution error")); + } + + try { + return arrow::Result(fun()); + } catch (cpp11::unwind_exception& e) { + GetMainRThread().SetError(e.token); + return arrow::Result(arrow::Status::UnknownError("R code execution error")); + } + })); + } else { + return arrow::Status::NotImplemented( + "Call to R from a non-R thread without calling RunWithCapturedR"); + } +} + +template +arrow::Result SafeCallIntoR(std::function fun) { + arrow::Future future = SafeCallIntoRAsync(std::move(fun)); + return future.result(); +} + +template +arrow::Result RunWithCapturedR(std::function()> make_arrow_call) { + if (GetMainRThread().Executor() != nullptr) { + return arrow::Status::AlreadyExists("Attempt to use more than one R Executor()"); + } + + GetMainRThread().ResetError(); + + arrow::Result result = arrow::internal::SerialExecutor::RunInSerialExecutor( + [make_arrow_call](arrow::internal::Executor* executor) { + GetMainRThread().Executor() = executor; + return make_arrow_call(); + }); + + GetMainRThread().Executor() = nullptr; + GetMainRThread().ClearError(); + + return result; +} + +#endif diff --git a/r/tests/testthat/test-safe-call-into-r.R b/r/tests/testthat/test-safe-call-into-r.R new file mode 100644 index 00000000000..e9438de58be --- /dev/null +++ b/r/tests/testthat/test-safe-call-into-r.R @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Note that TestSafeCallIntoR is defined in safe-call-into-r-impl.cpp + +test_that("SafeCallIntoR works from the main R thread", { + skip_on_cran() + + expect_identical( + TestSafeCallIntoR(function() "string one!", opt = "on_main_thread"), + "string one!" + ) + + expect_error( + TestSafeCallIntoR(function() stop("an error!"), opt = "on_main_thread"), + "an error!" + ) +}) + +test_that("SafeCallIntoR works within RunWithCapturedR", { + skip_on_cran() + + expect_identical( + TestSafeCallIntoR(function() "string one!", opt = "async_with_executor"), + "string one!" + ) + + expect_error( + TestSafeCallIntoR(function() stop("an error!"), opt = "async_with_executor"), + "an error!" + ) +}) + +test_that("SafeCallIntoR errors from the non-R thread", { + skip_on_cran() + + expect_error( + TestSafeCallIntoR(function() "string one!", opt = "async_without_executor"), + "Call to R from a non-R thread" + ) + + expect_error( + TestSafeCallIntoR(function() stop("an error!"), opt = "async_without_executor"), + "Call to R from a non-R thread" + ) +})