diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index 6e590651313..48c61bcdbe4 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -758,6 +758,7 @@ impl KMeans { id: usize, indices: Vec, centroid: Vec, + finalized: bool, } impl Eq for Cluster {} @@ -770,8 +771,15 @@ impl KMeans { impl Ord for Cluster { fn cmp(&self, other: &Self) -> Ordering { - // Max heap: larger clusters first - self.indices.len().cmp(&other.indices.len()) + // Non-finalized clusters should always have higher priority than finalized ones + match (self.finalized, other.finalized) { + (false, true) => Ordering::Greater, + (true, false) => Ordering::Less, + _ => { + // Max heap: larger clusters first + self.indices.len().cmp(&other.indices.len()) + } + } } } @@ -838,6 +846,7 @@ impl KMeans { id: next_cluster_id, indices: cluster_indices, centroid, + finalized: false, }); next_cluster_id += 1; } @@ -846,17 +855,22 @@ impl KMeans { // Iteratively split largest clusters until we have target_k clusters while heap.len() < target_k { // Get the largest cluster - let largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError( + let mut largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError( "No cluster can be further split".to_string(), ))?; - // Skip if cluster has only 1 point + // If this cluster is already finalized, no further split is possible; stop splitting + if largest_cluster.finalized { + log::warn!("Cluster {} is already finalized, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1); + heap.push(largest_cluster); + break; + } + + // Because the clusters are sorted by size, if the cluster has only 1 point, no further split is possible; stop splitting if largest_cluster.indices.len() <= 1 { + log::warn!("Cluster {} has only 1 point, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1); heap.push(largest_cluster); - if heap.iter().all(|c| c.indices.len() <= 1) { - break; // No more splits possible - } - continue; + break; } let cluster_size = largest_cluster.indices.len(); @@ -881,17 +895,17 @@ impl KMeans { }; // Create sub-dataset for this cluster using indices - let cluster_fsl = Self::create_array_from_indices::( + let sub_data = Self::create_array_from_indices::( &largest_cluster.indices, data_values, dimension, )?; // Run kmeans on this cluster - let sub_kmeans = Self::train_kmeans::(&cluster_fsl, cluster_k, params)?; + let sub_kmeans = Self::train_kmeans::(&sub_data, cluster_k, params)?; // Get membership for points in the sub-cluster - let sub_data = cluster_fsl.values().as_primitive::().values(); + let sub_data = sub_data.values().as_primitive::().values(); let (sub_membership, _, _) = Algo::compute_membership_and_loss( sub_kmeans.centroids.as_primitive::().values(), sub_data, @@ -902,31 +916,65 @@ impl KMeans { None, ); - // Create new sub-clusters and add to heap - let sub_centroids = sub_kmeans.centroids.as_primitive::().values(); - for i in 0..cluster_k { - let mut new_cluster_indices = Vec::new(); - for (local_idx, &sub_cluster_id) in sub_membership.iter().enumerate() { - if let Some(sid) = sub_cluster_id { - if sid as usize == i { - let global_idx = largest_cluster.indices[local_idx]; - new_cluster_indices.push(global_idx); - } + // Build per-cluster membership while checking whether the split is effective + let approx_cluster_capacity = if cluster_k > 0 { + largest_cluster.indices.len().div_ceil(cluster_k) + } else { + 0 + }; + let mut cluster_assignments: Vec> = (0..cluster_k) + .map(|_| Vec::with_capacity(approx_cluster_capacity)) + .collect(); + + let mut first_sid: Option = None; + let mut all_same = true; + for (local_idx, &membership) in sub_membership.iter().enumerate() { + let Some(sub_cluster_id) = membership else { + continue; + }; + + if let Some(first) = first_sid { + if sub_cluster_id != first { + all_same = false; } + } else { + first_sid = Some(sub_cluster_id); + } + + let sub_cluster_id = sub_cluster_id as usize; + if let Some(indices) = cluster_assignments.get_mut(sub_cluster_id) { + indices.push(largest_cluster.indices[local_idx]); + } else { + // Unexpected assignment outside [0, cluster_k); treat as ineffective split. + all_same = false; } + } - if !new_cluster_indices.is_empty() { - let centroid_start = i * dimension; - let centroid_end = centroid_start + dimension; - let centroid = sub_centroids[centroid_start..centroid_end].to_vec(); + // If all memberships are identical, the split is ineffective; finalize the original cluster + if all_same { + largest_cluster.finalized = true; + heap.push(largest_cluster); + continue; + } - heap.push(Cluster { - id: next_cluster_id, - indices: new_cluster_indices, - centroid, - }); - next_cluster_id += 1; + // Create new sub-clusters and add to heap + let sub_centroids = sub_kmeans.centroids.as_primitive::().values(); + for (i, new_cluster_indices) in cluster_assignments.into_iter().enumerate() { + if new_cluster_indices.is_empty() { + continue; } + + let centroid_start = i * dimension; + let centroid_end = centroid_start + dimension; + let centroid = sub_centroids[centroid_start..centroid_end].to_vec(); + + heap.push(Cluster { + id: next_cluster_id, + indices: new_cluster_indices, + centroid, + finalized: false, + }); + next_cluster_id += 1; } log::debug!(