diff --git a/java/lance-jni/src/async_scanner.rs b/java/lance-jni/src/async_scanner.rs new file mode 100644 index 00000000000..eada9287c47 --- /dev/null +++ b/java/lance-jni/src/async_scanner.rs @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use crate::RT; +use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; +use crate::blocking_scanner::{ScannerOptions, build_scanner_with_options}; +use crate::dispatcher::{DISPATCHER, DispatcherMessage}; +use crate::error::Result; +use crate::task_tracker::{TASK_TRACKER, TaskInfo}; +use arrow::ffi::FFI_ArrowSchema; +use jni::JNIEnv; +use jni::objects::JObject; +use jni::sys::{jboolean, jint, jlong}; +use lance::dataset::scanner::Scanner; +use lance_io::ffi::to_ffi_arrow_array_stream; + +pub const NATIVE_ASYNC_SCANNER: &str = "nativeAsyncScannerHandle"; + +/// Async scanner that spawns Tokio tasks for non-blocking I/O +pub struct AsyncScanner { + pub(crate) inner: Arc, +} + +/// RAII guard that ensures task cleanup even on panic or early return +/// +/// This guard prevents memory leaks in the task tracker by guaranteeing +/// that task_id is removed from the HashMap when the guard is dropped, +/// regardless of how the async task terminates (normal completion, panic, +/// or cancellation). +struct TaskCleanupGuard { + task_id: u64, +} + +impl TaskCleanupGuard { + fn new(task_id: u64) -> Self { + Self { task_id } + } +} + +impl Drop for TaskCleanupGuard { + fn drop(&mut self) { + // GUARANTEED to run when guard goes out of scope + // Works even if the task panics or returns early + // + // Note: We spawn a detached task instead of using block_on() + // because Drop may be called from within a tokio runtime context + let task_id = self.task_id; + RT.spawn(async move { + TASK_TRACKER.complete(task_id).await; + log::debug!("Task {} cleaned up via RAII guard", task_id); + }); + } +} + +impl AsyncScanner { + pub fn create(scanner: Scanner) -> Self { + Self { + inner: Arc::new(scanner), + } + } + + /// Start an async scan task (static method to avoid holding locks) + pub fn start_scan_with_scanner( + scanner: Arc, + task_id: u64, + scanner_global_ref: jni::objects::GlobalRef, + ) { + // Two-phase registration to prevent race condition: + // 1. Pre-register with placeholder handle BEFORE spawning + // 2. Spawn the actual task + // 3. Update registration with real handle + // This ensures task is registered before cleanup can run + + // Clone for the spawned task + let global_ref_for_task = scanner_global_ref.clone(); + + // Step 1: Pre-register with placeholder handle + let placeholder_handle = RT.spawn(async { + // Placeholder task that does nothing + // Will be aborted when real handle is registered + }); + + RT.block_on(async { + TASK_TRACKER + .register( + task_id, + TaskInfo { + scanner_global_ref: scanner_global_ref.clone(), + cancel_handle: placeholder_handle, + }, + ) + .await; + }); + + // Step 2: Spawn the actual task + let handle = RT.spawn(async move { + // RAII guard ensures cleanup on normal exit, panic, or cancellation + let _cleanup_guard = TaskCleanupGuard::new(task_id); + + let result = match scanner.try_into_stream().await { + Ok(stream) => { + // Convert to FFI pointer + match to_ffi_arrow_array_stream(stream, RT.handle().clone()) { + Ok(ffi_stream) => { + let ptr = Box::into_raw(Box::new(ffi_stream)) as i64; + Ok(ptr) + } + Err(e) => Err(e.to_string()), + } + } + Err(e) => Err(e.to_string()), + }; + + // Send result to dispatcher for Java completion + let dispatcher = match DISPATCHER.get() { + Some(d) => d, + None => { + log::error!( + "Dispatcher not initialized - cannot complete task {}. \ + This indicates a critical initialization failure.", + task_id + ); + // Clean up the FFI stream pointer to prevent memory leak + if let Ok(ptr) = result { + unsafe { + drop(Box::from_raw( + ptr as *mut arrow::ffi_stream::FFI_ArrowArrayStream, + )); + } + log::debug!("Cleaned up FFI stream pointer for task {}", task_id); + } + return; + } + }; + + // Save the pointer before sending so we can clean up on failure + let result_ptr = result.as_ref().ok().copied(); + + if let Err(e) = dispatcher.send(DispatcherMessage { + scanner_global_ref: global_ref_for_task, + task_id, + result, + }) { + log::error!( + "Failed to send completion message for task {}: {}", + task_id, + e + ); + // Clean up the FFI stream pointer to prevent memory leak + if let Some(ptr) = result_ptr { + unsafe { + drop(Box::from_raw( + ptr as *mut arrow::ffi_stream::FFI_ArrowArrayStream, + )); + } + log::debug!("Cleaned up FFI stream pointer for task {}", task_id); + } + } + + // _cleanup_guard.drop() called here automatically, removing task from tracker + }); + + // Step 3: Update registration with real handle + RT.block_on(async { + TASK_TRACKER.update_handle(task_id, handle).await; + }); + } +} + +// JNI Exports + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_ipc_AsyncScanner_createAsyncScanner<'local>( + mut env: JNIEnv<'local>, + _class: JObject<'local>, + jdataset: JObject<'local>, + fragment_ids_obj: JObject<'local>, + columns_obj: JObject<'local>, + substrait_filter_obj: JObject<'local>, + filter_obj: JObject<'local>, + batch_size_obj: JObject<'local>, + limit_obj: JObject<'local>, + offset_obj: JObject<'local>, + query_obj: JObject<'local>, + fts_query_obj: JObject<'local>, + prefilter: jboolean, + with_row_id: jboolean, + with_row_address: jboolean, + batch_readahead: jint, + column_orderings: JObject<'local>, + use_scalar_index: jboolean, + substrait_aggregate_obj: JObject<'local>, +) -> JObject<'local> { + crate::ok_or_throw!( + env, + inner_create_async_scanner( + &mut env, + jdataset, + fragment_ids_obj, + columns_obj, + substrait_filter_obj, + filter_obj, + batch_size_obj, + limit_obj, + offset_obj, + query_obj, + fts_query_obj, + prefilter, + with_row_id, + with_row_address, + batch_readahead, + column_orderings, + use_scalar_index, + substrait_aggregate_obj, + ) + ) +} + +#[allow(clippy::too_many_arguments)] +fn inner_create_async_scanner<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject<'local>, + fragment_ids_obj: JObject<'local>, + columns_obj: JObject<'local>, + substrait_filter_obj: JObject<'local>, + filter_obj: JObject<'local>, + batch_size_obj: JObject<'local>, + limit_obj: JObject<'local>, + offset_obj: JObject<'local>, + query_obj: JObject<'local>, + fts_query_obj: JObject<'local>, + prefilter: jboolean, + with_row_id: jboolean, + with_row_address: jboolean, + batch_readahead: jint, + column_orderings: JObject<'local>, + use_scalar_index: jboolean, + substrait_aggregate_obj: JObject<'local>, +) -> Result> { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; + let dataset = dataset_guard.inner.clone(); + drop(dataset_guard); + + let options = ScannerOptions { + fragment_ids_obj, + columns_obj, + substrait_filter_obj, + filter_obj, + batch_size_obj, + limit_obj, + offset_obj, + query_obj, + fts_query_obj, + prefilter, + with_row_id, + with_row_address, + batch_readahead, + column_orderings, + use_scalar_index, + substrait_aggregate_obj, + }; + + let scanner = build_scanner_with_options(env, &dataset, options)?; + + let async_scanner = AsyncScanner::create(scanner); + + // Create Java AsyncScanner object + let j_scanner = env.new_object("org/lance/ipc/AsyncScanner", "()V", &[])?; + + // Attach native handle + unsafe { env.set_rust_field(&j_scanner, NATIVE_ASYNC_SCANNER, async_scanner)? }; + + Ok(j_scanner) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_ipc_AsyncScanner_nativeStartScan( + mut env: JNIEnv, + j_scanner: JObject, + task_id: jlong, +) { + ok_or_throw_without_return!(env, inner_start_scan(&mut env, j_scanner, task_id as u64)); +} + +fn inner_start_scan(env: &mut JNIEnv, j_scanner: JObject, task_id: u64) -> Result<()> { + // Create global reference first, before borrowing scanner + let scanner_global_ref = env.new_global_ref(&j_scanner)?; + + // Clone the Arc and drop the MutexGuard before calling start_scan, + // which does block_on internally. Holding the guard across block_on risks deadlock. + let scanner = { + let guard = + unsafe { env.get_rust_field::<_, _, AsyncScanner>(&j_scanner, NATIVE_ASYNC_SCANNER)? }; + guard.inner.clone() + }; + + AsyncScanner::start_scan_with_scanner(scanner, task_id, scanner_global_ref); + Ok(()) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_ipc_AsyncScanner_nativeCancelTask( + _env: JNIEnv, + _j_scanner: JObject, + task_id: jlong, +) { + RT.block_on(async { + TASK_TRACKER.cancel(task_id as u64).await; + }); +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_ipc_AsyncScanner_releaseNativeScanner( + mut env: JNIEnv, + j_scanner: JObject, +) { + ok_or_throw_without_return!(env, inner_release_async_scanner(&mut env, j_scanner)); +} + +fn inner_release_async_scanner(env: &mut JNIEnv, j_scanner: JObject) -> Result<()> { + let _: AsyncScanner = unsafe { env.take_rust_field(j_scanner, NATIVE_ASYNC_SCANNER) }?; + Ok(()) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_ipc_AsyncScanner_importFfiSchema( + mut env: JNIEnv, + j_scanner: JObject, + schema_addr: jlong, +) { + ok_or_throw_without_return!( + env, + inner_import_async_ffi_schema(&mut env, j_scanner, schema_addr) + ); +} + +fn inner_import_async_ffi_schema( + env: &mut JNIEnv, + j_scanner: JObject, + schema_addr: jlong, +) -> Result<()> { + let scanner_guard = + unsafe { env.get_rust_field::<_, _, AsyncScanner>(j_scanner, NATIVE_ASYNC_SCANNER)? }; + + let schema = RT.block_on(scanner_guard.inner.schema())?; + let ffi_schema = FFI_ArrowSchema::try_from(&*schema)?; + unsafe { std::ptr::write_unaligned(schema_addr as *mut FFI_ArrowSchema, ffi_schema) } + Ok(()) +} diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 93a441f3902..5a369b98a73 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -58,7 +58,16 @@ impl BlockingScanner { } } -fn build_full_text_search_query<'a>(env: &mut JNIEnv<'a>, java_obj: JObject) -> Result { +/////////////////// +// Shared Helpers // +/////////////////// + +/// Build FTS query from Java FullTextQuery object +/// Made pub(crate) to be reused by async_scanner +pub(crate) fn build_full_text_search_query<'a>( + env: &mut JNIEnv<'a>, + java_obj: JObject, +) -> Result { let type_obj = env .call_method( &java_obj, @@ -193,88 +202,40 @@ fn build_full_text_search_query<'a>(env: &mut JNIEnv<'a>, java_obj: JObject) -> } } -/////////////////// -// Write Methods // -/////////////////// -#[unsafe(no_mangle)] -pub extern "system" fn Java_org_lance_ipc_LanceScanner_createScanner<'local>( - mut env: JNIEnv<'local>, - _reader: JObject, - jdataset: JObject, - fragment_ids_obj: JObject, // Optional> - columns_obj: JObject, // Optional> - substrait_filter_obj: JObject, // Optional - filter_obj: JObject, // Optional - batch_size_obj: JObject, // Optional - limit_obj: JObject, // Optional - offset_obj: JObject, // Optional - query_obj: JObject, // Optional - fts_query_obj: JObject, // Optional - prefilter: jboolean, // boolean - with_row_id: jboolean, // boolean - with_row_address: jboolean, // boolean - batch_readahead: jint, // int - column_orderings: JObject, // Optional> - use_scalar_index: jboolean, // boolean - substrait_aggregate_obj: JObject, // Optional -) -> JObject<'local> { - ok_or_throw!( - env, - inner_create_scanner( - &mut env, - jdataset, - fragment_ids_obj, - columns_obj, - substrait_filter_obj, - filter_obj, - batch_size_obj, - limit_obj, - offset_obj, - query_obj, - fts_query_obj, - prefilter, - with_row_id, - with_row_address, - batch_readahead, - column_orderings, - use_scalar_index, - substrait_aggregate_obj - ) - ) +/// Scanner options passed from JNI - shared between blocking and async scanners +pub(crate) struct ScannerOptions<'a> { + pub fragment_ids_obj: JObject<'a>, + pub columns_obj: JObject<'a>, + pub substrait_filter_obj: JObject<'a>, + pub filter_obj: JObject<'a>, + pub batch_size_obj: JObject<'a>, + pub limit_obj: JObject<'a>, + pub offset_obj: JObject<'a>, + pub query_obj: JObject<'a>, + pub fts_query_obj: JObject<'a>, + pub prefilter: jboolean, + pub with_row_id: jboolean, + pub with_row_address: jboolean, + pub batch_readahead: jint, + pub column_orderings: JObject<'a>, + pub use_scalar_index: jboolean, + pub substrait_aggregate_obj: JObject<'a>, } -#[allow(clippy::too_many_arguments)] -fn inner_create_scanner<'local>( - env: &mut JNIEnv<'local>, - jdataset: JObject, - fragment_ids_obj: JObject, - columns_obj: JObject, - substrait_filter_obj: JObject, - filter_obj: JObject, - batch_size_obj: JObject, - limit_obj: JObject, - offset_obj: JObject, - query_obj: JObject, - fts_query_obj: JObject, - prefilter: jboolean, - with_row_id: jboolean, - with_row_address: jboolean, - batch_readahead: jint, - column_orderings: JObject, - use_scalar_index: jboolean, - substrait_aggregate_obj: JObject, -) -> Result> { - let fragment_ids_opt = env.get_ints_opt(&fragment_ids_obj)?; - let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; - - let mut scanner = dataset_guard.inner.scan(); +/// Build a scanner with options applied - shared by blocking and async scanners +pub(crate) fn build_scanner_with_options<'a>( + env: &mut JNIEnv<'a>, + dataset: &lance::Dataset, + options: ScannerOptions<'a>, +) -> Result { + let mut scanner = dataset.scan(); // handle fragment_ids + let fragment_ids_opt = env.get_ints_opt(&options.fragment_ids_obj)?; if let Some(fragment_ids) = fragment_ids_opt { let mut fragments = Vec::with_capacity(fragment_ids.len()); for fragment_id in fragment_ids { - let Some(fragment) = dataset_guard.inner.get_fragment(fragment_id as usize) else { + let Some(fragment) = dataset.get_fragment(fragment_id as usize) else { return Err(Error::input_error(format!( "Fragment {fragment_id} not found" ))); @@ -283,49 +244,48 @@ fn inner_create_scanner<'local>( } scanner.with_fragments(fragments); } - drop(dataset_guard); - let columns_opt = env.get_strings_opt(&columns_obj)?; + let columns_opt = env.get_strings_opt(&options.columns_obj)?; if let Some(columns) = columns_opt { scanner.project(&columns)?; }; - let substrait_opt = env.get_bytes_opt(&substrait_filter_obj)?; + let substrait_opt = env.get_bytes_opt(&options.substrait_filter_obj)?; if let Some(substrait) = substrait_opt { RT.block_on(async { scanner.filter_substrait(substrait) })?; } - let filter_opt = env.get_string_opt(&filter_obj)?; + let filter_opt = env.get_string_opt(&options.filter_obj)?; if let Some(filter) = filter_opt { scanner.filter(filter.as_str())?; } - let batch_size_opt = env.get_long_opt(&batch_size_obj)?; + let batch_size_opt = env.get_long_opt(&options.batch_size_obj)?; if let Some(batch_size) = batch_size_opt { scanner.batch_size(batch_size as usize); } - let limit_opt = env.get_long_opt(&limit_obj)?; - let offset_opt = env.get_long_opt(&offset_obj)?; + let limit_opt = env.get_long_opt(&options.limit_obj)?; + let offset_opt = env.get_long_opt(&options.offset_obj)?; scanner .limit(limit_opt, offset_opt) .map_err(|err| Error::input_error(err.to_string()))?; - if with_row_id == JNI_TRUE { + if options.with_row_id == JNI_TRUE { scanner.with_row_id(); } - if with_row_address == JNI_TRUE { + if options.with_row_address == JNI_TRUE { scanner.with_row_address(); } - if prefilter == JNI_TRUE { + if options.prefilter == JNI_TRUE { scanner.prefilter(true); } - scanner.use_scalar_index(use_scalar_index == JNI_TRUE); + scanner.use_scalar_index(options.use_scalar_index == JNI_TRUE); - env.get_optional(&query_obj, |env, java_obj| { + env.get_optional(&options.query_obj, |env, java_obj| { // Set column and key for nearest search let column = env.get_string_from_method(&java_obj, "getColumn")?; let key_array = env.get_vec_f32_from_method(&java_obj, "getKey")?; @@ -363,16 +323,16 @@ fn inner_create_scanner<'local>( Ok(()) })?; - env.get_optional(&fts_query_obj, |env, java_obj| { + env.get_optional(&options.fts_query_obj, |env, java_obj| { let fts_query = build_full_text_search_query(env, java_obj)?; let full_text_query = FullTextSearchQuery::new_query(fts_query); scanner.full_text_search(full_text_query)?; Ok(()) })?; - scanner.batch_readahead(batch_readahead as usize); + scanner.batch_readahead(options.batch_readahead as usize); - env.get_optional(&column_orderings, |env, java_obj| { + env.get_optional(&options.column_orderings, |env, java_obj| { let list = env.get_list(&java_obj)?; let mut iter = list.iter(env)?; let mut results = Vec::with_capacity(list.size(env)? as usize); @@ -391,11 +351,111 @@ fn inner_create_scanner<'local>( Ok(()) })?; - let substrait_aggregate_opt = env.get_bytes_opt(&substrait_aggregate_obj)?; + let substrait_aggregate_opt = env.get_bytes_opt(&options.substrait_aggregate_obj)?; if let Some(substrait_aggregate) = substrait_aggregate_opt { scanner.aggregate(AggregateExpr::substrait(substrait_aggregate))?; } + Ok(scanner) +} + +/////////////////// +// Write Methods // +/////////////////// +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_ipc_LanceScanner_createScanner<'local>( + mut env: JNIEnv<'local>, + _reader: JObject<'local>, + jdataset: JObject<'local>, + fragment_ids_obj: JObject<'local>, // Optional> + columns_obj: JObject<'local>, // Optional> + substrait_filter_obj: JObject<'local>, // Optional + filter_obj: JObject<'local>, // Optional + batch_size_obj: JObject<'local>, // Optional + limit_obj: JObject<'local>, // Optional + offset_obj: JObject<'local>, // Optional + query_obj: JObject<'local>, // Optional + fts_query_obj: JObject<'local>, // Optional + prefilter: jboolean, // boolean + with_row_id: jboolean, // boolean + with_row_address: jboolean, // boolean + batch_readahead: jint, // int + column_orderings: JObject<'local>, // Optional> + use_scalar_index: jboolean, // boolean + substrait_aggregate_obj: JObject<'local>, // Optional +) -> JObject<'local> { + ok_or_throw!( + env, + inner_create_scanner( + &mut env, + jdataset, + fragment_ids_obj, + columns_obj, + substrait_filter_obj, + filter_obj, + batch_size_obj, + limit_obj, + offset_obj, + query_obj, + fts_query_obj, + prefilter, + with_row_id, + with_row_address, + batch_readahead, + column_orderings, + use_scalar_index, + substrait_aggregate_obj + ) + ) +} + +#[allow(clippy::too_many_arguments)] +fn inner_create_scanner<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject<'local>, + fragment_ids_obj: JObject<'local>, + columns_obj: JObject<'local>, + substrait_filter_obj: JObject<'local>, + filter_obj: JObject<'local>, + batch_size_obj: JObject<'local>, + limit_obj: JObject<'local>, + offset_obj: JObject<'local>, + query_obj: JObject<'local>, + fts_query_obj: JObject<'local>, + prefilter: jboolean, + with_row_id: jboolean, + with_row_address: jboolean, + batch_readahead: jint, + column_orderings: JObject<'local>, + use_scalar_index: jboolean, + substrait_aggregate_obj: JObject<'local>, +) -> Result> { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; + let dataset = dataset_guard.inner.clone(); + drop(dataset_guard); + + let options = ScannerOptions { + fragment_ids_obj, + columns_obj, + substrait_filter_obj, + filter_obj, + batch_size_obj, + limit_obj, + offset_obj, + query_obj, + fts_query_obj, + prefilter, + with_row_id, + with_row_address, + batch_readahead, + column_orderings, + use_scalar_index, + substrait_aggregate_obj, + }; + + let scanner = build_scanner_with_options(env, &dataset, options)?; + let scanner = BlockingScanner::create(scanner); scanner.into_java(env) } diff --git a/java/lance-jni/src/dispatcher.rs b/java/lance-jni/src/dispatcher.rs new file mode 100644 index 00000000000..a5efadc8cea --- /dev/null +++ b/java/lance-jni/src/dispatcher.rs @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use jni::JavaVM; +use jni::objects::GlobalRef; +use std::sync::{Arc, OnceLock}; +use tokio::sync::mpsc; + +/// Message sent from Tokio tasks to the dispatcher thread +pub struct DispatcherMessage { + pub scanner_global_ref: GlobalRef, + pub task_id: u64, + pub result: Result, // Ok(stream_ptr) or Err(error_msg) +} + +/// Global dispatcher instance initialized in JNI_OnLoad +pub static DISPATCHER: OnceLock> = OnceLock::new(); + +/// Dispatcher manages a persistent JNI thread for completing Java futures +#[derive(Debug)] +pub struct Dispatcher { + tx: mpsc::UnboundedSender, +} + +impl Dispatcher { + /// Initialize the dispatcher with a persistent JNI thread + pub fn initialize(jvm: Arc) -> Arc { + let (tx, mut rx) = mpsc::unbounded_channel::(); + + // Spawn persistent dispatcher thread + std::thread::Builder::new() + .name("lance-jni-dispatcher".to_string()) + .spawn(move || { + // Attach ONCE and never detach - this is the key optimization + let mut env = jvm + .attach_current_thread_permanently() + .expect("Failed to attach dispatcher to JVM"); + + log::info!("JNI dispatcher thread started"); + + // Cache method IDs for completeTask and failTask + let async_scanner_class = env + .find_class("org/lance/ipc/AsyncScanner") + .expect("AsyncScanner class not found"); + let complete_method = env + .get_method_id(&async_scanner_class, "completeTask", "(JJ)V") + .expect("completeTask method not found"); + let fail_method = env + .get_method_id(&async_scanner_class, "failTask", "(JLjava/lang/String;)V") + .expect("failTask method not found"); + + // Event loop: block waiting for completions + while let Some(msg) = rx.blocking_recv() { + let scanner_obj = msg.scanner_global_ref.as_obj(); + + match msg.result { + Err(error) => { + handle_error(&mut env, scanner_obj, fail_method, msg.task_id, &error) + } + Ok(result_ptr) => handle_success( + &mut env, + scanner_obj, + complete_method, + msg.task_id, + result_ptr, + ), + } + } + + log::info!("JNI dispatcher thread shutting down"); + }) + .expect("Failed to spawn dispatcher thread"); + + Arc::new(Self { tx }) + } + + /// Send a completion message to the dispatcher + pub fn send(&self, msg: DispatcherMessage) -> std::result::Result<(), String> { + self.tx + .send(msg) + .map_err(|e| format!("Failed to send message to dispatcher: {}", e)) + } +} + +/// Handle error completion by calling failTask on Java side +fn handle_error( + env: &mut jni::JNIEnv, + scanner_obj: &jni::objects::JObject, + fail_method: jni::objects::JMethodID, + task_id: u64, + error: &str, +) { + let error_jstr = match env.new_string(error) { + Ok(s) => s, + Err(e) => { + log::error!("Failed to create JString for error: {:?}", e); + let _ = env.exception_clear(); + return; + } + }; + + let result = unsafe { + env.call_method_unchecked( + scanner_obj, + fail_method, + jni::signature::ReturnType::Primitive(jni::signature::Primitive::Void), + &[ + jni::sys::jvalue { j: task_id as i64 }, + jni::sys::jvalue { + l: error_jstr.as_raw(), + }, + ], + ) + }; + + if let Err(e) = result { + log::error!("Failed to call failTask: {:?}", e); + // Clear any pending JNI exception to protect the dispatcher loop + let _ = env.exception_clear(); + } +} + +/// Handle success completion by calling completeTask on Java side +fn handle_success( + env: &mut jni::JNIEnv, + scanner_obj: &jni::objects::JObject, + complete_method: jni::objects::JMethodID, + task_id: u64, + result_ptr: i64, +) { + let result = unsafe { + env.call_method_unchecked( + scanner_obj, + complete_method, + jni::signature::ReturnType::Primitive(jni::signature::Primitive::Void), + &[ + jni::sys::jvalue { j: task_id as i64 }, + jni::sys::jvalue { j: result_ptr }, + ], + ) + }; + + if let Err(e) = result { + log::error!("Failed to call completeTask: {:?}", e); + // Clear any pending JNI exception to protect the dispatcher loop + let _ = env.exception_clear(); + // Clean up the FFI stream since Java won't receive it + unsafe { + drop(Box::from_raw( + result_ptr as *mut arrow::ffi_stream::FFI_ArrowArrayStream, + )); + } + log::debug!( + "Cleaned up FFI stream pointer for task {} after completeTask failure", + task_id + ); + } +} diff --git a/java/lance-jni/src/lib.rs b/java/lance-jni/src/lib.rs index 53ce125aca8..90be9b3ef80 100644 --- a/java/lance-jni/src/lib.rs +++ b/java/lance-jni/src/lib.rs @@ -39,10 +39,12 @@ macro_rules! ok_or_throw_with_return { }; } +mod async_scanner; mod blocking_blob; mod blocking_dataset; mod blocking_scanner; mod delta; +mod dispatcher; pub mod error; pub mod ffi; mod file_reader; @@ -56,6 +58,7 @@ mod schema; mod session; mod sql; mod storage_options; +mod task_tracker; pub mod traits; mod transaction; pub mod utils; @@ -151,3 +154,23 @@ pub extern "system" fn Java_org_lance_JniLoader_initLanceLogger() { log::set_max_level(max_level); // todo: add tracing } + +/// JNI_OnLoad - Called when the JVM loads the native library +/// Initializes the global dispatcher for async operations +#[unsafe(no_mangle)] +pub extern "system" fn JNI_OnLoad( + vm: jni::JavaVM, + _reserved: *mut std::ffi::c_void, +) -> jni::sys::jint { + let jvm_arc = Arc::new(vm); + + // Initialize global dispatcher with persistent thread + let dispatcher = dispatcher::Dispatcher::initialize(jvm_arc); + + // Set the global DISPATCHER (will panic if called more than once) + dispatcher::DISPATCHER + .set(dispatcher) + .expect("Dispatcher already initialized"); + + jni::sys::JNI_VERSION_1_8 +} diff --git a/java/lance-jni/src/task_tracker.rs b/java/lance-jni/src/task_tracker.rs new file mode 100644 index 00000000000..bc9d9b0519f --- /dev/null +++ b/java/lance-jni/src/task_tracker.rs @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use jni::objects::GlobalRef; +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; +use tokio::sync::RwLock; + +pub type TaskId = u64; + +/// Information about an in-flight async task +pub struct TaskInfo { + #[allow(dead_code)] // Used for cleanup when task is cancelled + pub scanner_global_ref: GlobalRef, + pub cancel_handle: tokio::task::JoinHandle<()>, +} + +/// Thread-safe task registry for managing async scan operations +pub struct TaskTracker { + tasks: Arc>>, +} + +impl TaskTracker { + pub fn new() -> Self { + Self { + tasks: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Register a new task + pub async fn register(&self, task_id: TaskId, info: TaskInfo) { + let mut tasks = self.tasks.write().await; + tasks.insert(task_id, info); + } + + /// Update the cancel handle for a task (used in two-phase registration) + /// Returns true if task was found and updated, false if task already completed + pub async fn update_handle( + &self, + task_id: TaskId, + cancel_handle: tokio::task::JoinHandle<()>, + ) -> bool { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(&task_id) { + // Abort the old placeholder handle and replace with real handle + task_info.cancel_handle.abort(); + task_info.cancel_handle = cancel_handle; + true + } else { + // Task already completed before we could update - abort the handle + cancel_handle.abort(); + false + } + } + + /// Mark a task as complete and return its info + pub async fn complete(&self, task_id: TaskId) -> Option { + let mut tasks = self.tasks.write().await; + tasks.remove(&task_id) + } + + /// Cancel a task by ID + pub async fn cancel(&self, task_id: TaskId) { + let info = { + let mut tasks = self.tasks.write().await; + tasks.remove(&task_id) + }; + + if let Some(info) = info { + info.cancel_handle.abort(); + } + } + + // TODO: Implement timeout-based cleanup for defense-in-depth + // + // While TaskCleanupGuard (RAII pattern) ensures cleanup in normal and panic cases, + // a background cleanup task provides additional safety against edge cases: + // + // Proposed implementation: + // ``` + // pub async fn cleanup_stale_tasks(&self, max_age: Duration) { + // let mut tasks = self.tasks.write().await; + // let now = Instant::now(); + // tasks.retain(|task_id, info| { + // let is_finished = info.cancel_handle.is_finished(); + // let is_stale = info.created_at.elapsed() > max_age; + // + // if is_finished || is_stale { + // log::warn!("Cleaning up stale/finished task {}", task_id); + // false // remove from HashMap + // } else { + // true // keep in HashMap + // } + // }); + // } + // + // // In JNI_OnLoad or module initialization: + // RT.spawn(async { + // loop { + // tokio::time::sleep(Duration::from_secs(60)).await; + // TASK_TRACKER.cleanup_stale_tasks(Duration::from_secs(300)).await; + // } + // }); + // ``` + // + // This would require adding `created_at: Instant` field to TaskInfo. +} + +/// Global task tracker instance +pub static TASK_TRACKER: LazyLock = LazyLock::new(TaskTracker::new); diff --git a/java/src/main/java/org/lance/ipc/AsyncScanner.java b/java/src/main/java/org/lance/ipc/AsyncScanner.java new file mode 100644 index 00000000000..59ecdebd750 --- /dev/null +++ b/java/src/main/java/org/lance/ipc/AsyncScanner.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import org.lance.Dataset; +import org.lance.LockManager; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Async scanner that provides non-blocking scan operations via CompletableFuture. + * + *

This scanner spawns async I/O tasks in Rust and completes Java futures when data is ready, + * preventing thread starvation in Java query engines like Presto/Trino. + */ +public class AsyncScanner implements AutoCloseable { + private static final AtomicLong TASK_ID_GENERATOR = new AtomicLong(1); + private final ConcurrentHashMap> pendingTasks = + new ConcurrentHashMap<>(); + + private BufferAllocator allocator; + private final LockManager lockManager = new LockManager(); + private long nativeAsyncScannerHandle; + + private AsyncScanner() {} + + /** + * Create an AsyncScanner. + * + * @param dataset the dataset to scan + * @param options scan options + * @param allocator allocator + * @return an AsyncScanner + */ + public static AsyncScanner create( + Dataset dataset, ScanOptions options, BufferAllocator allocator) { + Preconditions.checkNotNull(dataset); + Preconditions.checkNotNull(options); + Preconditions.checkNotNull(allocator); + AsyncScanner scanner = + createAsyncScanner( + dataset, + options.getFragmentIds(), + options.getColumns(), + options.getSubstraitFilter(), + options.getFilter(), + options.getBatchSize(), + options.getLimit(), + options.getOffset(), + options.getNearest(), + options.getFullTextQuery(), + options.isPrefilter(), + options.isWithRowId(), + options.isWithRowAddress(), + options.getBatchReadahead(), + options.getColumnOrderings(), + options.isUseScalarIndex(), + options.getSubstraitAggregate()); + scanner.allocator = allocator; + return scanner; + } + + static native AsyncScanner createAsyncScanner( + Dataset dataset, + Optional> fragmentIds, + Optional> columns, + Optional substraitFilter, + Optional filter, + Optional batchSize, + Optional limit, + Optional offset, + Optional query, + Optional fullTextQuery, + boolean prefilter, + boolean withRowId, + boolean withRowAddress, + int batchReadahead, + Optional> columnOrderings, + boolean useScalarIndex, + Optional substraitAggregate); + + /** + * Asynchronously scan batches and return a CompletableFuture. + * + * @return a CompletableFuture that will be completed with an ArrowReader when data is ready + */ + public CompletableFuture scanBatchesAsync() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + if (nativeAsyncScannerHandle == 0) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalStateException("Scanner is closed")); + return future; + } + + long taskId = TASK_ID_GENERATOR.getAndIncrement(); + CompletableFuture streamPtrFuture = new CompletableFuture<>(); + pendingTasks.put(taskId, streamPtrFuture); + + // Start async scan in Rust + nativeStartScan(taskId); + + // Transform stream pointer to ArrowReader + return streamPtrFuture.handle( + (streamPtr, error) -> { + pendingTasks.remove(taskId); + + if (error != null) { + throw new RuntimeException("Scan failed", error); + } + + if (streamPtr < 0) { + throw new RuntimeException("Native scan error"); + } + + try { + ArrowArrayStream stream = ArrowArrayStream.wrap(streamPtr); + return Data.importArrayStream(allocator, stream); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + /** Called by Rust dispatcher thread via JNI to complete a task successfully. */ + private void completeTask(long taskId, long resultPtr) { + CompletableFuture future = pendingTasks.get(taskId); + if (future != null) { + future.complete(resultPtr); + } + } + + /** Called by Rust dispatcher thread via JNI to fail a task with an error. */ + private void failTask(long taskId, String errorMessage) { + CompletableFuture future = pendingTasks.get(taskId); + if (future != null) { + future.completeExceptionally(new RuntimeException(errorMessage)); + } + } + + private native void nativeStartScan(long taskId); + + /** + * Get schema (synchronous operation). + * + * @return the schema + */ + public Schema schema() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeAsyncScannerHandle != 0, "Scanner is closed"); + try (ArrowSchema ffiSchema = ArrowSchema.allocateNew(allocator)) { + importFfiSchema(ffiSchema.memoryAddress()); + return Data.importSchema(allocator, ffiSchema, null); + } + } + } + + private native void importFfiSchema(long arrowSchemaMemoryAddress); + + /** + * Closes this scanner and releases any system resources associated with it. If the scanner is + * already closed, then invoking this method has no effect. + */ + @Override + public void close() throws Exception { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + if (nativeAsyncScannerHandle != 0) { + // Cancel all pending tasks + for (Long taskId : pendingTasks.keySet()) { + nativeCancelTask(taskId); + } + pendingTasks.clear(); + + releaseNativeScanner(); + nativeAsyncScannerHandle = 0; + } + } + } + + private native void nativeCancelTask(long taskId); + + /** Native method to release the async scanner resources. */ + private native void releaseNativeScanner(); +} diff --git a/java/src/test/java/org/lance/AsyncScannerTest.java b/java/src/test/java/org/lance/AsyncScannerTest.java new file mode 100644 index 00000000000..98f46887b64 --- /dev/null +++ b/java/src/test/java/org/lance/AsyncScannerTest.java @@ -0,0 +1,311 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance; + +import org.lance.ipc.AsyncScanner; +import org.lance.ipc.ScanOptions; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Example tests demonstrating AsyncScanner usage with CompletableFuture-based API. + * + *

AsyncScanner provides non-blocking scan operations that prevent thread starvation in Java + * query engines like Presto/Trino. + */ +public class AsyncScannerTest { + private static Dataset dataset; + + @BeforeAll + static void setup() {} + + @AfterAll + static void tearDown() { + if (dataset != null) { + dataset.close(); + } + } + + /** + * Example 1: Basic async scan with CompletableFuture. + * + *

This shows the simplest usage - create an async scanner and wait for results. + */ + @Test + void testBasicAsyncScan(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_basic").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 40; + + try (Dataset dataset = testDataset.write(1, totalRows)) { + // Create AsyncScanner with same options as LanceScanner + ScanOptions options = new ScanOptions.Builder().batchSize(20L).build(); + + try (AsyncScanner scanner = AsyncScanner.create(dataset, options, allocator)) { + // Start async scan - returns CompletableFuture + CompletableFuture future = scanner.scanBatchesAsync(); + + // Wait for result (blocks current thread, but doesn't block Rust I/O threads) + ArrowReader reader = future.get(10, TimeUnit.SECONDS); + assertNotNull(reader); + + // Read all batches + int rowCount = 0; + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + rowCount += root.getRowCount(); + } + + assertEquals(totalRows, rowCount, "Should read all rows"); + reader.close(); + } + } + } + } + + /** + * Example 2: Async scan with filter. + * + *

Shows how to use async scanner with SQL-like filters. + */ + @Test + void testAsyncScanWithFilter(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_filter").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + + try (Dataset dataset = testDataset.write(1, 40)) { + // Scan with filter - only rows where id < 20 + ScanOptions options = new ScanOptions.Builder().filter("id < 20").build(); + + try (AsyncScanner scanner = AsyncScanner.create(dataset, options, allocator)) { + CompletableFuture future = scanner.scanBatchesAsync(); + + ArrowReader reader = future.get(10, TimeUnit.SECONDS); + int rowCount = 0; + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + rowCount += root.getRowCount(); + } + + assertEquals(20, rowCount, "Should read only filtered rows"); + reader.close(); + } + } + } + } + + /** + * Example 3: Multiple concurrent async scans. + * + *

Shows how to run multiple scans in parallel without blocking threads. This is the key + * benefit for query engines like Presto/Trino. + */ + @Test + void testConcurrentAsyncScans(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_concurrent").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + int totalRows = 100; + + try (Dataset dataset = testDataset.write(1, totalRows)) { + // Create 5 concurrent scans with different filters + List> futures = new ArrayList<>(); + + for (int i = 0; i < 5; i++) { + final int rangeStart = i * 20; + final int rangeEnd = rangeStart + 20; + String filter = String.format("id >= %d AND id < %d", rangeStart, rangeEnd); + + ScanOptions options = new ScanOptions.Builder().filter(filter).build(); + + AsyncScanner scanner = AsyncScanner.create(dataset, options, allocator); + + // Chain async operations: scan -> read -> count rows -> cleanup + CompletableFuture future = + scanner + .scanBatchesAsync() + .thenApply( + reader -> { + try { + int count = 0; + while (reader.loadNextBatch()) { + count += reader.getVectorSchemaRoot().getRowCount(); + } + reader.close(); + scanner.close(); + return count; + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + futures.add(future); + } + + // Wait for all scans to complete + CompletableFuture allDone = + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + allDone.get(30, TimeUnit.SECONDS); + + // Verify each scan read the expected number of rows + for (CompletableFuture future : futures) { + assertEquals(20, future.get(), "Each range should have 20 rows"); + } + } + } + } + + /** + * Example 4: Async scan with error handling. + * + *

Shows how to handle errors in async operations. + */ + @Test + void testAsyncScanErrorHandling(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_error").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + + try (Dataset dataset = testDataset.write(1, 40)) { + ScanOptions options = new ScanOptions.Builder().build(); + + try (AsyncScanner scanner = AsyncScanner.create(dataset, options, allocator)) { + CompletableFuture future = + scanner + .scanBatchesAsync() + .whenComplete( + (reader, error) -> { + if (error != null) { + // Handle error + System.err.println("Scan failed: " + error.getMessage()); + } else { + // Process successful result + assertNotNull(reader); + } + }); + + ArrowReader reader = future.get(10, TimeUnit.SECONDS); + assertNotNull(reader); + reader.close(); + } + } + } + } + + /** + * Example 5: Async scan with projection (column selection). + * + *

Shows how to select specific columns for better performance. + */ + @Test + void testAsyncScanWithProjection(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_projection").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + + try (Dataset dataset = testDataset.write(1, 40)) { + // Select only "id" column + ScanOptions options = new ScanOptions.Builder().columns(List.of("id")).build(); + + try (AsyncScanner scanner = AsyncScanner.create(dataset, options, allocator)) { + CompletableFuture future = scanner.scanBatchesAsync(); + + ArrowReader reader = future.get(10, TimeUnit.SECONDS); + + // Verify schema has only one column + assertEquals(1, reader.getVectorSchemaRoot().getFieldVectors().size()); + assertEquals("id", reader.getVectorSchemaRoot().getVector(0).getName()); + + reader.close(); + } + } + } + } + + /** + * Example 6: Using thenCompose for sequential async operations. + * + *

Shows how to chain multiple async operations together. + */ + @Test + void testAsyncChaining(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_chaining").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + + try (Dataset dataset = testDataset.write(1, 40)) { + ScanOptions options = new ScanOptions.Builder().build(); + + try (AsyncScanner scanner = AsyncScanner.create(dataset, options, allocator)) { + // Chain operations: scan -> read first batch -> extract values + CompletableFuture> future = + scanner + .scanBatchesAsync() + .thenApply( + reader -> { + try { + List values = new ArrayList<>(); + if (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + IntVector idVector = (IntVector) root.getVector("id"); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(idVector.get(i)); + } + } + reader.close(); + return values; + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + List values = future.get(10, TimeUnit.SECONDS); + assertTrue(values.size() > 0, "Should read some values"); + } + } + } + } +}