Skip to content

Commit

Permalink
Deprecate imbalanced hierarchical K-means from clustering and semanti…
Browse files Browse the repository at this point in the history
…c indexing. (#151)

* Deprecate imbalanced hierarchical kmeans from clustering and semantic indexing.

* Fix a style error.

Co-authored-by: Jyun-Yu Jiang <[email protected]>
  • Loading branch information
hallogameboy and Jyun-Yu Jiang authored Jul 12, 2022
1 parent 6723f61 commit b4a3d96
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 123 deletions.
16 changes: 0 additions & 16 deletions examples/pecos-xrlinear-jmlr22/xrl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,6 @@ def parse_arguments():
help="The max size of the leaf nodes of hierarchical 2-means clustering. If larger than total number of labels, One-Versus-All model will be trained. Default 100.",
)

parser.add_argument(
"--imbalanced-ratio",
type=float,
default=0.0,
metavar="FLOAT",
help="Value between 0.0 and 0.5 (inclusive). Indicates how relaxed the balancedness constraint of 2-means can be. Specifically, if an iteration of 2-means is clustering L labels, the size of the output 2 clusters will be within approx imbalanced_ratio * 2 * L of each other. (default 0.0)",
)

parser.add_argument(
"--imbalanced-depth",
type=int,
default=100,
metavar="INT",
help="After hierarchical 2-means clustering has reached this depth, it will continue clustering as if --imbalanced-ratio is set to 0.0. (default 100)",
)

parser.add_argument(
"--spherical",
type=cli.str2bool,
Expand Down
16 changes: 0 additions & 16 deletions pecos/apps/text2text/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,6 @@ def parse_arguments(args):
help="number of splits used to construct hierarchy (a power of 2 is recommended, default 16)",
)

parser.add_argument(
"--imbalanced-ratio",
type=float,
default=0.0,
metavar="FLOAT",
help="Value between 0.0 and 0.5 (inclusive). Indicates how relaxed the balancedness constraint of 2-means can be. Specifically, if an iteration of 2-means is clustering L labels, the size of the output 2 clusters will be within approx imbalanced_ratio * 2 * L of each other. (default 0.0)",
)

parser.add_argument(
"--imbalanced-depth",
type=int,
default=100,
metavar="INT",
help="After hierarchical 2-means clustering has reached this depth, it will continue clustering as if --imbalanced-ratio is set to 0.0. (default 100)",
)

parser.add_argument(
"--label-embed-type",
type=str,
Expand Down
90 changes: 36 additions & 54 deletions pecos/xmc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
XLINEAR_INFERENCE_MODEL_TYPES,
)
from pecos.utils import smat_util
from pecos.utils.cluster_util import ClusterChain, hierarchical_kmeans
from pecos.utils.cluster_util import ClusterChain
from sklearn.preprocessing import normalize

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,11 +86,9 @@ class HierarchicalKMeans(Indexer):
class TrainParams(pecos.BaseParams): # type: ignore
"""Training Parameters of Hierarchical K-means.
nr_splits (int, optional): The out-degree of each internal node of the tree. Ignored if `imbalanced_ratio != 0` because imbalanced clustering supports only 2-means. Default is `16`.
nr_splits (int, optional): The out-degree of each internal node of the tree. Default is `16`.
min_codes (int): The number of direct child nodes that the top level of the hierarchy should have.
max_leaf_size (int, optional): The maximum size of each leaf node of the tree. Default is `100`.
imbalanced_ratio (float, optional): Value between `0.0` and `0.5` (inclusive). Indicates how relaxed the balancedness constraint of 2-means can be. Specifically, if an iteration of 2-means is clustering `L` labels, the size of the output 2 clusters will be within approx `imbalanced_ratio * 2 * L` of each other. Default is `0.0`.
imbalanced_depth (int, optional): Maximum depth of imbalanced clustering. After depth `imbalanced_depth` is reached, balanced clustering will be used. Default is `100`.
spherical (bool, optional): True will l2-normalize the centroids of k-means after each iteration. Default is `True`.
seed (int, optional): Random seed. Default is `0`.
kmeans_max_iter (int, optional): Maximum number of iterations for each k-means problem. Default is `20`.
Expand All @@ -100,8 +98,6 @@ class TrainParams(pecos.BaseParams): # type: ignore
nr_splits: int = 16
min_codes: int = None # type: ignore
max_leaf_size: int = 100
imbalanced_ratio: float = 0.0
imbalanced_depth: int = 100
spherical: bool = True
seed: int = 0
kmeans_max_iter: int = 20
Expand Down Expand Up @@ -138,59 +134,45 @@ def gen(
f"HierarchicalKMeans train_params: {json.dumps(train_params.to_dict(), indent=True)}"
)

# use optimized c++ clustering code if doing balanced clustering
if train_params.imbalanced_ratio == 0:
nr_instances = feat_mat.shape[0]
if train_params.max_leaf_size >= nr_instances:
# no-need to do clustering
return ClusterChain.from_partial_chain(
smat.csc_matrix(np.ones((nr_instances, 1), dtype=np.float32))
)

depth = max(1, int(math.ceil(math.log2(nr_instances / train_params.max_leaf_size))))
if (2**depth) > nr_instances:
raise ValueError(
f"max_leaf_size > 1 is needed for feat_mat.shape[0] == {nr_instances} to avoid empty clusters"
)
nr_instances = feat_mat.shape[0]
if train_params.max_leaf_size >= nr_instances:
# no-need to do clustering
return ClusterChain.from_partial_chain(
smat.csc_matrix(np.ones((nr_instances, 1), dtype=np.float32))
)

algo = cls.SKMEANS if train_params.spherical else cls.KMEANS
depth = max(1, int(math.ceil(math.log2(nr_instances / train_params.max_leaf_size))))
if (2**depth) > nr_instances:
raise ValueError(
f"max_leaf_size > 1 is needed for feat_mat.shape[0] == {nr_instances} to avoid empty clusters"
)

assert feat_mat.dtype == np.float32
if isinstance(feat_mat, (smat.csr_matrix, ScipyCsrF32)):
py_feat_mat = ScipyCsrF32.init_from(feat_mat)
elif isinstance(feat_mat, (np.ndarray, ScipyDrmF32)):
py_feat_mat = ScipyDrmF32.init_from(feat_mat)
else:
raise NotImplementedError(
"type(feat_mat) = {} is not supported.".format(type(feat_mat))
)
algo = cls.SKMEANS if train_params.spherical else cls.KMEANS

codes = np.zeros(py_feat_mat.rows, dtype=np.uint32)
codes = clib.run_clustering(
py_feat_mat,
depth,
algo,
train_params.seed,
codes=codes,
kmeans_max_iter=train_params.kmeans_max_iter,
threads=train_params.threads,
)
C = cls.convert_codes_to_csc_matrix(codes, depth)
cluster_chain = ClusterChain.from_partial_chain(
C, min_codes=train_params.min_codes, nr_splits=train_params.nr_splits
)
assert feat_mat.dtype == np.float32
if isinstance(feat_mat, (smat.csr_matrix, ScipyCsrF32)):
py_feat_mat = ScipyCsrF32.init_from(feat_mat)
elif isinstance(feat_mat, (np.ndarray, ScipyDrmF32)):
py_feat_mat = ScipyDrmF32.init_from(feat_mat)
else:
cluster_chain = hierarchical_kmeans(
feat_mat,
max_leaf_size=train_params.max_leaf_size,
imbalanced_ratio=train_params.imbalanced_ratio,
imbalanced_depth=train_params.imbalanced_depth,
spherical=train_params.spherical,
seed=train_params.seed,
kmeans_max_iter=train_params.kmeans_max_iter,
threads=train_params.threads,
raise NotImplementedError(
"type(feat_mat) = {} is not supported.".format(type(feat_mat))
)
cluster_chain = ClusterChain(cluster_chain)

codes = np.zeros(py_feat_mat.rows, dtype=np.uint32)
codes = clib.run_clustering(
py_feat_mat,
depth,
algo,
train_params.seed,
codes=codes,
kmeans_max_iter=train_params.kmeans_max_iter,
threads=train_params.threads,
)
C = cls.convert_codes_to_csc_matrix(codes, depth)
cluster_chain = ClusterChain.from_partial_chain(
C, min_codes=train_params.min_codes, nr_splits=train_params.nr_splits
)
return cluster_chain

@staticmethod
Expand Down
16 changes: 0 additions & 16 deletions pecos/xmc/xlinear/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,6 @@ def parse_arguments():
help="The max size of the leaf nodes of hierarchical 2-means clustering. If larger than total number of labels, One-Versus-All model will be trained. Default 100.",
)

parser.add_argument(
"--imbalanced-ratio",
type=float,
default=0.0,
metavar="FLOAT",
help="Value between 0.0 and 0.5 (inclusive). Indicates how relaxed the balancedness constraint of 2-means can be. Specifically, if an iteration of 2-means is clustering L labels, the size of the output 2 clusters will be within approx imbalanced_ratio * 2 * L of each other. (default 0.0)",
)

parser.add_argument(
"--imbalanced-depth",
type=int,
default=100,
metavar="INT",
help="After hierarchical 2-means clustering has reached this depth, it will continue clustering as if --imbalanced-ratio is set to 0.0. (default 100)",
)

parser.add_argument(
"--spherical",
type=cli.str2bool,
Expand Down
14 changes: 0 additions & 14 deletions pecos/xmc/xtransformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,6 @@ def parse_arguments():
metavar="INT",
help="The max size of the leaf nodes of hierarchical clustering. If larger than the number of labels, OVA model will be trained. Default 100.",
)
parser.add_argument(
"--imbalanced-ratio",
type=float,
default=0.0,
metavar="FLOAT",
help="Value between 0.0 and 0.5 (inclusive). Indicates how relaxed the balancedness constraint of 2-means can be. Specifically, if an iteration of 2-means is clustering L labels, the size of the output 2 clusters will be within approx imbalanced_ratio * 2 * L of each other. (default 0.0)",
)
parser.add_argument(
"--imbalanced-depth",
type=int,
default=100,
metavar="INT",
help="After hierarchical 2-means clustering has reached this depth, it will continue clustering as if --imbalanced-ratio is set to 0.0. (default 100)",
)
# ========= matcher parameters ============
parser.add_argument(
"--max-match-clusters",
Expand Down
7 changes: 0 additions & 7 deletions test/pecos/xmc/test_xmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,13 @@ def test_hierarchicalkmeans():
smat.csr_matrix([[1, 0], [0.95, 0.05], [0.9, 0.1], [0, 1]], dtype=np.float32)
)
target_balanced = [0, 0, 1, 1]
target_imbalanced = [0, 0, 0, 1]

balanced_chain = Indexer.gen(feat_mat, max_leaf_size=3)
balanced_assignments = (balanced_chain[-1].todense() == [0, 1]).all(axis=1).A1
assert np.array_equal(balanced_assignments, target_balanced) or np.array_equal(
~balanced_assignments, target_balanced
)

imbalanced_chain = Indexer.gen(feat_mat, imbalanced_ratio=0.4, max_leaf_size=3)
imbalanced_assignments = (imbalanced_chain[-1].todense() == [0, 1]).all(axis=1).A1
assert np.array_equal(imbalanced_assignments, target_imbalanced) or np.array_equal(
~imbalanced_assignments, target_imbalanced
)

chain2 = Indexer.gen(feat_mat, max_leaf_size=1, nr_splits=2)
chain4 = Indexer.gen(feat_mat, max_leaf_size=1, nr_splits=4)

Expand Down

0 comments on commit b4a3d96

Please sign in to comment.