Skip to content

Commit

Permalink
python interface for sampling of hierarchical clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
yaushian committed Dec 10, 2022
1 parent aad94c6 commit 1135506
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 2 deletions.
23 changes: 23 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ def __init__(self, base_vect_param_list, norm_p):
self.norm_p = c_int32(norm_p)


class ClusteringSamplerParam(ctypes.Structure):
"""
python class for handling struct ClusteringSamplerParam in clustering.hpp
"""

_fields_ = [
("strategy", c_int32),
("sample_rate", c_float),
("warmup_sample_rate", c_float),
("warmup_layer_rate", c_float),
]

def __init__(self, params):
name2type = dict(ClusteringSamplerParam._fields_)
for name in name2type:
setattr(self, name, name2type[name](getattr(params, name)))


class ScipyCscF32(ctypes.Structure):
"""
PyMatrix for scipy.sparse.csc_matrix
Expand Down Expand Up @@ -1256,6 +1274,7 @@ def link_clustering(self):
c_uint32,
c_int,
POINTER(c_uint32),
POINTER(ClusteringSamplerParam),
]
corelib.fillprototype(
self.clib_float32.c_run_clustering_csr_f32, None, [POINTER(ScipyCsrF32)] + arg_list[1:]
Expand All @@ -1273,6 +1292,7 @@ def run_clustering(
codes=None,
kmeans_max_iter=20,
threads=-1,
sample_params=None,
):
"""
Run clustering with given label embedding matrix and parameters in C++.
Expand Down Expand Up @@ -1302,6 +1322,8 @@ def run_clustering(

if codes is None or len(codes) != py_feat_mat.shape[0] or codes.dtype != np.uint32:
codes = np.zeros(py_feat_mat.rows, dtype=np.uint32)
if sample_params is not None:
sample_params = ClusteringSamplerParam(sample_params)
run_clustering(
byref(py_feat_mat),
depth,
Expand All @@ -1310,6 +1332,7 @@ def run_clustering(
kmeans_max_iter,
threads,
codes.ctypes.data_as(POINTER(c_uint32)),
sample_params,
)
return codes

Expand Down
5 changes: 3 additions & 2 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,11 @@ extern "C" {
int seed, \
uint32_t max_iter, \
int threads, \
uint32_t* label_codes) { \
uint32_t* label_codes, \
pecos::clustering::ClusteringSamplerParam* sample_param_ptr) { \
C_MAT feat_mat(py_mat_ptr); \
pecos::clustering::Tree tree(depth); \
tree.run_clustering(feat_mat, partition_algo, seed, label_codes, max_iter, threads); \
tree.run_clustering(feat_mat, partition_algo, seed, label_codes, max_iter, threads, sample_param_ptr); \
}
C_RUN_CLUSTERING(_csr_f32, ScipyCsrF32, pecos::csr_t)
C_RUN_CLUSTERING(_drm_f32, ScipyDrmF32, pecos::drm_t)
Expand Down
33 changes: 33 additions & 0 deletions pecos/xmc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class HierarchicalKMeans(Indexer):
KMEANS = 0 # KMEANS
SKMEANS = 5 # Spherical KMEANS

sample_scheduler_dict = {
"constant": 0,
"linear": 1,
} # Sampling strategy corresponding in clustering.cpp line 34.

@dc.dataclass
class TrainParams(pecos.BaseParams): # type: ignore
"""Training Parameters of Hierarchical K-means.
Expand All @@ -103,12 +108,30 @@ class TrainParams(pecos.BaseParams): # type: ignore
kmeans_max_iter: int = 20
threads: int = -1

@dc.dataclass
class SampleParams(pecos.BaseParams): # type: ignore
"""Parameters of Hierarchical K-means Sampling.
sample_rate (float, optional): Final Sample rate of a scheduler. Default is 1.
warmup_layer_rate (float, optional): The top (total_layers*warmup_layer_rate) layers are warmup_layers. Default is `0.4`.
warmup_sample_rate (float, optional): The fixed sample rate for the warmup layers.
strategy (string, optional): Sampling scheduler that increases sample rate from top layers to buttom layers. Options are ["constant", "linear"]. Default is `linear`.
`linear` strategy: the sample rate of layer l is `warmup_sample_rate`+max(l+1-warmup_layer,0)*(`sample_rate`-warmup_sample_rate)/(total_layers-warmup_layers).
`constant` strategy: the sample rate for layer l is `sample_rate`.
"""

strategy: str = "linear"
sample_rate: float = 1.0
warmup_sample_rate: float = 0.1
warmup_layer_rate: float = 0.4

@classmethod
def gen(
cls,
feat_mat,
train_params=None,
dtype=np.float32,
sample_params=None,
**kwargs,
):
"""Generate a cluster chain by using hierarchical k-means.
Expand All @@ -117,6 +140,7 @@ def gen(
feat_mat (numpy.ndarray or scipy.sparse.csr.csr_matrix): Matrix of label features.
train_params (HierarchicalKMeans.TrainParams, optional): training parameters for indexing.
dtype (type, optional): Data type for matrices. Default is `numpy.float32`.
sample_params (HierarchicalKMeans.SampleParams, optional): parameters for sampling. None denotes no sampling. Default is None.
**kwargs: Ignored.
Returns:
Expand Down Expand Up @@ -149,6 +173,14 @@ def gen(

algo = cls.SKMEANS if train_params.spherical else cls.KMEANS

if sample_params is not None:
sample_params = cls.SampleParams.from_dict(sample_params)
if sample_params.strategy not in cls.sample_scheduler_dict:
raise ValueError(
f"The sampling strategy is needed to be either 'constant' or 'linear'."
)
sample_params.strategy = cls.sample_scheduler_dict[sample_params.strategy]

assert feat_mat.dtype == np.float32
if isinstance(feat_mat, (smat.csr_matrix, ScipyCsrF32)):
py_feat_mat = ScipyCsrF32.init_from(feat_mat)
Expand All @@ -168,6 +200,7 @@ def gen(
codes=codes,
kmeans_max_iter=train_params.kmeans_max_iter,
threads=train_params.threads,
sample_params=sample_params,
)
C = cls.convert_codes_to_csc_matrix(codes, depth)
cluster_chain = ClusterChain.from_partial_chain(
Expand Down
19 changes: 19 additions & 0 deletions test/pecos/xmc/test_xmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ def test_hierarchicalkmeans():

assert (chain2.chain[1].dot(chain2.chain[0]) - chain4.chain[0]).nnz == 0

# test sampling of hierarchical clustering.
# randomly sampling 3 out of the 4 following instances results in the same clustering results.
from pecos.xmc.base import HierarchicalKMeans

feat_mat = normalize(
smat.csr_matrix([[1, 0], [0.99, 0.02], [0.01, 1.03], [0, 1]], dtype=np.float32)
)
target_balanced = [0, 0, 1, 1]

# set the sampling rate to a fixed value 0.75
sample_params = HierarchicalKMeans.SampleParams(
strategy="linear", warmup_sample_rate=0.75, warmup_layer_rate=1.0
)
balanced_chain = Indexer.gen(feat_mat, max_leaf_size=2, sample_params=sample_params)
balanced_assignments = (balanced_chain[-1].todense() == [0, 1]).all(axis=1).A1
assert np.array_equal(balanced_assignments, target_balanced) or np.array_equal(
~balanced_assignments, target_balanced
)


def test_label_embedding():
import random
Expand Down

0 comments on commit 1135506

Please sign in to comment.