Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ impl Dataset {
Arc::new(
self.ds
.schema()
.project(&columns)
.project_preserve_system_columns(&columns)
.map_err(|err| PyValueError::new_err(err.to_string()))?,
)
} else {
Expand Down
63 changes: 59 additions & 4 deletions rust/lance-core/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ impl Schema {
}
}

fn do_project<T: AsRef<str>>(&self, columns: &[T], err_on_missing: bool) -> Result<Self> {
fn do_project<T: AsRef<str>>(
&self,
columns: &[T],
err_on_missing: bool,
preserve_system_columns: bool,
) -> Result<Self> {
let mut candidates: Vec<Field> = vec![];
for col in columns {
let split = parse_field_path(col.as_ref())?;
Expand All @@ -234,7 +239,17 @@ impl Schema {
} else {
candidates.push(projected_field)
}
} else if err_on_missing && first != ROW_ID && first != ROW_ADDR {
} else if first == ROW_ID || first == ROW_ADDR {
// Note: Other system columns like _rowoffset are handled differently
if preserve_system_columns {
// For now we only support _rowid and _rowaddr in projections
if first == ROW_ID {
candidates.push(Field::try_from(ROW_ID_FIELD.clone())?);
} else if first == ROW_ADDR {
candidates.push(Field::try_from(ROW_ADDR_FIELD.clone())?);
}
}
} else if err_on_missing {
return Err(Error::Schema {
message: format!("Column {} does not exist", col.as_ref()),
location: location!(),
Expand All @@ -255,12 +270,17 @@ impl Schema {
/// let projected = schema.project(&["col1", "col2.sub_col3.field4"])?;
/// ```
pub fn project<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
self.do_project(columns, true)
self.do_project(columns, true, false)
}

/// Project the columns over the schema, dropping unrecognized columns
pub fn project_or_drop<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
self.do_project(columns, false)
self.do_project(columns, false, false)
}

/// Project the columns over the schema, preserving system columns.
pub fn project_preserve_system_columns<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
Comment thread
hamersaw marked this conversation as resolved.
self.do_project(columns, true, true)
}

/// Check that the top level fields don't contain `.` in their names
Expand Down Expand Up @@ -1832,6 +1852,41 @@ mod tests {
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}

#[test]
fn test_schema_projection_preserving_system_columns() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projected = schema
.project_preserve_system_columns(&["b.f1", "b.f3", "_rowid", "c"])
.unwrap();

let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("_rowid", DataType::UInt64, true),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}

#[test]
fn test_schema_project_by_ids() {
let arrow_schema = ArrowSchema::new(vec![
Expand Down
50 changes: 5 additions & 45 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ use crate::dataset::transaction::translate_schema_metadata_updates;
use crate::session::caches::{DSMetadataCache, ManifestKey, TransactionKey};
use crate::session::index_caches::DSIndexCache;
use itertools::Itertools;
use lance_core::datatypes::{
BlobVersion, Field, OnMissing, OnTypeMismatch, Projectable, Projection,
};
use lance_core::datatypes::{BlobVersion, OnMissing, OnTypeMismatch, Projectable, Projection};
use lance_core::traits::DatasetTakeRows;
use lance_core::utils::address::RowAddress;
use lance_core::utils::tracing::{
DATASET_CLEANING_EVENT, DATASET_DELETING_EVENT, DATASET_DROPPING_COLUMN_EVENT,
TRACE_DATASET_EVENTS,
};
use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID_FIELD};
use lance_core::ROW_ADDR;
use lance_datafusion::projection::ProjectionPlan;
use lance_file::datatypes::populate_schema_dictionary;
use lance_file::reader::FileReaderOptions;
Expand Down Expand Up @@ -334,47 +332,9 @@ impl ProjectionRequest {
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();

// Separate data columns from system columns
// System columns need to be added to the schema manually since Schema::project
// doesn't include them (they're virtual columns)
let mut data_columns = Vec::new();
let mut system_fields = Vec::new();

for col in &columns {
if lance_core::is_system_column(col) {
// For now we only support _rowid and _rowaddr in projections
if col == ROW_ID {
system_fields.push(Field::try_from(ROW_ID_FIELD.clone()).unwrap());
} else if col == ROW_ADDR {
system_fields.push(Field::try_from(ROW_ADDR_FIELD.clone()).unwrap());
}
// Note: Other system columns like _rowoffset are handled differently
} else {
data_columns.push(col.as_str());
}
}

// Project only the data columns
let mut schema = dataset_schema.project(&data_columns).unwrap();

// Add system fields in the order they appeared in the original columns list
// We need to reconstruct the proper order
let mut final_fields = Vec::new();
for col in &columns {
if lance_core::is_system_column(col) {
// Find and add the system field
if let Some(field) = system_fields.iter().find(|f| &f.name == col) {
final_fields.push(field.clone());
}
} else {
// Find and add the data field
if let Some(field) = schema.fields.iter().find(|f| &f.name == col) {
final_fields.push(field.clone());
}
}
}

schema.fields = final_fields;
let schema = dataset_schema
.project_preserve_system_columns(&columns)
.unwrap();
Self::Schema(Arc::new(schema))
}

Expand Down
Loading