Skip to content

Commit

Permalink
refactor interface for hierarchical clustering and sampling (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaushian authored Dec 17, 2022
1 parent 23ecdd1 commit 2ad6828
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 96 deletions.
11 changes: 5 additions & 6 deletions pecos/core/ann/quantizer_impl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,13 @@ namespace ann {

// fit HLT or flat-Kmeans for each sub-space
std::vector<index_type> assignments(sub_sample_points);
pecos::clustering::Tree hlt(std::log2(num_of_local_centroids));
int hlt_depth = std::log2(num_of_local_centroids);
pecos::clustering::Tree hlt(hlt_depth);
pecos::clustering::ClusteringParam clustering_param(0, hlt_depth, seed, max_iter, threads, 1.0, 1.0, 1.0);
hlt.run_clustering<pecos::drm_t, index_type>(
Xsub,
0,
seed,
assignments.data(),
max_iter,
threads);
&clustering_param,
assignments.data());

compute_centroids(Xsub, local_dimension, num_of_local_centroids, assignments.data(),
&original_local_codebooks[m * num_of_local_centroids * local_dimension], threads);
Expand Down
47 changes: 27 additions & 20 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,28 @@ def __init__(self, base_vect_param_list, norm_p):
self.norm_p = c_int32(norm_p)


class ClusteringParam(ctypes.Structure):
"""
python class for handling struct ClusteringParam in clustering.hpp
"""

_fields_ = [
("partition_algo", c_uint64),
("depth", c_uint64),
("seed", c_int),
("kmeans_max_iter", c_uint64),
("threads", c_int),
("max_sample_rate", c_float),
("min_sample_rate", c_float),
("warmup_ratio", c_float),
]

def __init__(self, params):
name2type = dict(ClusteringParam._fields_)
for name in name2type:
setattr(self, name, name2type[name](getattr(params, name)))


class ScipyCscF32(ctypes.Structure):
"""
PyMatrix for scipy.sparse.csc_matrix
Expand Down Expand Up @@ -1250,11 +1272,7 @@ def link_clustering(self):
"""
arg_list = [
POINTER(ScipyCsrF32),
c_uint32,
c_uint32,
c_int,
c_uint32,
c_int,
POINTER(ClusteringParam),
POINTER(c_uint32),
]
corelib.fillprototype(
Expand All @@ -1267,24 +1285,16 @@ def link_clustering(self):
def run_clustering(
self,
py_feat_mat,
depth,
algo,
seed,
train_params,
codes=None,
kmeans_max_iter=20,
threads=-1,
):
"""
Run clustering with given label embedding matrix and parameters in C++.
Args:
py_feat_mat (ScipyCsrF32, ScipyDrmF32): label embedding matrix. (num_labels x num_features).
depth (int): Depth of K-means clustering N-nary tree.
algo (str): The algorithm for clustering, either `KMEANS` or `SKMEANS`.
seed (int): Randoms seed.
train_params (HierarchicalKMeans.TrainParams): Parameter class defined in pecos.xmc.base.HierarchicalKMeans.TrainParams.
codes (ndarray, optional): Label clustering results.
kmeans_max_iter (int, optional): Maximum number of iter for reordering each node based on score.
threads (int, optional): The number of threads. Default -1 to use all cores.
Return:
codes (ndarray): The clustering result.
Expand All @@ -1302,13 +1312,10 @@ def run_clustering(

if codes is None or len(codes) != py_feat_mat.shape[0] or codes.dtype != np.uint32:
codes = np.zeros(py_feat_mat.rows, dtype=np.uint32)
train_params = ClusteringParam(train_params)
run_clustering(
byref(py_feat_mat),
depth,
algo,
seed,
kmeans_max_iter,
threads,
train_params,
codes.ctypes.data_as(POINTER(c_uint32)),
)
return codes
Expand Down
10 changes: 3 additions & 7 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,11 @@ extern "C" {
#define C_RUN_CLUSTERING(SUFFIX, PY_MAT, C_MAT) \
void c_run_clustering ## SUFFIX( \
const PY_MAT* py_mat_ptr, \
uint32_t depth, \
uint32_t partition_algo, \
int seed, \
uint32_t max_iter, \
int threads, \
pecos::clustering::ClusteringParam* param_ptr, \
uint32_t* label_codes) { \
C_MAT feat_mat(py_mat_ptr); \
pecos::clustering::Tree tree(depth); \
tree.run_clustering(feat_mat, partition_algo, seed, label_codes, max_iter, threads); \
pecos::clustering::Tree tree(param_ptr->depth); \
tree.run_clustering(feat_mat, param_ptr, label_codes); \
}
C_RUN_CLUSTERING(_csr_f32, ScipyCsrF32, pecos::csr_t)
C_RUN_CLUSTERING(_drm_f32, ScipyDrmF32, pecos::drm_t)
Expand Down
111 changes: 55 additions & 56 deletions pecos/core/utils/clustering.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,45 @@ enum {
SKMEANS=5,
}; /* partition_algo */

enum {
CONSTANT_SAMPLE_SCHEDULE=0,
LINEAR_SAMPLE_SCHEDULE=1,
}; /* sample strategies */

extern "C" {
struct ClusteringSamplerParam {
int strategy;
float sample_rate;
float warmup_sample_rate;
float warmup_layer_rate;

ClusteringSamplerParam(
int strategy,
float sample_rate,
float warmup_sample_rate,
float warmup_layer_rate
): strategy(strategy),
sample_rate(sample_rate),
warmup_sample_rate(warmup_sample_rate),
warmup_layer_rate(warmup_layer_rate) {
if(sample_rate <= 0 || sample_rate > 1) {
throw std::invalid_argument("expect 0 < sample_rate <= 1.0");
struct ClusteringParam {
size_t partition_algo;
size_t depth;
int seed;
size_t kmeans_max_iter;
int threads;
float max_sample_rate;
float min_sample_rate;
float warmup_ratio;

ClusteringParam(
size_t partition_algo,
size_t depth,
int seed,
size_t kmeans_max_iter,
int threads,
float max_sample_rate,
float min_sample_rate,
float warmup_ratio
): partition_algo(partition_algo),
depth(depth),
seed(seed),
kmeans_max_iter(kmeans_max_iter),
threads(threads),
max_sample_rate(max_sample_rate),
min_sample_rate(min_sample_rate),
warmup_ratio(warmup_ratio) {
if(min_sample_rate <= 0 || min_sample_rate > 1) {
throw std::invalid_argument("expect 0 < min_sample_rate <= 1.0");
}
if(warmup_sample_rate <= 0 || warmup_sample_rate > 1) {
throw std::invalid_argument("expect 0 < warmup_sample_rate <= 1.0");
if(max_sample_rate <= 0 || max_sample_rate > 1) {
throw std::invalid_argument("expect 0 < max_sample_rate <= 1.0");
}
if(warmup_layer_rate < 0 || warmup_layer_rate > 1) {
throw std::invalid_argument("expect 0 <= warmup_layer_rate <= 1.0");
if(min_sample_rate > max_sample_rate) {
throw std::invalid_argument("expect min_sample_rate <= max_sample_rate");
}
if(warmup_ratio < 0 || warmup_ratio > 1) {
throw std::invalid_argument("expect 0 <= warmup_ratio <= 1.0");
}
}
};
Expand Down Expand Up @@ -113,28 +123,20 @@ struct Tree {

struct ClusteringSampler {
// scheduler for sampling
ClusteringSamplerParam* param_ptr;
ClusteringParam* param_ptr;
size_t warmup_layers;
size_t depth;

ClusteringSampler(ClusteringSamplerParam* param_ptr, size_t depth): param_ptr(param_ptr), depth(depth) {
warmup_layers = size_t(depth * param_ptr->warmup_layer_rate);
ClusteringSampler(ClusteringParam* param_ptr): param_ptr(param_ptr) {
warmup_layers = size_t(param_ptr->depth * param_ptr->warmup_ratio);
}

float get_sample_rate(size_t layer) const {
if(param_ptr->strategy == LINEAR_SAMPLE_SCHEDULE) {
return _get_linear_sample_rate(layer);
}
return param_ptr->sample_rate; // Constant strategy
}

float _get_linear_sample_rate(size_t layer) const {
// If input `layer` < `warmup_layers`, return `warmup_sample_rate`.
// Otherwise, linearly increase the current sample rate from `warmup_sample_rate` to `sample_rate` until the last layer.
// If input `layer` < `warmup_layers`, return `min_sample_rate`.
// Otherwise, linearly increase the current sample rate from `min_sample_rate` to `max_sample_rate` until the last layer.
if(layer < warmup_layers) {
return param_ptr->warmup_sample_rate;
return param_ptr->min_sample_rate;
}
return param_ptr->warmup_sample_rate + (param_ptr->sample_rate - param_ptr->warmup_sample_rate) * (layer + 1 - warmup_layers) / (depth - warmup_layers);
return param_ptr->min_sample_rate + (param_ptr->max_sample_rate - param_ptr->min_sample_rate) * (layer + 1 - warmup_layers) / (param_ptr->depth - warmup_layers);
}
};

Expand Down Expand Up @@ -360,19 +362,19 @@ struct Tree {
}

template<typename MAT, typename IND=unsigned>
void run_clustering(const MAT& feat_mat, int partition_algo, int seed=0, IND *label_codes=NULL, size_t max_iter=10, int threads=1, ClusteringSamplerParam* sample_param_ptr=NULL) {
void run_clustering(const MAT& feat_mat, ClusteringParam* param_ptr=NULL, IND *label_codes=NULL) {
size_t nr_elements = feat_mat.rows;
elements.resize(nr_elements);
previous_elements.resize(nr_elements);
for(size_t i = 0; i < nr_elements; i++) {
elements[i] = i;
}
rng_t rng(seed);
rng_t rng(param_ptr->seed);
for(size_t nid = 0; nid < nodes.size(); nid++) {
seed_for_nodes[nid] = rng.randint<unsigned>();
}

threads = set_threads(threads);
int threads = set_threads(param_ptr->threads);
center1.resize(threads, f32_sdvec_t(feat_mat.cols));
center2.resize(threads, f32_sdvec_t(feat_mat.cols));
scores.resize(feat_mat.rows, 0);
Expand All @@ -382,10 +384,7 @@ struct Tree {
// Allocate tmp arrays for parallel update center
center_tmp_thread.resize(threads, f32_sdvec_t(feat_mat.cols));

if(sample_param_ptr == NULL) {
sample_param_ptr = new ClusteringSamplerParam(CONSTANT_SAMPLE_SCHEDULE, 1.0, 1.0, 1.0); // no sampling for default constructor
}
ClusteringSampler sample_scheduler(sample_param_ptr, depth);
ClusteringSampler sample_scheduler(param_ptr);

// let's do it layer by layer so we can parallelize it
for(size_t d = 0; d < depth; d++) {
Expand All @@ -398,21 +397,21 @@ struct Tree {
rng_t rng(seed_for_nodes[nid]);
int local_threads = 1;
int thread_id = omp_get_thread_num();
if(partition_algo == KMEANS) {
partition_kmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate);
} else if(partition_algo == SKMEANS) {
partition_skmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate);
if(param_ptr->partition_algo == KMEANS) {
partition_kmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate);
} else if(param_ptr->partition_algo == SKMEANS) {
partition_skmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate);
}
}
} else {
for(size_t nid = layer_start; nid < layer_end; nid++) {
rng_t rng(seed_for_nodes[nid]);
int local_threads = threads;
int thread_id = 0;
if(partition_algo == KMEANS) {
partition_kmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate);
} else if(partition_algo == SKMEANS) {
partition_skmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate);
if(param_ptr->partition_algo == KMEANS) {
partition_kmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate);
} else if(param_ptr->partition_algo == SKMEANS) {
partition_skmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate);
}
}
}
Expand Down
31 changes: 24 additions & 7 deletions pecos/xmc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ class TrainParams(pecos.BaseParams): # type: ignore
seed (int, optional): Random seed. Default is `0`.
kmeans_max_iter (int, optional): Maximum number of iterations for each k-means problem. Default is `20`.
threads (int, optional): Number of threads to use. `-1` denotes all CPUs. Default is `-1`.
do_sample (bool, optional): Do sampling if is True. Default is False.
We use linear sampling strategy with warmup, which linearly increases sampling rate from `min_sample_rate` to `max_sample_rate`.
The top (total_layer * `warmup_ratio`) layers are warmup_layers which use a fixed sampling rate `min_sample_rate`.
The sampling rate for layer l is `min_sample_rate`+max(l+1-warmup_layer,0)*(`max_sample_rate`-min_sample_rate)/(total_layers-warmup_layers).
max_sample_rate (float, optional): the maximum samplng rate at the end of the linear sampling strategy. Default is `1.0`.
min_sample_rate (float, optional): the minimum sampling rate at the begining warmup stage of the linear sampling strategy. Default is `0.1`.
Note that 0 < min_sample_rate <= max_sample_rate <= 1.0.
warmup_ratio: (float, optional): The ratio of warmup layers. 0 <= warmup_ratio <= 1.0. Default is 0.4.
"""

nr_splits: int = 16
Expand All @@ -103,6 +111,12 @@ class TrainParams(pecos.BaseParams): # type: ignore
kmeans_max_iter: int = 20
threads: int = -1

# paramters for sampling of hierarchical clustering
do_sample: bool = False
max_sample_rate: float = 1.0
min_sample_rate: float = 0.1
warmup_ratio: float = 0.4

@classmethod
def gen(
cls,
Expand Down Expand Up @@ -130,6 +144,11 @@ def gen(
if train_params.min_codes is None:
train_params.min_codes = train_params.nr_splits

if not train_params.do_sample:
# set the min_sample_rate to be 1.0 so it doesn't do sampling
train_params.warmup_ratio = 1.0
train_params.min_sample_rate = 1.0

LOGGER.debug(
f"HierarchicalKMeans train_params: {json.dumps(train_params.to_dict(), indent=True)}"
)
Expand All @@ -146,8 +165,10 @@ def gen(
raise ValueError(
f"max_leaf_size > 1 is needed for feat_mat.shape[0] == {nr_instances} to avoid empty clusters"
)
train_params.depth = depth

algo = cls.SKMEANS if train_params.spherical else cls.KMEANS
partition_algo = cls.SKMEANS if train_params.spherical else cls.KMEANS
train_params.partition_algo = partition_algo

assert feat_mat.dtype == np.float32
if isinstance(feat_mat, (smat.csr_matrix, ScipyCsrF32)):
Expand All @@ -162,12 +183,8 @@ def gen(
codes = np.zeros(py_feat_mat.rows, dtype=np.uint32)
codes = clib.run_clustering(
py_feat_mat,
depth,
algo,
train_params.seed,
codes=codes,
kmeans_max_iter=train_params.kmeans_max_iter,
threads=train_params.threads,
train_params,
codes,
)
C = cls.convert_codes_to_csc_matrix(codes, depth)
cluster_chain = ClusterChain.from_partial_chain(
Expand Down
27 changes: 27 additions & 0 deletions test/pecos/xmc/test_xmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@ def test_hierarchicalkmeans():
assert (chain2.chain[1].dot(chain2.chain[0]) - chain4.chain[0]).nnz == 0


def test_hierarchicalkmeans_sampling():
import numpy as np
import scipy.sparse as smat
from sklearn.preprocessing import normalize
from pecos.xmc import Indexer
from pecos.xmc.base import HierarchicalKMeans

# randomly sampling arbitary number of examples from the 4 following instances results in the same clustering results.
feat_mat = normalize(
smat.csr_matrix([[1, 0], [0.99, 0.02], [0.01, 1.03], [0, 1]], dtype=np.float32)
)
target_balanced = [0, 0, 1, 1]

# the clustering results are the same as long as min_sample_rate >= 0.25
train_params = HierarchicalKMeans.TrainParams(
do_sample=True,
min_sample_rate=0.75,
warmup_ratio=1.0,
max_leaf_size=2,
)
balanced_chain = Indexer.gen(feat_mat, train_params=train_params)
balanced_assignments = (balanced_chain[-1].todense() == [0, 1]).all(axis=1).A1
assert np.array_equal(balanced_assignments, target_balanced) or np.array_equal(
~balanced_assignments, target_balanced
)


def test_label_embedding():
import random
import numpy as np
Expand Down

0 comments on commit 2ad6828

Please sign in to comment.