diff --git a/Cargo.lock b/Cargo.lock index cfcc4899c96..5a8a86dd57e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4886,6 +4886,8 @@ dependencies = [ "log", "pin-project", "prost", + "prost-build", + "protobuf-src", "snafu", "tokio", "tracing", diff --git a/protos/filtered_read.proto b/protos/filtered_read.proto new file mode 100644 index 00000000000..d81f6b02cfb --- /dev/null +++ b/protos/filtered_read.proto @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +syntax = "proto3"; + +package lance.datafusion; + +import "table_identifier.proto"; + +message U64Range { + uint64 start = 1; + uint64 end = 2; +} + +message ProjectionProto { + repeated int32 field_ids = 1; + bool with_row_id = 2; + bool with_row_addr = 3; + bool with_row_last_updated_at_version = 4; + bool with_row_created_at_version = 5; + BlobHandlingProto blob_handling = 6; +} + +message BlobHandlingProto { + oneof mode { + // All blobs read as binary + bool all_binary = 1; + // Blobs as descriptions, other binary as binary (default) + bool blobs_descriptions = 2; + // All binary columns as descriptions + bool all_descriptions = 3; + // Specific blobs read as binary, rest as descriptions (non-blob binary stays binary) + FieldIdSet some_blobs_binary = 4; + // Specific columns as binary, all other binary as descriptions + FieldIdSet some_binary = 5; + } +} + +message FieldIdSet { + repeated uint32 field_ids = 1; +} + +message FilteredReadThreadingModeProto { + oneof mode { + uint64 one_partition_multiple_threads = 1; + uint64 multiple_partitions = 2; + } +} + +// Serializable form of FilteredReadOptions. +message FilteredReadOptionsProto { + optional U64Range scan_range_before_filter = 1; + optional U64Range scan_range_after_filter = 2; + bool with_deleted_rows = 3; + optional uint32 batch_size = 4; + optional uint64 fragment_readahead = 5; + repeated uint64 fragment_ids = 6; + ProjectionProto projection = 7; + optional bytes refine_filter_substrait = 8; + optional bytes full_filter_substrait = 9; + FilteredReadThreadingModeProto threading_mode = 10; + optional uint64 io_buffer_size_bytes = 11; + // Arrow IPC schema for decoding Substrait filters (may be wider than projection). + optional bytes filter_schema_ipc = 12; +} + +// Serializable form of FilteredReadPlan (planned/distributed mode). +// RowAddrTreeMap serialized via its built-in serialize_into/deserialize_from. +// Per-fragment filters are Substrait-encoded and deduplicated. +message FilteredReadPlanProto { + bytes row_addr_tree_map = 1; + optional U64Range scan_range_after_filter = 2; + // Arrow IPC schema for decoding Substrait filters (matches the schema used at encode time). + optional bytes filter_schema_ipc = 3; + // Per-fragment filter mapping. Key is fragment id, value is a list index into + // filter_expressions. Multiple fragments can share the same list index when + // they have the same filter, avoiding duplicate Substrait encoding. + map fragment_filter_ids = 4; + // Deduplicated Substrait-encoded filter expressions. Each entry is referenced + // by one or more values in fragment_filter_ids. + repeated bytes filter_expressions = 5; +} + +// Top-level wrapper for FilteredReadExec serialization. +message FilteredReadExecProto { + TableIdentifier table = 1; + FilteredReadOptionsProto options = 2; + // FilteredRead has two modes + // Plan-then-execute (distributed): The planner creates a FilteredReadPlan and sends it to a remote executor. + // Plan-and-execute (local): The executor creates the plan itself at execution time. + optional FilteredReadPlanProto plan = 3; + // Note: FilteredReadExec.index_input (child ExecutionPlan) is NOT serialized here. + // DataFusion's PhysicalExtensionCodec handles child plans automatically: it walks + // the plan tree via children() / with_new_children(), serializes each node, and + // passes deserialized children back as the `inputs` parameter in try_decode. + // This means any ExecutionPlan in the tree (including index_input) must also + // implement try_encode/try_decode in the PhysicalExtensionCodec. + // TODO: implement serialize/deserialize for lance-specific index input ExecutionPlans. +} diff --git a/protos/table_identifier.proto b/protos/table_identifier.proto new file mode 100644 index 00000000000..f7e70b21c84 --- /dev/null +++ b/protos/table_identifier.proto @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +syntax = "proto3"; + +package lance.datafusion; + +// Identifies a Lance dataset for remote reconstruction. +// +// Two modes: +// 1. uri + serialized_manifest (fast): remote executor skips manifest read. +// 2. uri + version + etag (lightweight): remote executor loads manifest from storage. +message TableIdentifier { + string uri = 1; + uint64 version = 2; + optional string manifest_etag = 3; + optional bytes serialized_manifest = 4; +} diff --git a/python/Cargo.lock b/python/Cargo.lock index 090400fd4b3..aa81aee745c 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -4048,6 +4048,7 @@ dependencies = [ "log", "pin-project", "prost", + "prost-build", "snafu", "tokio", "tracing", diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index 528437341ca..65e44e0d38d 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -1024,7 +1024,7 @@ impl Projectable for Schema { } /// Specifies how to handle blob columns when projecting -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq)] pub enum BlobHandling { /// Read all blobs as binary AllBinary, diff --git a/rust/lance-datafusion/Cargo.toml b/rust/lance-datafusion/Cargo.toml index 47315ce4712..9a561fa9b88 100644 --- a/rust/lance-datafusion/Cargo.toml +++ b/rust/lance-datafusion/Cargo.toml @@ -36,11 +36,16 @@ snafu.workspace = true tokio.workspace = true tracing.workspace = true +[build-dependencies] +prost-build.workspace = true +protobuf-src = {version = "2.1", optional = true} + [dev-dependencies] lance-datagen.workspace = true [features] substrait = ["dep:datafusion-substrait"] +protoc = ["dep:protobuf-src"] [lints] workspace = true diff --git a/rust/lance-datafusion/build.rs b/rust/lance-datafusion/build.rs new file mode 100644 index 00000000000..b68f793fbb7 --- /dev/null +++ b/rust/lance-datafusion/build.rs @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::io::Result; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=protos"); + + #[cfg(feature = "protoc")] + // Use vendored protobuf compiler if requested. + std::env::set_var("PROTOC", protobuf_src::protoc()); + + let mut prost_build = prost_build::Config::new(); + prost_build.protoc_arg("--experimental_allow_proto3_optional"); + prost_build.enable_type_names(); + prost_build.compile_protos( + &[ + "./protos/table_identifier.proto", + "./protos/filtered_read.proto", + ], + &["./protos"], + )?; + + Ok(()) +} diff --git a/rust/lance-datafusion/protos b/rust/lance-datafusion/protos new file mode 120000 index 00000000000..69d0d0d54b0 --- /dev/null +++ b/rust/lance-datafusion/protos @@ -0,0 +1 @@ +../../protos \ No newline at end of file diff --git a/rust/lance-datafusion/src/lib.rs b/rust/lance-datafusion/src/lib.rs index 0ef51216ae8..ecc78672924 100644 --- a/rust/lance-datafusion/src/lib.rs +++ b/rust/lance-datafusion/src/lib.rs @@ -10,6 +10,17 @@ pub mod expr; pub mod logical_expr; pub mod planner; pub mod projection; +pub mod pb { + #![allow(clippy::all)] + #![allow(non_upper_case_globals)] + #![allow(non_camel_case_types)] + #![allow(non_snake_case)] + #![allow(unused)] + #![allow(improper_ctypes)] + #![allow(clippy::upper_case_acronyms)] + #![allow(clippy::use_self)] + include!(concat!(env!("OUT_DIR"), "/lance.datafusion.rs")); +} pub mod spill; pub mod sql; #[cfg(feature = "substrait")] diff --git a/rust/lance-datafusion/src/substrait.rs b/rust/lance-datafusion/src/substrait.rs index 295da39d09c..db0cc261e4f 100644 --- a/rust/lance-datafusion/src/substrait.rs +++ b/rust/lance-datafusion/src/substrait.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use arrow_schema::Schema as ArrowSchema; +use arrow_schema::{DataType, Schema as ArrowSchema}; use datafusion::{execution::SessionState, logical_expr::Expr}; use crate::aggregate::Aggregate; @@ -27,6 +27,33 @@ use snafu::location; use std::collections::HashMap; use std::sync::Arc; +/// FixedSizeList has no Substrait producer support in datafusion-substrait. +/// Other unsupported types (Null, Float16) are encoded as UserDefined and +/// handled by `remove_extension_types` on the decode side. +fn is_substrait_compatible(data_type: &DataType) -> bool { + match data_type { + DataType::FixedSizeList(_, _) => false, + DataType::List(inner) => is_substrait_compatible(inner.data_type()), + DataType::Struct(fields) => fields + .iter() + .all(|f| is_substrait_compatible(f.data_type())), + _ => true, + } +} + +/// Removes top-level fields that contain data types that the Substrait +/// producer cannot encode (currently only FixedSizeList). +pub fn prune_schema_for_substrait(schema: &ArrowSchema) -> ArrowSchema { + ArrowSchema::new( + schema + .fields() + .iter() + .filter(|f| is_substrait_compatible(f.data_type())) + .cloned() + .collect::>(), + ) +} + /// Convert a DF Expr into a Substrait ExtendedExpressions message /// /// The schema needs to contain all of the fields that are referenced in the expression. @@ -824,6 +851,44 @@ mod tests { assert_substrait_roundtrip(schema, id_filter("test-id")).await; } + #[tokio::test] + async fn test_substrait_roundtrip_with_null_and_float16_columns() { + // Float16 and Null are encoded as UserDefined types in Substrait. + // The decode side (remove_extension_types) strips them and remaps + // field references, so filters on other columns still work. + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("embedding", DataType::Float16, true), + Field::new("empty", DataType::Null, true), + Field::new("name", DataType::Utf8, true), + ]); + + assert_substrait_roundtrip(schema, id_filter("test-id")).await; + } + + #[tokio::test] + async fn test_substrait_roundtrip_with_fixed_size_list_column() { + // FixedSizeList has no Substrait producer support, so it must be + // pruned from the schema before encoding. Verify that a schema with + // FSL columns works when the filter references a different column. + use crate::substrait::prune_schema_for_substrait; + + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128), + true, + ), + Field::new("name", DataType::Utf8, true), + ]); + + // Encoding with the full schema would fail, but pruning removes the FSL column + let pruned = prune_schema_for_substrait(&schema); + assert_eq!(pruned.fields().len(), 2); // id and name only + assert_substrait_roundtrip(pruned, id_filter("test-id")).await; + } + // ==================== Aggregate parsing tests ==================== use datafusion_substrait::substrait::proto::{ diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index ca6af3f5e9f..ae62214857f 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -7,6 +7,8 @@ mod filter; pub mod filtered_read; +#[cfg(feature = "substrait")] +pub mod filtered_read_proto; pub mod fts; pub(crate) mod knn; mod optimizer; diff --git a/rust/lance/src/io/exec/filtered_read.rs b/rust/lance/src/io/exec/filtered_read.rs index c7088c1cf3e..74be5330588 100644 --- a/rust/lance/src/io/exec/filtered_read.rs +++ b/rust/lance/src/io/exec/filtered_read.rs @@ -1751,6 +1751,11 @@ impl FilteredReadExec { pub fn index_input(&self) -> Option<&Arc> { self.index_input.as_ref() } + + /// Return the pre-computed plan if one exists, without triggering initialization. + pub fn plan(&self) -> Option { + self.plan.get().map(|p| p.to_external_plan()) + } } impl DisplayAs for FilteredReadExec { diff --git a/rust/lance/src/io/exec/filtered_read_proto.rs b/rust/lance/src/io/exec/filtered_read_proto.rs new file mode 100644 index 00000000000..ee2ae8f3257 --- /dev/null +++ b/rust/lance/src/io/exec/filtered_read_proto.rs @@ -0,0 +1,860 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Protobuf serialization for [`FilteredReadExec`] and related types. +//! +//! Proto message definitions live in `lance-datafusion` (see `pb`). +//! Conversion functions live here because they need access to `FilteredReadExec` +//! and `Dataset`, which are defined in this crate. +//! +//! A datafusion `PhysicalExtensionCodec` can call these functions in `try_encode` +//! and `try_decode` to support distributed execution (planner → executor). + +use std::collections::HashMap; +use std::io::Cursor; +use std::ops::Range; +use std::sync::Arc; + +use arrow_schema::Schema as ArrowSchema; +use datafusion::execution::SessionState; +use datafusion::logical_expr::Expr; +use datafusion::physical_plan::ExecutionPlan; +use lance_core::datatypes::{BlobHandling, Projection}; +use lance_core::utils::mask::RowAddrTreeMap; +use lance_core::{Error, Result}; +use lance_datafusion::pb; +use lance_datafusion::substrait::{encode_substrait, parse_substrait, prune_schema_for_substrait}; +use lance_table::format::Fragment; +use prost::Message; +use snafu::location; + +use crate::Dataset; + +use super::filtered_read::{ + FilteredReadExec, FilteredReadOptions, FilteredReadPlan, FilteredReadThreadingMode, +}; + +// ============================================================================= +// TableIdentifier helpers (reusable by other execs) +// ============================================================================= + +/// Build a [`TableIdentifier`] from a [`Dataset`]. +/// +/// Default: lightweight mode (uri + version + etag only, no serialized manifest). +pub fn table_identifier_from_dataset(dataset: &Dataset) -> pb::TableIdentifier { + pb::TableIdentifier { + uri: dataset.uri().to_string(), + version: dataset.manifest.version, + manifest_etag: dataset.manifest_location.e_tag.clone(), + serialized_manifest: None, + } +} + +/// Build a [`TableIdentifier`] with serialized manifest bytes included. +/// +/// Fast path: remote executor skips manifest read from storage. +pub fn table_identifier_from_dataset_with_manifest(dataset: &Dataset) -> pb::TableIdentifier { + let manifest_proto = lance_table::format::pb::Manifest::from(dataset.manifest.as_ref()); + pb::TableIdentifier { + uri: dataset.uri().to_string(), + version: dataset.manifest.version, + manifest_etag: dataset.manifest_location.e_tag.clone(), + serialized_manifest: Some(manifest_proto.encode_to_vec()), + } +} + +// ============================================================================= +// FilteredReadExec <-> Proto +// ============================================================================= + +/// Convert a [`FilteredReadExec`] to proto for serialization. +/// +/// Uses `table_identifier_from_dataset` by default (no manifest bytes). +/// The caller can replace the `table` field with +/// [`table_identifier_from_dataset_with_manifest`] if desired. +pub fn filtered_read_exec_to_proto( + exec: &FilteredReadExec, + state: &SessionState, +) -> Result { + let table = table_identifier_from_dataset(exec.dataset()); + // Use the pruned dataset schema for filter encoding — filters can reference columns + // outside the projection (e.g. SELECT name WHERE age > 10), and some dataset columns + // may have types that Substrait cannot serialize (e.g. FixedSizeList, Float16). + let filter_schema = Arc::new(prune_schema_for_substrait(&exec.dataset().schema().into())); + let options = fr_options_to_proto(exec.options(), &filter_schema, state)?; + + let plan = match exec.plan() { + Some(plan) => Some(plan_to_proto(&plan, &filter_schema, state)?), + None => None, + }; + + Ok(pb::FilteredReadExecProto { + table: Some(table), + options: Some(options), + plan, + }) +} + +/// Reconstruct a [`FilteredReadExec`] from proto. +/// +/// The `dataset` must already be opened on the remote side (using +/// the `TableIdentifier` from the proto). The optional `index_input` +/// child plan is provided by DataFusion's codec via its `inputs` parameter. +pub async fn filtered_read_exec_from_proto( + proto: pb::FilteredReadExecProto, + dataset: Arc, + index_input: Option>, + state: &SessionState, +) -> Result { + let options_proto = proto.options.ok_or_else(|| Error::InvalidInput { + source: "Missing options in FilteredReadExecProto".into(), + location: location!(), + })?; + + let options = fr_options_from_proto(options_proto, &dataset, state).await?; + let exec = FilteredReadExec::try_new(dataset.clone(), options, index_input)?; + + // Apply pre-computed plan if present + if let Some(plan_proto) = proto.plan { + let plan = plan_from_proto(plan_proto, &dataset, state).await?; + exec.with_plan(plan).await + } else { + Ok(exec) + } +} + +// ============================================================================= +// FilteredReadOptions <-> Proto +// ============================================================================= + +fn fr_options_to_proto( + options: &FilteredReadOptions, + filter_schema: &Arc, + state: &SessionState, +) -> Result { + let refine_filter_substrait = match &options.refine_filter { + Some(expr) => Some(encode_substrait( + expr.clone(), + filter_schema.clone(), + state, + )?), + None => None, + }; + + let full_filter_substrait = match &options.full_filter { + Some(expr) => Some(encode_substrait( + expr.clone(), + filter_schema.clone(), + state, + )?), + None => None, + }; + + // Serialize the filter schema as Arrow IPC if we have filters + let filter_schema_ipc = if refine_filter_substrait.is_some() || full_filter_substrait.is_some() + { + Some(schema_to_bytes(filter_schema)?) + } else { + None + }; + + Ok(pb::FilteredReadOptionsProto { + scan_range_before_filter: options + .scan_range_before_filter + .as_ref() + .map(range_to_proto), + scan_range_after_filter: options.scan_range_after_filter.as_ref().map(range_to_proto), + with_deleted_rows: options.with_deleted_rows, + batch_size: options.batch_size, + fragment_readahead: options.fragment_readahead.map(|v| v as u64), + fragment_ids: options + .fragments + .as_ref() + .map(|frags| frags.iter().map(|f| f.id).collect()) + .unwrap_or_default(), + projection: Some(projection_to_proto(&options.projection)), + refine_filter_substrait, + full_filter_substrait, + threading_mode: Some(threading_mode_to_proto(&options.threading_mode)), + io_buffer_size_bytes: options.io_buffer_size_bytes, + filter_schema_ipc, + }) +} + +async fn fr_options_from_proto( + proto: pb::FilteredReadOptionsProto, + dataset: &Arc, + state: &SessionState, +) -> Result { + let projection = projection_from_proto( + proto.projection.as_ref(), + dataset.clone() as Arc, + )?; + let mut options = FilteredReadOptions::new(projection); + + // Fragments + if !proto.fragment_ids.is_empty() { + let fragments = fragments_from_proto(&proto.fragment_ids, dataset)?; + options = options.with_fragments(Arc::new(fragments)); + } + + // Scan ranges + if let Some(range) = proto.scan_range_before_filter { + options = options + .with_scan_range_before_filter(range_from_proto(&range)) + .map_err(|e| Error::Internal { + message: e.to_string(), + location: location!(), + })?; + } + if let Some(range) = proto.scan_range_after_filter { + options = options + .with_scan_range_after_filter(range_from_proto(&range)) + .map_err(|e| Error::Internal { + message: e.to_string(), + location: location!(), + })?; + } + + // Deleted rows + if proto.with_deleted_rows { + options = options.with_deleted_rows().map_err(|e| Error::Internal { + message: e.to_string(), + location: location!(), + })?; + } + + // Performance tuning + if let Some(batch_size) = proto.batch_size { + options = options.with_batch_size(batch_size); + } + if let Some(readahead) = proto.fragment_readahead { + options = options.with_fragment_readahead(readahead as usize); + } + if let Some(io_buffer) = proto.io_buffer_size_bytes { + options = options.with_io_buffer_size(io_buffer); + } + if let Some(mode) = proto.threading_mode { + options.threading_mode = threading_mode_from_proto(&mode)?; + } + + // Filters — require filter_schema_ipc when filters are present + let has_filters = + proto.refine_filter_substrait.is_some() || proto.full_filter_substrait.is_some(); + if has_filters { + let filter_schema = + schema_from_bytes(proto.filter_schema_ipc.as_ref().ok_or_else(|| { + Error::InvalidInput { + source: "missing filter_schema_ipc but filters are present".into(), + location: location!(), + } + })?)?; + + if let Some(bytes) = &proto.refine_filter_substrait { + options.refine_filter = + Some(parse_substrait(bytes, filter_schema.clone(), state).await?); + } + if let Some(bytes) = &proto.full_filter_substrait { + options.full_filter = Some(parse_substrait(bytes, filter_schema, state).await?); + } + } + + Ok(options) +} + +// ============================================================================= +// FilteredReadPlan <-> Proto +// ============================================================================= + +/// Convert a [`FilteredReadPlan`] to proto. +/// +/// Deduplicates filter expressions: many fragments often share the same `Arc`. +/// We detect sharing via `Arc::as_ptr()` and encode each unique expression only once. +pub fn plan_to_proto( + plan: &FilteredReadPlan, + filter_schema: &Arc, + state: &SessionState, +) -> Result { + let mut buf = Vec::with_capacity(plan.rows.serialized_size()); + plan.rows.serialize_into(&mut buf)?; + + // Deduplicate filter expressions by Arc pointer identity. + let mut ptr_to_id: HashMap<*const Expr, u32> = HashMap::new(); + let mut filter_expressions: Vec> = Vec::new(); + let mut fragment_filter_ids: HashMap = HashMap::new(); + + for (frag_id, expr) in &plan.filters { + let ptr = Arc::as_ptr(expr); + let id = match ptr_to_id.get(&ptr) { + Some(&id) => id, + None => { + let id = filter_expressions.len() as u32; + let encoded = + encode_substrait(expr.as_ref().clone(), filter_schema.clone(), state)?; + filter_expressions.push(encoded); + ptr_to_id.insert(ptr, id); + id + } + }; + fragment_filter_ids.insert(*frag_id, id); + } + + let filter_schema_ipc = if fragment_filter_ids.is_empty() { + None + } else { + Some(schema_to_bytes(filter_schema)?) + }; + + Ok(pb::FilteredReadPlanProto { + row_addr_tree_map: buf, + scan_range_after_filter: plan.scan_range_after_filter.as_ref().map(range_to_proto), + filter_schema_ipc, + fragment_filter_ids, + filter_expressions, + }) +} + +async fn plan_from_proto( + proto: pb::FilteredReadPlanProto, + _dataset: &Arc, + state: &SessionState, +) -> Result { + let rows = RowAddrTreeMap::deserialize_from(Cursor::new(&proto.row_addr_tree_map))?; + + let mut filters = HashMap::new(); + if !proto.fragment_filter_ids.is_empty() { + let filter_schema = + schema_from_bytes(proto.filter_schema_ipc.as_ref().ok_or_else(|| { + Error::InvalidInput { + source: "missing filter_schema_ipc but plan has filters".into(), + location: location!(), + } + })?)?; + + // Decode each unique expression once, then share via Arc. + let mut decoded: Vec> = Vec::with_capacity(proto.filter_expressions.len()); + for bytes in &proto.filter_expressions { + let expr = parse_substrait(bytes, filter_schema.clone(), state).await?; + decoded.push(Arc::new(expr)); + } + + for (frag_id, expr_id) in &proto.fragment_filter_ids { + let expr = decoded + .get(*expr_id as usize) + .ok_or_else(|| Error::InvalidInput { + source: format!( + "filter expression index {} out of bounds (have {})", + expr_id, + decoded.len() + ) + .into(), + location: location!(), + })?; + filters.insert(*frag_id, Arc::clone(expr)); + } + } + + Ok(FilteredReadPlan { + rows, + filters, + scan_range_after_filter: proto.scan_range_after_filter.map(|r| range_from_proto(&r)), + }) +} + +// ============================================================================= +// Projection <-> Proto +// ============================================================================= + +fn projection_to_proto(proj: &Projection) -> pb::ProjectionProto { + pb::ProjectionProto { + field_ids: proj.field_ids.iter().copied().collect(), + with_row_id: proj.with_row_id, + with_row_addr: proj.with_row_addr, + with_row_last_updated_at_version: proj.with_row_last_updated_at_version, + with_row_created_at_version: proj.with_row_created_at_version, + blob_handling: Some(blob_handling_to_proto(&proj.blob_handling)), + } +} + +fn blob_handling_to_proto(bh: &BlobHandling) -> pb::BlobHandlingProto { + use pb::blob_handling_proto::Mode; + let mode = match bh { + BlobHandling::AllBinary => Some(Mode::AllBinary(true)), + BlobHandling::BlobsDescriptions => Some(Mode::BlobsDescriptions(true)), + BlobHandling::AllDescriptions => Some(Mode::AllDescriptions(true)), + BlobHandling::SomeBlobsBinary(ids) => Some(Mode::SomeBlobsBinary(pb::FieldIdSet { + field_ids: ids.iter().copied().collect(), + })), + BlobHandling::SomeBinary(ids) => Some(Mode::SomeBinary(pb::FieldIdSet { + field_ids: ids.iter().copied().collect(), + })), + }; + pb::BlobHandlingProto { mode } +} + +fn blob_handling_from_proto(proto: Option<&pb::BlobHandlingProto>) -> BlobHandling { + use pb::blob_handling_proto::Mode; + match proto.and_then(|p| p.mode.as_ref()) { + Some(Mode::AllBinary(_)) => BlobHandling::AllBinary, + Some(Mode::BlobsDescriptions(_)) => BlobHandling::BlobsDescriptions, + Some(Mode::AllDescriptions(_)) => BlobHandling::AllDescriptions, + Some(Mode::SomeBlobsBinary(ids)) => { + BlobHandling::SomeBlobsBinary(ids.field_ids.iter().copied().collect()) + } + Some(Mode::SomeBinary(ids)) => { + BlobHandling::SomeBinary(ids.field_ids.iter().copied().collect()) + } + // Default for backwards compatibility with protos that don't have blob_handling + None => BlobHandling::default(), + } +} + +fn projection_from_proto( + proto: Option<&pb::ProjectionProto>, + base: Arc, +) -> Result { + let proto = proto.ok_or_else(|| Error::InvalidInput { + source: "Missing projection in proto".into(), + location: location!(), + })?; + + let mut projection = Projection::empty(base); + for field_id in &proto.field_ids { + projection.field_ids.insert(*field_id); + } + if proto.with_row_id { + projection = projection.with_row_id(); + } + if proto.with_row_addr { + projection = projection.with_row_addr(); + } + if proto.with_row_last_updated_at_version { + projection = projection.with_row_last_updated_at_version(); + } + if proto.with_row_created_at_version { + projection = projection.with_row_created_at_version(); + } + projection = + projection.with_blob_handling(blob_handling_from_proto(proto.blob_handling.as_ref())); + Ok(projection) +} + +// ============================================================================= +// Threading mode <-> Proto +// ============================================================================= + +fn threading_mode_to_proto(mode: &FilteredReadThreadingMode) -> pb::FilteredReadThreadingModeProto { + let mode_oneof = match mode { + FilteredReadThreadingMode::OnePartitionMultipleThreads(n) => { + pb::filtered_read_threading_mode_proto::Mode::OnePartitionMultipleThreads(*n as u64) + } + FilteredReadThreadingMode::MultiplePartitions(n) => { + pb::filtered_read_threading_mode_proto::Mode::MultiplePartitions(*n as u64) + } + }; + pb::FilteredReadThreadingModeProto { + mode: Some(mode_oneof), + } +} + +fn threading_mode_from_proto( + proto: &pb::FilteredReadThreadingModeProto, +) -> Result { + match &proto.mode { + Some(pb::filtered_read_threading_mode_proto::Mode::OnePartitionMultipleThreads(n)) => Ok( + FilteredReadThreadingMode::OnePartitionMultipleThreads(*n as usize), + ), + Some(pb::filtered_read_threading_mode_proto::Mode::MultiplePartitions(n)) => { + Ok(FilteredReadThreadingMode::MultiplePartitions(*n as usize)) + } + None => Err(Error::InvalidInput { + source: "Missing threading mode in proto".into(), + location: location!(), + }), + } +} + +// ============================================================================= +// Helpers +// ============================================================================= + +fn range_to_proto(range: &Range) -> pb::U64Range { + pb::U64Range { + start: range.start, + end: range.end, + } +} + +fn range_from_proto(proto: &pb::U64Range) -> Range { + proto.start..proto.end +} + +fn fragments_from_proto(fragment_ids: &[u64], dataset: &Arc) -> Result> { + fragment_ids + .iter() + .map(|id| { + dataset + .manifest + .fragments + .iter() + .find(|f| f.id == *id) + .cloned() + .ok_or_else(|| Error::InvalidInput { + source: format!("Fragment {} not found in dataset", id).into(), + location: location!(), + }) + }) + .collect() +} + +fn schema_to_bytes(schema: &ArrowSchema) -> Result> { + let options = + arrow_ipc::writer::IpcWriteOptions::try_new(8, false, arrow_ipc::MetadataVersion::V5) + .map_err(|e| Error::Internal { + message: format!("Failed to create IPC write options: {}", e), + location: location!(), + })?; + let gen = arrow_ipc::writer::IpcDataGenerator::default(); + let mut tracker = arrow_ipc::writer::DictionaryTracker::new(false); + let encoded = gen.schema_to_bytes_with_dictionary_tracker(schema, &mut tracker, &options); + Ok(encoded.ipc_message.to_vec()) +} + +fn schema_from_bytes(bytes: &[u8]) -> Result> { + let message = arrow_ipc::root_as_message(bytes).map_err(|e| Error::Internal { + message: format!("Failed to parse IPC schema message: {}", e), + location: location!(), + })?; + let ipc_schema = message.header_as_schema().ok_or_else(|| Error::Internal { + message: "IPC message does not contain a schema".to_string(), + location: location!(), + })?; + let schema = arrow_ipc::convert::fb_to_schema(ipc_schema); + Ok(Arc::new(schema)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::types::UInt32Type; + use arrow_schema::{DataType, Field}; + use datafusion::prelude::SessionContext; + use lance_core::datatypes::OnMissing; + use lance_core::utils::mask::RowAddrTreeMap; + use lance_datagen::{array, gen_batch}; + use roaring::RoaringBitmap; + use std::collections::HashSet; + + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; + + #[test] + fn test_range_roundtrip() { + let range = 10u64..42u64; + let proto = range_to_proto(&range); + let back = range_from_proto(&proto); + assert_eq!(range, back); + } + + #[test] + fn test_threading_mode_roundtrip() { + let mode = FilteredReadThreadingMode::OnePartitionMultipleThreads(8); + let proto = threading_mode_to_proto(&mode); + let back = threading_mode_from_proto(&proto).unwrap(); + assert_eq!(mode, back); + + let mode = FilteredReadThreadingMode::MultiplePartitions(4); + let proto = threading_mode_to_proto(&mode); + let back = threading_mode_from_proto(&proto).unwrap(); + assert_eq!(mode, back); + } + + #[test] + fn test_schema_roundtrip() { + let schema = ArrowSchema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + let bytes = schema_to_bytes(&schema).unwrap(); + let back = schema_from_bytes(&bytes).unwrap(); + assert_eq!(schema, *back); + } + + #[test] + fn test_projection_roundtrip() { + let schema = lance_core::datatypes::Schema::try_from(&ArrowSchema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])) + .unwrap(); + + let base: Arc = Arc::new(schema); + + let mut projection = Projection::empty(base.clone()); + projection.field_ids = HashSet::from([0, 2]); + projection = projection + .with_row_id() + .with_row_addr() + .with_row_last_updated_at_version() + .with_row_created_at_version() + .with_blob_handling(BlobHandling::SomeBlobsBinary(HashSet::from([1, 3]))); + + let proto = projection_to_proto(&projection); + let back = projection_from_proto(Some(&proto), base).unwrap(); + + assert_eq!(projection.field_ids, back.field_ids); + assert_eq!(projection.with_row_id, back.with_row_id); + assert_eq!(projection.with_row_addr, back.with_row_addr); + assert_eq!( + projection.with_row_last_updated_at_version, + back.with_row_last_updated_at_version + ); + assert_eq!( + projection.with_row_created_at_version, + back.with_row_created_at_version + ); + assert_eq!(projection.blob_handling, back.blob_handling); + } + + #[test] + fn test_table_identifier_without_manifest() { + let id = pb::TableIdentifier { + uri: "s3://bucket/table.lance".to_string(), + version: 42, + manifest_etag: Some("etag123".to_string()), + serialized_manifest: None, + }; + let bytes = id.encode_to_vec(); + let back = pb::TableIdentifier::decode(bytes.as_slice()).unwrap(); + assert_eq!(id.uri, back.uri); + assert_eq!(id.version, back.version); + assert_eq!(id.manifest_etag, back.manifest_etag); + assert!(back.serialized_manifest.is_none()); + } + + #[test] + fn test_row_addr_tree_map_roundtrip_in_plan_proto() { + let mut rows = RowAddrTreeMap::new(); + let mut bitmap = RoaringBitmap::new(); + bitmap.insert_range(0..100); + rows.insert_bitmap(0, bitmap); + rows.insert_fragment(1); // Full fragment + + let mut buf = Vec::with_capacity(rows.serialized_size()); + rows.serialize_into(&mut buf).unwrap(); + let back = RowAddrTreeMap::deserialize_from(Cursor::new(&buf)).unwrap(); + assert_eq!(rows, back); + } + + async fn make_test_dataset() -> Arc { + let dataset = gen_batch() + .col("x", array::step::()) + .col("y", array::step::()) + .into_ram_dataset(FragmentCount::from(2), FragmentRowCount::from(50)) + .await + .unwrap(); + Arc::new(dataset) + } + + #[tokio::test] + async fn test_options_roundtrip_basic() { + let dataset = make_test_dataset().await; + let ctx = SessionContext::new(); + let state = ctx.state(); + let filter_schema = Arc::new(prune_schema_for_substrait(&dataset.schema().into())); + + let options = FilteredReadOptions::basic_full_read(&dataset) + .with_scan_range_before_filter(10..90) + .unwrap() + .with_batch_size(64) + .with_fragment_readahead(4) + .with_io_buffer_size(1024 * 1024); + + let proto = fr_options_to_proto(&options, &filter_schema, &state).unwrap(); + let back = fr_options_from_proto(proto, &dataset, &state) + .await + .unwrap(); + + assert_eq!( + options.scan_range_before_filter, + back.scan_range_before_filter + ); + assert_eq!(options.batch_size, back.batch_size); + assert_eq!(options.fragment_readahead, back.fragment_readahead); + assert_eq!(options.io_buffer_size_bytes, back.io_buffer_size_bytes); + assert_eq!(options.threading_mode, back.threading_mode); + assert_eq!(options.with_deleted_rows, back.with_deleted_rows); + assert_eq!(options.projection.field_ids, back.projection.field_ids); + assert_eq!(options.projection.with_row_id, back.projection.with_row_id); + assert_eq!( + options.projection.with_row_addr, + back.projection.with_row_addr + ); + } + + #[tokio::test] + async fn test_options_roundtrip_with_filter() { + let dataset = make_test_dataset().await; + let ctx = SessionContext::new(); + let state = ctx.state(); + let filter_schema = Arc::new(prune_schema_for_substrait(&dataset.schema().into())); + + let filter_expr = datafusion_expr::col("x").gt(datafusion_expr::lit(5i32)); + let refine_expr = datafusion_expr::col("x").lt(datafusion_expr::lit(100i32)); + let projection = dataset + .empty_projection() + .union_column("x", OnMissing::Error) + .unwrap() + .with_row_id(); + let mut options = FilteredReadOptions::new(projection) + .with_deleted_rows() + .unwrap(); + options.full_filter = Some(filter_expr); + options.refine_filter = Some(refine_expr); + options.threading_mode = FilteredReadThreadingMode::MultiplePartitions(4); + + let proto = fr_options_to_proto(&options, &filter_schema, &state).unwrap(); + + // Verify filter schema IPC was generated + assert!(proto.filter_schema_ipc.is_some()); + assert!(proto.full_filter_substrait.is_some()); + assert!(proto.refine_filter_substrait.is_some()); + + let back = fr_options_from_proto(proto, &dataset, &state) + .await + .unwrap(); + + assert!(back.full_filter.is_some()); + assert!(back.refine_filter.is_some()); + assert!(back.with_deleted_rows); + assert_eq!(options.threading_mode, back.threading_mode); + assert_eq!(options.projection.field_ids, back.projection.field_ids); + assert!(back.projection.with_row_id); + } + + #[tokio::test] + async fn test_options_roundtrip_with_fragments() { + let dataset = make_test_dataset().await; + let ctx = SessionContext::new(); + let state = ctx.state(); + let filter_schema = Arc::new(prune_schema_for_substrait(&dataset.schema().into())); + + let frags = dataset.get_fragments(); + let first_frag = vec![frags[0].metadata().clone()]; + let options = + FilteredReadOptions::basic_full_read(&dataset).with_fragments(Arc::new(first_frag)); + + let proto = fr_options_to_proto(&options, &filter_schema, &state).unwrap(); + assert_eq!(proto.fragment_ids.len(), 1); + + let back = fr_options_from_proto(proto, &dataset, &state) + .await + .unwrap(); + assert!(back.fragments.is_some()); + assert_eq!(back.fragments.as_ref().unwrap().len(), 1); + assert_eq!( + back.fragments.as_ref().unwrap()[0].id, + options.fragments.as_ref().unwrap()[0].id + ); + } + + #[tokio::test] + async fn test_exec_to_proto_roundtrip() { + let dataset = make_test_dataset().await; + let ctx = SessionContext::new(); + let state = ctx.state(); + + let options = FilteredReadOptions::basic_full_read(&dataset) + .with_batch_size(32) + .with_scan_range_before_filter(0..50) + .unwrap(); + + let exec = FilteredReadExec::try_new(dataset.clone(), options, None).unwrap(); + + let proto = filtered_read_exec_to_proto(&exec, &state).unwrap(); + + // Check table identifier + let table = proto.table.as_ref().unwrap(); + assert_eq!(table.uri, dataset.uri()); + assert_eq!(table.version, dataset.manifest.version); + assert!(table.serialized_manifest.is_none()); + + // Roundtrip back + let back = filtered_read_exec_from_proto(proto, dataset.clone(), None, &state) + .await + .unwrap(); + + assert_eq!(exec.options().batch_size, back.options().batch_size); + assert_eq!( + exec.options().scan_range_before_filter, + back.options().scan_range_before_filter + ); + assert_eq!( + exec.options().projection.field_ids, + back.options().projection.field_ids + ); + } + + #[tokio::test] + async fn test_table_identifier_with_manifest() { + let dataset = make_test_dataset().await; + + let id = table_identifier_from_dataset_with_manifest(&dataset); + assert_eq!(id.uri, dataset.uri()); + assert_eq!(id.version, dataset.manifest.version); + assert!(id.serialized_manifest.is_some()); + + // Verify the serialized manifest bytes decode + let manifest_bytes = id.serialized_manifest.unwrap(); + let _manifest_proto = + lance_table::format::pb::Manifest::decode(manifest_bytes.as_slice()).unwrap(); + } + + #[tokio::test] + async fn test_plan_proto_roundtrip() { + let dataset = make_test_dataset().await; + let ctx = SessionContext::new(); + let state = ctx.state(); + + let mut rows = RowAddrTreeMap::new(); + let mut bitmap0 = RoaringBitmap::new(); + bitmap0.insert_range(0..25); + rows.insert_bitmap(0, bitmap0); + let mut bitmap1 = RoaringBitmap::new(); + bitmap1.insert_range(0..30); + rows.insert_bitmap(1, bitmap1); + + // Two fragments share the same Arc — dedup should encode it once. + let shared_filter = Arc::new(datafusion_expr::col("x").gt(datafusion_expr::lit(10i32))); + let mut filters = HashMap::new(); + filters.insert(0u32, Arc::clone(&shared_filter)); + filters.insert(1u32, Arc::clone(&shared_filter)); + + let plan = FilteredReadPlan { + rows, + filters, + scan_range_after_filter: Some(5..20), + }; + + let filter_schema = Arc::new(prune_schema_for_substrait(&dataset.schema().into())); + let proto = plan_to_proto(&plan, &filter_schema, &state).unwrap(); + + // Verify dedup: 2 fragments but only 1 unique expression + assert_eq!(proto.fragment_filter_ids.len(), 2); + assert_eq!( + proto.filter_expressions.len(), + 1, + "shared Arc should be deduplicated into a single expression" + ); + + let back = plan_from_proto(proto, &dataset, &state).await.unwrap(); + + assert_eq!(plan.rows, back.rows); + assert_eq!(plan.scan_range_after_filter, back.scan_range_after_filter); + assert_eq!(back.filters.len(), 2); + assert!(back.filters.contains_key(&0)); + assert!(back.filters.contains_key(&1)); + // After roundtrip, the decoded expressions should be shared via Arc too + assert!(Arc::ptr_eq(&back.filters[&0], &back.filters[&1])); + } +}