Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def test_pre_populated_ivf_centroids(dataset, tmp_path: Path):
"metric_type": "l2",
"nbits": 8,
"num_sub_vectors": 8,
"transposed": True,
"transposed": False,
Comment thread
yanghua marked this conversation as resolved.
Outdated
},
"index_file_version": IndexFileVersion.V3,
}
Expand Down
61 changes: 59 additions & 2 deletions rust/lance/src/index/vector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ use super::{
utils::{self, get_vector_type},
};

/// Stably sort a RecordBatch by the ROW_ID column in ascending order.
///
/// If the batch has no ROW_ID column or has fewer than 2 rows, it is
/// returned unchanged. When sorting, the relative order of rows with the
/// same ROW_ID is preserved.
fn stable_sort_batch_by_row_id(batch: &RecordBatch) -> Result<RecordBatch> {
if let Some(row_id_col) = batch.column_by_name(ROW_ID) {
let row_ids = row_id_col.as_primitive::<UInt64Type>();
if row_ids.len() > 1 {
let mut order: Vec<usize> = (0..row_ids.len()).collect();
// Vec::sort_by is stable, so equal ROW_IDs keep their
// original relative order.
order.sort_by(|&i, &j| row_ids.value(i).cmp(&row_ids.value(j)));
let indices = UInt32Array::from_iter_values(order.into_iter().map(|i| i as u32));
return Ok(batch.take(&indices)?);
}
}
Ok(batch.clone())
}

// the number of partitions to evaluate for reassigning
const REASSIGN_RANGE: usize = 64;

Expand Down Expand Up @@ -935,6 +955,15 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
}
_ => {}
}

// Normalize each batch for this partition to be stably sorted by ROW_ID.
for batch in part_batches.iter_mut() {
if batch.num_rows() == 0 {
continue;
}
*batch = stable_sort_batch_by_row_id(batch)?;
}

batches.extend(part_batches);
}

Expand All @@ -958,6 +987,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
.map(|s| s.parse::<f64>().unwrap_or(0.0))
.unwrap_or(0.0);
let batch = batch.drop_column(PART_ID_COLUMN)?;
let batch = stable_sort_batch_by_row_id(&batch)?;
batches.push(batch);
}
}
Expand All @@ -981,6 +1011,8 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
));
};

let is_pq = Q::quantization_type() == QuantizationType::Product;

// prepare the final writers
let storage_path = self.index_dir.child(INDEX_AUXILIARY_FILE_NAME);
let index_path = self.index_dir.child(INDEX_FILE_NAME);
Expand Down Expand Up @@ -1024,7 +1056,32 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
storage_ivf.add_partition(0);
} else {
let batches = storage.to_batches()?.collect::<Vec<_>>();
let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;
let mut batch =
arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?;

if is_pq && batch.column_by_name(PQ_CODE_COLUMN).is_some() {
// The PQ storage keeps codes in a transposed layout (bytes grouped
// across all rows). Convert them back to per-row layout so that a
// stable ROW_ID sort moves PQ_CODE_COLUMN together with ROW_ID.
let codes_fsl = batch
.column_by_name(PQ_CODE_COLUMN)
.unwrap()
.as_fixed_size_list();
let num_rows = batch.num_rows();
let bytes_per_code = codes_fsl.value_length() as usize;
let codes = codes_fsl.values().as_primitive::<datatypes::UInt8Type>();
let original_codes = transpose(codes, bytes_per_code, num_rows);
let original_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
original_codes,
bytes_per_code as i32,
)?);
batch = batch.replace_column_by_name(PQ_CODE_COLUMN, original_fsl)?;
}

// Enforce a stable ROW_ID ordering for all auxiliary batches so that the
// PQ code column moves together with ROW_ID.
batch = stable_sort_batch_by_row_id(&batch)?;

storage_writer.write_batch(&batch).await?;
storage_ivf.add_partition(batch.num_rows() as u32);
}
Expand Down Expand Up @@ -1071,7 +1128,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
let mut metadata = quantizer.metadata(Some(QuantizationMetadata {
codebook_position: Some(0),
codebook: None,
transposed: true,
transposed: !is_pq,
}));
if let Some(extra_metadata) = metadata.extra_metadata()? {
let idx = storage_writer.add_global_buffer(extra_metadata).await?;
Expand Down
20 changes: 13 additions & 7 deletions rust/lance/src/index/vector/ivf/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,26 @@ pub(super) async fn write_pq_partitions(
location: location!(),
})?;
if let Some(pq_code) = pq_index.code.as_ref() {
let original_pq_codes = transpose(
pq_code,
pq_index.pq.num_sub_vectors,
pq_code.len() / pq_index.pq.code_dim(),
);
let row_ids = pq_index.row_ids.as_ref().unwrap();
let num_vectors = row_ids.len();
if num_vectors == 0 || pq_code.is_empty() {
continue;
}
if pq_code.len() % num_vectors != 0 {
continue;
}
let num_bytes_per_code = pq_code.len() / num_vectors;
let original_pq_codes = transpose(pq_code, num_bytes_per_code, num_vectors);
let fsl = Arc::new(
FixedSizeListArray::try_new_from_values(
original_pq_codes,
pq_index.pq.code_dim() as i32,
num_bytes_per_code as i32,
)
.unwrap(),
);

pq_array.push(fsl);
row_id_array.push(pq_index.row_ids.as_ref().unwrap().clone());
row_id_array.push(row_ids.clone());
}
}
}
Expand Down
Loading