Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions rust/lance-index/src/vector/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ impl KMeans {
id: usize,
indices: Vec<usize>,
centroid: Vec<N>,
finalized: bool,
}

impl<N> Eq for Cluster<N> {}
Expand All @@ -770,8 +771,15 @@ impl KMeans {

impl<N> Ord for Cluster<N> {
fn cmp(&self, other: &Self) -> Ordering {
// Max heap: larger clusters first
self.indices.len().cmp(&other.indices.len())
// Non-finalized clusters should always have higher priority than finalized ones
match (self.finalized, other.finalized) {
(false, true) => Ordering::Greater,
(true, false) => Ordering::Less,
_ => {
// Max heap: larger clusters first
self.indices.len().cmp(&other.indices.len())
}
}
}
}

Expand Down Expand Up @@ -838,6 +846,7 @@ impl KMeans {
id: next_cluster_id,
indices: cluster_indices,
centroid,
finalized: false,
});
next_cluster_id += 1;
}
Expand All @@ -846,17 +855,22 @@ impl KMeans {
// Iteratively split largest clusters until we have target_k clusters
while heap.len() < target_k {
// Get the largest cluster
let largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError(
let mut largest_cluster = heap.pop().ok_or(ArrowError::InvalidArgumentError(
"No cluster can be further split".to_string(),
))?;

// Skip if cluster has only 1 point
// If this cluster is already finalized, no further split is possible; stop splitting
if largest_cluster.finalized {
log::warn!("Cluster {} is already finalized, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1);
heap.push(largest_cluster);
break;
}

// Because the clusters are sorted by size, if the cluster has only 1 point, no further split is possible; stop splitting
if largest_cluster.indices.len() <= 1 {
log::warn!("Cluster {} has only 1 point, no further split is possible, finish with {} clusters", largest_cluster.id, heap.len()+ 1);
heap.push(largest_cluster);
if heap.iter().all(|c| c.indices.len() <= 1) {
break; // No more splits possible
}
continue;
break;
}

let cluster_size = largest_cluster.indices.len();
Expand All @@ -881,17 +895,17 @@ impl KMeans {
};

// Create sub-dataset for this cluster using indices
let cluster_fsl = Self::create_array_from_indices::<T>(
let sub_data = Self::create_array_from_indices::<T>(
&largest_cluster.indices,
data_values,
dimension,
)?;

// Run kmeans on this cluster
let sub_kmeans = Self::train_kmeans::<T, Algo>(&cluster_fsl, cluster_k, params)?;
let sub_kmeans = Self::train_kmeans::<T, Algo>(&sub_data, cluster_k, params)?;

// Get membership for points in the sub-cluster
let sub_data = cluster_fsl.values().as_primitive::<T>().values();
let sub_data = sub_data.values().as_primitive::<T>().values();
let (sub_membership, _, _) = Algo::compute_membership_and_loss(
sub_kmeans.centroids.as_primitive::<T>().values(),
sub_data,
Expand All @@ -902,31 +916,65 @@ impl KMeans {
None,
);

// Create new sub-clusters and add to heap
let sub_centroids = sub_kmeans.centroids.as_primitive::<T>().values();
for i in 0..cluster_k {
let mut new_cluster_indices = Vec::new();
for (local_idx, &sub_cluster_id) in sub_membership.iter().enumerate() {
if let Some(sid) = sub_cluster_id {
if sid as usize == i {
let global_idx = largest_cluster.indices[local_idx];
new_cluster_indices.push(global_idx);
}
// Build per-cluster membership while checking whether the split is effective
let approx_cluster_capacity = if cluster_k > 0 {
largest_cluster.indices.len().div_ceil(cluster_k)
} else {
0
};
let mut cluster_assignments: Vec<Vec<usize>> = (0..cluster_k)
.map(|_| Vec::with_capacity(approx_cluster_capacity))
.collect();

let mut first_sid: Option<u32> = None;
let mut all_same = true;
for (local_idx, &membership) in sub_membership.iter().enumerate() {
let Some(sub_cluster_id) = membership else {
continue;
};

if let Some(first) = first_sid {
if sub_cluster_id != first {
all_same = false;
}
} else {
first_sid = Some(sub_cluster_id);
}

let sub_cluster_id = sub_cluster_id as usize;
if let Some(indices) = cluster_assignments.get_mut(sub_cluster_id) {
indices.push(largest_cluster.indices[local_idx]);
} else {
// Unexpected assignment outside [0, cluster_k); treat as ineffective split.
all_same = false;
}
}

if !new_cluster_indices.is_empty() {
let centroid_start = i * dimension;
let centroid_end = centroid_start + dimension;
let centroid = sub_centroids[centroid_start..centroid_end].to_vec();
// If all memberships are identical, the split is ineffective; finalize the original cluster
if all_same {
largest_cluster.finalized = true;
heap.push(largest_cluster);
continue;
}

heap.push(Cluster {
id: next_cluster_id,
indices: new_cluster_indices,
centroid,
});
next_cluster_id += 1;
// Create new sub-clusters and add to heap
let sub_centroids = sub_kmeans.centroids.as_primitive::<T>().values();
for (i, new_cluster_indices) in cluster_assignments.into_iter().enumerate() {
if new_cluster_indices.is_empty() {
continue;
}

let centroid_start = i * dimension;
let centroid_end = centroid_start + dimension;
let centroid = sub_centroids[centroid_start..centroid_end].to_vec();

heap.push(Cluster {
id: next_cluster_id,
indices: new_cluster_indices,
centroid,
finalized: false,
});
next_cluster_id += 1;
}

log::debug!(
Expand Down
Loading