Skip to content

Commit

Permalink
add JSON api for distributed XR-Linear training (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang authored Feb 17, 2023
1 parent 4a22d26 commit cef885f
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 132 deletions.
15 changes: 9 additions & 6 deletions pecos/distributed/xmc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import dataclasses as dc
import itertools
import pecos
from pecos.utils import smat_util
from pecos.xmc import Indexer, LabelEmbeddingFactory
from pecos.xmc.base import HierarchicalKMeans
Expand Down Expand Up @@ -279,22 +280,24 @@ class DistClustering(object):
"""

@dc.dataclass
class ClusterParams(object):
class TrainParams(pecos.BaseParams):
"""Clustering parameters of Distributed Cluster Chain
Parameters:
indexer_params (HierarchicalKMeans.TrainParams): Params for indexing
meta_label_embedding_method (str): Meta-tree cluster label embedding method
meta_label_embedding_method (str): Meta-tree cluster label embedding method.
Default pifa
sub_label_embedding_method (str): Sub-tree cluster label embedding method
Default pifa
"""

indexer_params: HierarchicalKMeans.TrainParams # type: ignore
meta_label_embedding_method: str
sub_label_embedding_method: str
indexer_params: HierarchicalKMeans.TrainParams = None # type: ignore
meta_label_embedding_method: str = "pifa"
sub_label_embedding_method: str = "pifa"

def __init__(self, dist_comm, cluster_params):
assert isinstance(dist_comm, DistComm), type(dist_comm)
assert isinstance(cluster_params, self.ClusterParams), type(cluster_params)
assert isinstance(cluster_params, self.TrainParams), type(cluster_params)
assert cluster_params.meta_label_embedding_method in (
"pii",
"pifa",
Expand Down
4 changes: 2 additions & 2 deletions pecos/distributed/xmc/xlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,15 @@ def train(cls, dist_comm, X, Y, cluster_params, train_params, pred_params, dist_
dist_comm (DistComm): Distributed communicator.
X (csr_matrix(float32)): instance feature matrix of shape (nr_inst, nr_feat).
Y (csc_matrix(float32)): label matrix of shape (nr_inst, nr_labels).
cluster_params (DistClustering.ClusterParams): Clustering parameters.
cluster_params (DistClustering.TrainParams): Clustering parameters.
train_params (cls.TrainParams): Training parameters.
pred_params (cls.PredParams): Prediction parameters.
dist_params (cls.DistParams): Distributed parameters.
"""
assert isinstance(dist_comm, DistComm), type(dist_comm)
assert isinstance(X, csr_matrix), type(X)
assert isinstance(Y, csc_matrix), type(Y)
assert isinstance(cluster_params, DistClustering.ClusterParams), type(cluster_params)
assert isinstance(cluster_params, DistClustering.TrainParams), type(cluster_params)
assert isinstance(train_params, cls.TrainParams), type(train_params)
assert isinstance(pred_params, cls.PredParams), type(pred_params)
assert isinstance(dist_params, cls.DistParams), type(dist_params)
Expand Down
170 changes: 48 additions & 122 deletions pecos/distributed/xmc/xlinear/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,21 @@
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
# and limitations under the License.
import argparse
import logging
import json
from pecos.distributed.xmc.base import DistClustering
from pecos.distributed.xmc.xlinear.model import DistributedCPUXLinearModel
from pecos.distributed.xmc.xlinear.model import DistributedCPUXLinearModel as DistXLM
from pecos.distributed.comm.mpi_comm import MPIComm
from pecos.utils import logging_util
from pecos.utils.profile_util import MemInfo
from pecos.xmc import PostProcessor
from pecos.xmc.xlinear.model import XLinearModel
from pecos.xmc.base import HierarchicalKMeans
from pecos.xmc.xlinear.train import parse_arguments


LOGGER = logging.getLogger(__name__)


def parse_arguments():
parser = argparse.ArgumentParser()

parser.add_argument(
"-x",
"--inst-path",
type=str,
required=True,
metavar="PATH",
help="path to the CSR npz or Row-majored npy file of the feature matrix (nr_insts * nr_feats)",
)

parser.add_argument(
"-y",
"--label-path",
type=str,
required=True,
metavar="PATH",
help="path to the CSR npz file of the label matrix (nr_insts * nr_labels)",
)
parser.add_argument(
"-m",
"--model-folder",
type=str,
required=True,
metavar="DIR",
help="path to the model folder.",
)

def add_dist_arguments(parser):
parser.add_argument(
"-nst",
"--min-n-sub-tree",
Expand All @@ -60,13 +31,6 @@ def parse_arguments():
metavar="INT",
help="the minimum number of sub-trees in training step, should be more than number of distributed machines.",
)
parser.add_argument(
"--nr-splits",
type=int,
default=16,
metavar="INT",
help="number of splits used to construct hierarchy (a power of 2 is recommended)",
)
parser.add_argument(
"-mle",
"--meta-label-embedding-method",
Expand All @@ -81,60 +45,6 @@ def parse_arguments():
default="pifa",
help="label embedding method for sub-tree",
)
# Prediction kwargs
parser.add_argument(
"-k",
"--only-topk",
type=int,
default=None,
metavar="INT",
help="the default number of top labels used in the prediction",
)
parser.add_argument(
"-b",
"--beam-size",
type=int,
default=None,
metavar="INT",
help="the default size of beam search used in the prediction",
)
parser.add_argument(
"--max-leaf-size",
type=int,
default=100,
metavar="INT",
help="The max size of the leaf nodes of hierarchical 2-means clustering. Multiple values (separated by comma) are supported and will lead to different individual models for ensembling. (default [100])",
)
parser.add_argument(
"-pp",
"--post-processor",
type=str,
choices=PostProcessor.valid_list(),
default=None,
metavar="STR",
help="the default post processor used in the prediction",
)
parser.add_argument(
"--seed", type=int, default=0, metavar="INT", help="random seed (default 0)"
)
parser.add_argument(
"--bias", type=float, default=1.0, metavar="VAL", help="bias term (default 1.0)"
)
parser.add_argument(
"--max-iter",
type=int,
default=20,
metavar="INT",
help="max iterations for indexer (default 20)",
)
parser.add_argument(
"-n",
"--threads",
type=int,
default=-1,
metavar="INT",
help="number of threads to use (default -1 to denote all the CPUs)",
)
parser.add_argument(
"-mwf",
"--main-workload-factor",
Expand All @@ -143,42 +53,58 @@ def parse_arguments():
metavar="FLOAT",
help="main node vs worker node workload factor in distributed model training",
)
parser.add_argument(
"-t",
"--threshold",
type=float,
default=0.1,
metavar="VAL",
help="threshold to sparsify the model weights.",
)
parser.add_argument(
"--verbose-level",
type=int,
choices=logging_util.log_levels.keys(),
default=2,
metavar="INT",
help=f"the verbose level, {', '.join([str(k) + ' for ' + logging.getLevelName(v) for k, v in logging_util.log_levels.items()])}. Default 2",
)

return parser


def do_train(args):
"""Distributed CPU training and saving XLinear model"""

# Distributed communicator
mpi_comm = MPIComm()
params = dict()
if args.generate_params_skeleton:
params["train_params"] = DistXLM.TrainParams.from_dict({}, recursive=True).to_dict()
params["pred_params"] = DistXLM.PredParams.from_dict({}, recursive=True).to_dict()
params["dist_params"] = DistXLM.DistParams.from_dict({}, recursive=True).to_dict()
params["cluster_params"] = DistClustering.TrainParams.from_dict(
{}, recursive=True
).to_dict()
print(f"{json.dumps(params, indent=True)}")
return

if args.params_path:
with open(args.params_path, "r") as fin:
params = json.load(fin)

# Parse args
args_dict = {k: v for k, v in vars(args).items() if v is not None}
train_params = DistributedCPUXLinearModel.TrainParams.from_dict(args_dict, recursive=True)
cluster_params = DistClustering.ClusterParams(
indexer_params=HierarchicalKMeans.TrainParams.from_dict(args_dict),
meta_label_embedding_method=args.meta_label_embedding_method,
sub_label_embedding_method=args.sub_label_embedding_method,
)
pred_params = DistributedCPUXLinearModel.PredParams.from_dict(args_dict, recursive=True)
dist_params = DistributedCPUXLinearModel.DistParams.from_dict(args_dict)

train_params = params.get("train_params", None)
pred_params = params.get("pred_params", None)
dist_params = params.get("dist_params", None)
cluster_params = params.get("cluster_params", None)

if train_params is not None:
train_params = DistXLM.TrainParams.from_dict(train_params)
else:
train_params = DistXLM.TrainParams.from_dict(args_dict, recursive=True)

if pred_params is not None:
pred_params = DistXLM.PredParams.from_dict(pred_params)
else:
pred_params = DistXLM.PredParams.from_dict(args_dict, recursive=True)

if dist_params is not None:
dist_params = DistXLM.DistParams.from_dict(dist_params)
else:
dist_params = DistXLM.DistParams.from_dict(args_dict)

if cluster_params is not None:
cluster_params = DistClustering.TrainParams.from_dict(cluster_params)
else:
cluster_params = DistClustering.TrainParams.from_dict(args_dict, recursive=True)

# Distributed communicator
mpi_comm = MPIComm()

# Prepare data
LOGGER.info(f"Started loading data on Rank {mpi_comm.get_rank()} ... {MemInfo.mem_info()}")
Expand All @@ -187,7 +113,7 @@ def do_train(args):
LOGGER.info(f"Done loading data on Rank {mpi_comm.get_rank()}. {MemInfo.mem_info()}")

# Train Distributed XLinearModel
xlm = DistributedCPUXLinearModel.train(
xlm = DistXLM.train(
dist_comm=mpi_comm,
X=X,
Y=Y,
Expand All @@ -205,7 +131,7 @@ def do_train(args):


if __name__ == "__main__":
parser = parse_arguments()
parser = add_dist_arguments(parse_arguments())
args = parser.parse_args()
logging_util.setup_logging_config(level=args.verbose_level)
do_train(args)
2 changes: 1 addition & 1 deletion test/pecos/distributed/xmc/test_dist_xmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_dist_clustering():
nr_label = 10

dummy_comm = DummyComm()
cluster_params = DistClustering.ClusterParams(
cluster_params = DistClustering.TrainParams(
indexer_params=HierarchicalKMeans.TrainParams(
nr_splits=2, max_leaf_size=2, threads=1, seed=0
),
Expand Down
2 changes: 1 addition & 1 deletion test/pecos/distributed/xmc/xlinear/test_dist_xlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_dist_training():

dummy_comm = DummyComm()

cluster_params = DistClustering.ClusterParams(
cluster_params = DistClustering.TrainParams(
indexer_params=HierarchicalKMeans.TrainParams(
nr_splits=2, max_leaf_size=2, threads=1, seed=0
),
Expand Down

0 comments on commit cef885f

Please sign in to comment.