Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,3 +1887,28 @@ def test_nested_field_vector_index(tmp_path):

# Verify total row count
assert dataset.count_rows() == num_rows + 50


def test_prewarm_index(tmp_path):
tbl = create_table()
dataset = lance.write_dataset(tbl, tmp_path, data_storage_version="2.1")
dataset = dataset.create_index(
"vector",
name="vector_index",
index_type="IVF_PQ",
num_partitions=4,
num_sub_vectors=16,
)
# Prewarm the index
dataset.prewarm_index("vector_index")

new_data = create_table(nvec=10)
dataset = lance.write_dataset(new_data, dataset.uri, mode="append")
q = new_data["vector"][0].as_py()

def func(rs: pa.Table):
if "vector" not in rs:
return
assert rs["vector"][0].as_py() == q

run(dataset, q=np.array(q), assert_func=func)
268 changes: 263 additions & 5 deletions rust/lance/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
use itertools::Itertools;
use lance_core::cache::{CacheKey, UnsizedCacheKey};
use lance_core::datatypes::Field;
use lance_core::datatypes::Schema as LanceSchema;
use lance_core::utils::address::RowAddress;
use lance_core::utils::parse::str_is_truthy;
use lance_core::utils::tracing::{
Expand Down Expand Up @@ -1391,12 +1393,11 @@ impl DatasetIndexInternalExt for Dataset {
})?;
let index_metadata: lance_index::IndexMetadata =
serde_json::from_str(index_metadata)?;
let field = self.schema().field(column).ok_or_else(|| Error::Index {
message: format!("Column {} does not exist in the schema", column),
location: location!(),
})?;

let (_, element_type) = get_vector_type(self.schema(), column)?;
// Resolve the column name and field
let (field_path, field) = resolve_index_column(self.schema(), &index_meta, column)?;

let (_, element_type) = get_vector_type(self.schema(), &field_path)?;

info!(target: TRACE_IO_EVENTS, index_uuid=uuid, r#type=IO_TYPE_OPEN_VECTOR, version="0.3", index_type=index_metadata.index_type);

Expand Down Expand Up @@ -1836,6 +1837,49 @@ impl DatasetIndexInternalExt for Dataset {
}
}

/// Resolves the column name and field for an index operation.
///
/// This function handles the case where the caller passes an index name instead of a column name.
/// It returns the full field path and the field reference.
fn resolve_index_column(
schema: &LanceSchema,
index_meta: &IndexMetadata,
column_arg: &str,
) -> Result<(String, Arc<Field>)> {
// First, try to find the column directly in the schema
if let Some(field) = schema.field(column_arg) {
// Column exists in schema, use it
return Ok((column_arg.to_string(), Arc::new(field.clone())));
}

// Column doesn't exist in schema, check if it's the index name
if column_arg == index_meta.name {
// Get the actual column from index metadata
if let Some(field_id) = index_meta.fields.first() {
let field = schema.field_by_id(*field_id).ok_or_else(|| Error::Index {
message: format!(
"Index '{}' references field with id {} which does not exist in schema",
index_meta.name, field_id
),
location: location!(),
})?;
let field_path = schema.field_path(*field_id)?;
return Ok((field_path, Arc::new(field.clone())));
} else {
return Err(Error::Index {
message: format!("Index '{}' has no fields", index_meta.name),
location: location!(),
});
}
}

// Column doesn't exist and is not the index name
Err(Error::Index {
message: format!("Column '{}' does not exist in the schema", column_arg),
location: location!(),
})
}

fn is_vector_field(data_type: DataType) -> bool {
match data_type {
DataType::FixedSizeList(_, _) => true,
Expand Down Expand Up @@ -4949,4 +4993,218 @@ mod tests {
);
assert!(found_count < num_rows, "Should not match all documents");
}

#[tokio::test]
async fn test_resolve_index_column() {
use lance_datagen::{array, BatchCount, RowCount};

// Create a test dataset with a vector column
let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

let reader = lance_datagen::gen_batch()
.col("id", array::step::<arrow_array::types::Int32Type>())
.col(
"vector",
array::rand_vec::<arrow_array::types::Float32Type>(32.into()),
)
.into_reader_rows(RowCount::from(100), BatchCount::from(1));

let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap();

// Create an index with a custom name
let params = crate::index::vector::VectorIndexParams::ivf_flat(
4,
lance_linalg::distance::MetricType::L2,
);
dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("my_vector_index".to_string()),
&params,
false,
)
.await
.unwrap();

// Reload dataset to get the index metadata
let dataset = Dataset::open(test_uri).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert_eq!(indices.len(), 1);
let index_meta = &indices[0];

// Test 1: Pass the actual column name
let (field_path, field) =
resolve_index_column(dataset.schema(), index_meta, "vector").unwrap();
assert_eq!(field_path, "vector");
assert_eq!(field.name, "vector");

// Test 2: Pass the index name (should resolve to the actual column)
let (field_path2, field2) =
resolve_index_column(dataset.schema(), index_meta, "my_vector_index").unwrap();
assert_eq!(field_path2, "vector");
assert_eq!(field2.name, "vector");

// Test 3: Pass a non-existent column name (should fail)
let result = resolve_index_column(dataset.schema(), index_meta, "nonexistent");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("does not exist in the schema"));
}

#[tokio::test]
async fn test_resolve_index_column_error_cases() {
use lance_datagen::{array, BatchCount, RowCount};

// Create a test dataset
let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

let reader = lance_datagen::gen_batch()
.col("id", array::step::<arrow_array::types::Int32Type>())
.col(
"vector",
array::rand_vec::<arrow_array::types::Float32Type>(32.into()),
)
.into_reader_rows(RowCount::from(100), BatchCount::from(1));

let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap();

// Create an index
let params = crate::index::vector::VectorIndexParams::ivf_flat(
4,
lance_linalg::distance::MetricType::L2,
);
dataset
.create_index(
&["vector"],
IndexType::Vector,
Some("my_index".to_string()),
&params,
false,
)
.await
.unwrap();

// Reload dataset
let dataset = Dataset::open(test_uri).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
let index_meta = &indices[0];

// Test: Pass a column that doesn't exist and is not the index name
let result = resolve_index_column(dataset.schema(), index_meta, "nonexistent_column");
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("does not exist in the schema"),
"Error message should mention column doesn't exist, got: {}",
err_msg
);
}

#[tokio::test]
async fn test_resolve_index_column_nested_field() {
use arrow_array::{RecordBatch, StructArray};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};

// Create a test dataset with nested struct manually
let test_dir = tempfile::tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

// Create schema with nested structure: data.vector
let vector_field = ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
8,
),
false,
);
let struct_field = ArrowField::new(
"data",
DataType::Struct(vec![vector_field.clone()].into()),
false,
);
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
struct_field,
]));

// Create data
let id_array = arrow_array::Int32Array::from(vec![1, 2, 3, 4, 5]);

// Create nested vector data
let mut vector_values = Vec::new();
for _ in 0..5 {
for _ in 0..8 {
vector_values.push(rand::random::<f32>());
}
}
let vector_array = arrow_array::FixedSizeListArray::try_new_from_values(
arrow_array::Float32Array::from(vector_values),
8,
)
.unwrap();

let struct_array = StructArray::from(vec![(
Arc::new(vector_field),
Arc::new(vector_array) as arrow_array::ArrayRef,
)]);

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_array), Arc::new(struct_array)],
)
.unwrap();

let reader = Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch)],
schema,
));

let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap();

// Create an index on the nested field
let params = crate::index::vector::VectorIndexParams::ivf_flat(
2,
lance_linalg::distance::MetricType::L2,
);
dataset
.create_index(
&["data.vector"],
IndexType::Vector,
Some("nested_vector_index".to_string()),
&params,
false,
)
.await
.unwrap();

// Reload dataset to get the index metadata
let dataset = Dataset::open(test_uri).await.unwrap();
let indices = dataset.load_indices().await.unwrap();
assert_eq!(indices.len(), 1);
let index_meta = &indices[0];

// Test 1: Pass the nested field path directly
let (field_path, field) =
resolve_index_column(dataset.schema(), index_meta, "data.vector").unwrap();
assert_eq!(field_path, "data.vector");
assert_eq!(field.name, "vector");

// Test 2: Pass the index name, should resolve to the nested field path
let (field_path2, field2) =
resolve_index_column(dataset.schema(), index_meta, "nested_vector_index").unwrap();
assert_eq!(field_path2, "data.vector");
assert_eq!(field2.name, "vector");

// Verify the field path is correct for nested access
assert!(
field_path2.contains('.'),
"Field path should contain '.' for nested field"
);
}
}