diff --git a/pecos/core/ann/quantizer_impl/common.hpp b/pecos/core/ann/quantizer_impl/common.hpp index c319bd3..0d79043 100644 --- a/pecos/core/ann/quantizer_impl/common.hpp +++ b/pecos/core/ann/quantizer_impl/common.hpp @@ -221,14 +221,13 @@ namespace ann { // fit HLT or flat-Kmeans for each sub-space std::vector assignments(sub_sample_points); - pecos::clustering::Tree hlt(std::log2(num_of_local_centroids)); + int hlt_depth = std::log2(num_of_local_centroids); + pecos::clustering::Tree hlt(hlt_depth); + pecos::clustering::ClusteringParam clustering_param(0, hlt_depth, seed, max_iter, threads, 1.0, 1.0, 1.0); hlt.run_clustering( Xsub, - 0, - seed, - assignments.data(), - max_iter, - threads); + &clustering_param, + assignments.data()); compute_centroids(Xsub, local_dimension, num_of_local_centroids, assignments.data(), &original_local_codebooks[m * num_of_local_centroids * local_dimension], threads); diff --git a/pecos/core/base.py b/pecos/core/base.py index e7d0e0f..00ab0a7 100644 --- a/pecos/core/base.py +++ b/pecos/core/base.py @@ -148,6 +148,28 @@ def __init__(self, base_vect_param_list, norm_p): self.norm_p = c_int32(norm_p) +class ClusteringParam(ctypes.Structure): + """ + python class for handling struct ClusteringParam in clustering.hpp + """ + + _fields_ = [ + ("partition_algo", c_uint64), + ("depth", c_uint64), + ("seed", c_int), + ("kmeans_max_iter", c_uint64), + ("threads", c_int), + ("max_sample_rate", c_float), + ("min_sample_rate", c_float), + ("warmup_ratio", c_float), + ] + + def __init__(self, params): + name2type = dict(ClusteringParam._fields_) + for name in name2type: + setattr(self, name, name2type[name](getattr(params, name))) + + class ScipyCscF32(ctypes.Structure): """ PyMatrix for scipy.sparse.csc_matrix @@ -1250,11 +1272,7 @@ def link_clustering(self): """ arg_list = [ POINTER(ScipyCsrF32), - c_uint32, - c_uint32, - c_int, - c_uint32, - c_int, + POINTER(ClusteringParam), POINTER(c_uint32), ] corelib.fillprototype( @@ -1267,24 +1285,16 @@ def link_clustering(self): def run_clustering( self, py_feat_mat, - depth, - algo, - seed, + train_params, codes=None, - kmeans_max_iter=20, - threads=-1, ): """ Run clustering with given label embedding matrix and parameters in C++. Args: py_feat_mat (ScipyCsrF32, ScipyDrmF32): label embedding matrix. (num_labels x num_features). - depth (int): Depth of K-means clustering N-nary tree. - algo (str): The algorithm for clustering, either `KMEANS` or `SKMEANS`. - seed (int): Randoms seed. + train_params (HierarchicalKMeans.TrainParams): Parameter class defined in pecos.xmc.base.HierarchicalKMeans.TrainParams. codes (ndarray, optional): Label clustering results. - kmeans_max_iter (int, optional): Maximum number of iter for reordering each node based on score. - threads (int, optional): The number of threads. Default -1 to use all cores. Return: codes (ndarray): The clustering result. @@ -1302,13 +1312,10 @@ def run_clustering( if codes is None or len(codes) != py_feat_mat.shape[0] or codes.dtype != np.uint32: codes = np.zeros(py_feat_mat.rows, dtype=np.uint32) + train_params = ClusteringParam(train_params) run_clustering( byref(py_feat_mat), - depth, - algo, - seed, - kmeans_max_iter, - threads, + train_params, codes.ctypes.data_as(POINTER(c_uint32)), ) return codes diff --git a/pecos/core/libpecos.cpp b/pecos/core/libpecos.cpp index 912d96e..071caaa 100644 --- a/pecos/core/libpecos.cpp +++ b/pecos/core/libpecos.cpp @@ -257,15 +257,11 @@ extern "C" { #define C_RUN_CLUSTERING(SUFFIX, PY_MAT, C_MAT) \ void c_run_clustering ## SUFFIX( \ const PY_MAT* py_mat_ptr, \ - uint32_t depth, \ - uint32_t partition_algo, \ - int seed, \ - uint32_t max_iter, \ - int threads, \ + pecos::clustering::ClusteringParam* param_ptr, \ uint32_t* label_codes) { \ C_MAT feat_mat(py_mat_ptr); \ - pecos::clustering::Tree tree(depth); \ - tree.run_clustering(feat_mat, partition_algo, seed, label_codes, max_iter, threads); \ + pecos::clustering::Tree tree(param_ptr->depth); \ + tree.run_clustering(feat_mat, param_ptr, label_codes); \ } C_RUN_CLUSTERING(_csr_f32, ScipyCsrF32, pecos::csr_t) C_RUN_CLUSTERING(_drm_f32, ScipyDrmF32, pecos::drm_t) diff --git a/pecos/core/utils/clustering.hpp b/pecos/core/utils/clustering.hpp index 04777d3..fa7061a 100644 --- a/pecos/core/utils/clustering.hpp +++ b/pecos/core/utils/clustering.hpp @@ -31,35 +31,45 @@ enum { SKMEANS=5, }; /* partition_algo */ -enum { - CONSTANT_SAMPLE_SCHEDULE=0, - LINEAR_SAMPLE_SCHEDULE=1, -}; /* sample strategies */ - extern "C" { - struct ClusteringSamplerParam { - int strategy; - float sample_rate; - float warmup_sample_rate; - float warmup_layer_rate; - - ClusteringSamplerParam( - int strategy, - float sample_rate, - float warmup_sample_rate, - float warmup_layer_rate - ): strategy(strategy), - sample_rate(sample_rate), - warmup_sample_rate(warmup_sample_rate), - warmup_layer_rate(warmup_layer_rate) { - if(sample_rate <= 0 || sample_rate > 1) { - throw std::invalid_argument("expect 0 < sample_rate <= 1.0"); + struct ClusteringParam { + size_t partition_algo; + size_t depth; + int seed; + size_t kmeans_max_iter; + int threads; + float max_sample_rate; + float min_sample_rate; + float warmup_ratio; + + ClusteringParam( + size_t partition_algo, + size_t depth, + int seed, + size_t kmeans_max_iter, + int threads, + float max_sample_rate, + float min_sample_rate, + float warmup_ratio + ): partition_algo(partition_algo), + depth(depth), + seed(seed), + kmeans_max_iter(kmeans_max_iter), + threads(threads), + max_sample_rate(max_sample_rate), + min_sample_rate(min_sample_rate), + warmup_ratio(warmup_ratio) { + if(min_sample_rate <= 0 || min_sample_rate > 1) { + throw std::invalid_argument("expect 0 < min_sample_rate <= 1.0"); } - if(warmup_sample_rate <= 0 || warmup_sample_rate > 1) { - throw std::invalid_argument("expect 0 < warmup_sample_rate <= 1.0"); + if(max_sample_rate <= 0 || max_sample_rate > 1) { + throw std::invalid_argument("expect 0 < max_sample_rate <= 1.0"); } - if(warmup_layer_rate < 0 || warmup_layer_rate > 1) { - throw std::invalid_argument("expect 0 <= warmup_layer_rate <= 1.0"); + if(min_sample_rate > max_sample_rate) { + throw std::invalid_argument("expect min_sample_rate <= max_sample_rate"); + } + if(warmup_ratio < 0 || warmup_ratio > 1) { + throw std::invalid_argument("expect 0 <= warmup_ratio <= 1.0"); } } }; @@ -113,28 +123,20 @@ struct Tree { struct ClusteringSampler { // scheduler for sampling - ClusteringSamplerParam* param_ptr; + ClusteringParam* param_ptr; size_t warmup_layers; - size_t depth; - ClusteringSampler(ClusteringSamplerParam* param_ptr, size_t depth): param_ptr(param_ptr), depth(depth) { - warmup_layers = size_t(depth * param_ptr->warmup_layer_rate); + ClusteringSampler(ClusteringParam* param_ptr): param_ptr(param_ptr) { + warmup_layers = size_t(param_ptr->depth * param_ptr->warmup_ratio); } float get_sample_rate(size_t layer) const { - if(param_ptr->strategy == LINEAR_SAMPLE_SCHEDULE) { - return _get_linear_sample_rate(layer); - } - return param_ptr->sample_rate; // Constant strategy - } - - float _get_linear_sample_rate(size_t layer) const { - // If input `layer` < `warmup_layers`, return `warmup_sample_rate`. - // Otherwise, linearly increase the current sample rate from `warmup_sample_rate` to `sample_rate` until the last layer. + // If input `layer` < `warmup_layers`, return `min_sample_rate`. + // Otherwise, linearly increase the current sample rate from `min_sample_rate` to `max_sample_rate` until the last layer. if(layer < warmup_layers) { - return param_ptr->warmup_sample_rate; + return param_ptr->min_sample_rate; } - return param_ptr->warmup_sample_rate + (param_ptr->sample_rate - param_ptr->warmup_sample_rate) * (layer + 1 - warmup_layers) / (depth - warmup_layers); + return param_ptr->min_sample_rate + (param_ptr->max_sample_rate - param_ptr->min_sample_rate) * (layer + 1 - warmup_layers) / (param_ptr->depth - warmup_layers); } }; @@ -360,19 +362,19 @@ struct Tree { } template - void run_clustering(const MAT& feat_mat, int partition_algo, int seed=0, IND *label_codes=NULL, size_t max_iter=10, int threads=1, ClusteringSamplerParam* sample_param_ptr=NULL) { + void run_clustering(const MAT& feat_mat, ClusteringParam* param_ptr=NULL, IND *label_codes=NULL) { size_t nr_elements = feat_mat.rows; elements.resize(nr_elements); previous_elements.resize(nr_elements); for(size_t i = 0; i < nr_elements; i++) { elements[i] = i; } - rng_t rng(seed); + rng_t rng(param_ptr->seed); for(size_t nid = 0; nid < nodes.size(); nid++) { seed_for_nodes[nid] = rng.randint(); } - threads = set_threads(threads); + int threads = set_threads(param_ptr->threads); center1.resize(threads, f32_sdvec_t(feat_mat.cols)); center2.resize(threads, f32_sdvec_t(feat_mat.cols)); scores.resize(feat_mat.rows, 0); @@ -382,10 +384,7 @@ struct Tree { // Allocate tmp arrays for parallel update center center_tmp_thread.resize(threads, f32_sdvec_t(feat_mat.cols)); - if(sample_param_ptr == NULL) { - sample_param_ptr = new ClusteringSamplerParam(CONSTANT_SAMPLE_SCHEDULE, 1.0, 1.0, 1.0); // no sampling for default constructor - } - ClusteringSampler sample_scheduler(sample_param_ptr, depth); + ClusteringSampler sample_scheduler(param_ptr); // let's do it layer by layer so we can parallelize it for(size_t d = 0; d < depth; d++) { @@ -398,10 +397,10 @@ struct Tree { rng_t rng(seed_for_nodes[nid]); int local_threads = 1; int thread_id = omp_get_thread_num(); - if(partition_algo == KMEANS) { - partition_kmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate); - } else if(partition_algo == SKMEANS) { - partition_skmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate); + if(param_ptr->partition_algo == KMEANS) { + partition_kmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate); + } else if(param_ptr->partition_algo == SKMEANS) { + partition_skmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate); } } } else { @@ -409,10 +408,10 @@ struct Tree { rng_t rng(seed_for_nodes[nid]); int local_threads = threads; int thread_id = 0; - if(partition_algo == KMEANS) { - partition_kmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate); - } else if(partition_algo == SKMEANS) { - partition_skmeans(nid, d, feat_mat, rng, max_iter, local_threads, thread_id, cur_sample_rate); + if(param_ptr->partition_algo == KMEANS) { + partition_kmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate); + } else if(param_ptr->partition_algo == SKMEANS) { + partition_skmeans(nid, d, feat_mat, rng, param_ptr->kmeans_max_iter, local_threads, thread_id, cur_sample_rate); } } } diff --git a/pecos/xmc/base.py b/pecos/xmc/base.py index edad390..49b6a8e 100644 --- a/pecos/xmc/base.py +++ b/pecos/xmc/base.py @@ -93,6 +93,14 @@ class TrainParams(pecos.BaseParams): # type: ignore seed (int, optional): Random seed. Default is `0`. kmeans_max_iter (int, optional): Maximum number of iterations for each k-means problem. Default is `20`. threads (int, optional): Number of threads to use. `-1` denotes all CPUs. Default is `-1`. + do_sample (bool, optional): Do sampling if is True. Default is False. + We use linear sampling strategy with warmup, which linearly increases sampling rate from `min_sample_rate` to `max_sample_rate`. + The top (total_layer * `warmup_ratio`) layers are warmup_layers which use a fixed sampling rate `min_sample_rate`. + The sampling rate for layer l is `min_sample_rate`+max(l+1-warmup_layer,0)*(`max_sample_rate`-min_sample_rate)/(total_layers-warmup_layers). + max_sample_rate (float, optional): the maximum samplng rate at the end of the linear sampling strategy. Default is `1.0`. + min_sample_rate (float, optional): the minimum sampling rate at the begining warmup stage of the linear sampling strategy. Default is `0.1`. + Note that 0 < min_sample_rate <= max_sample_rate <= 1.0. + warmup_ratio: (float, optional): The ratio of warmup layers. 0 <= warmup_ratio <= 1.0. Default is 0.4. """ nr_splits: int = 16 @@ -103,6 +111,12 @@ class TrainParams(pecos.BaseParams): # type: ignore kmeans_max_iter: int = 20 threads: int = -1 + # paramters for sampling of hierarchical clustering + do_sample: bool = False + max_sample_rate: float = 1.0 + min_sample_rate: float = 0.1 + warmup_ratio: float = 0.4 + @classmethod def gen( cls, @@ -130,6 +144,11 @@ def gen( if train_params.min_codes is None: train_params.min_codes = train_params.nr_splits + if not train_params.do_sample: + # set the min_sample_rate to be 1.0 so it doesn't do sampling + train_params.warmup_ratio = 1.0 + train_params.min_sample_rate = 1.0 + LOGGER.debug( f"HierarchicalKMeans train_params: {json.dumps(train_params.to_dict(), indent=True)}" ) @@ -146,8 +165,10 @@ def gen( raise ValueError( f"max_leaf_size > 1 is needed for feat_mat.shape[0] == {nr_instances} to avoid empty clusters" ) + train_params.depth = depth - algo = cls.SKMEANS if train_params.spherical else cls.KMEANS + partition_algo = cls.SKMEANS if train_params.spherical else cls.KMEANS + train_params.partition_algo = partition_algo assert feat_mat.dtype == np.float32 if isinstance(feat_mat, (smat.csr_matrix, ScipyCsrF32)): @@ -162,12 +183,8 @@ def gen( codes = np.zeros(py_feat_mat.rows, dtype=np.uint32) codes = clib.run_clustering( py_feat_mat, - depth, - algo, - train_params.seed, - codes=codes, - kmeans_max_iter=train_params.kmeans_max_iter, - threads=train_params.threads, + train_params, + codes, ) C = cls.convert_codes_to_csc_matrix(codes, depth) cluster_chain = ClusterChain.from_partial_chain( diff --git a/test/pecos/xmc/test_xmc.py b/test/pecos/xmc/test_xmc.py index 151596c..fed3d00 100644 --- a/test/pecos/xmc/test_xmc.py +++ b/test/pecos/xmc/test_xmc.py @@ -45,6 +45,33 @@ def test_hierarchicalkmeans(): assert (chain2.chain[1].dot(chain2.chain[0]) - chain4.chain[0]).nnz == 0 +def test_hierarchicalkmeans_sampling(): + import numpy as np + import scipy.sparse as smat + from sklearn.preprocessing import normalize + from pecos.xmc import Indexer + from pecos.xmc.base import HierarchicalKMeans + + # randomly sampling arbitary number of examples from the 4 following instances results in the same clustering results. + feat_mat = normalize( + smat.csr_matrix([[1, 0], [0.99, 0.02], [0.01, 1.03], [0, 1]], dtype=np.float32) + ) + target_balanced = [0, 0, 1, 1] + + # the clustering results are the same as long as min_sample_rate >= 0.25 + train_params = HierarchicalKMeans.TrainParams( + do_sample=True, + min_sample_rate=0.75, + warmup_ratio=1.0, + max_leaf_size=2, + ) + balanced_chain = Indexer.gen(feat_mat, train_params=train_params) + balanced_assignments = (balanced_chain[-1].todense() == [0, 1]).all(axis=1).A1 + assert np.array_equal(balanced_assignments, target_balanced) or np.array_equal( + ~balanced_assignments, target_balanced + ) + + def test_label_embedding(): import random import numpy as np