diff --git a/examples/xr-transformer-neurips21/run.sh b/examples/xr-transformer-neurips21/run.sh index 29f22f4..92c3deb 100644 --- a/examples/xr-transformer-neurips21/run.sh +++ b/examples/xr-transformer-neurips21/run.sh @@ -5,7 +5,7 @@ if [ ${data} == "eurlex-4k" ]; then models=(bert roberta xlnet) ens_method=softmax_average elif [ ${data} == "wiki10-31k" ]; then - models=(bert roberta xlnet) + models=(bert) ens_method=rank_average elif [ ${data} == "amazoncat-13k" ]; then models=(bert roberta xlnet) diff --git a/pecos/distributed/diagnostic_util/__init__.py b/pecos/distributed/diagnostic_tools/__init__.py similarity index 100% rename from pecos/distributed/diagnostic_util/__init__.py rename to pecos/distributed/diagnostic_tools/__init__.py diff --git a/pecos/distributed/diagnostic_tools/deepspeed_comm.py b/pecos/distributed/diagnostic_tools/deepspeed_comm.py new file mode 100644 index 0000000..cc644ae --- /dev/null +++ b/pecos/distributed/diagnostic_tools/deepspeed_comm.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 -u +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import argparse +import deepspeed +import os +import tempfile +import torch +import torch.nn as nn +import torch.nn.functional as F +import logging +import socket +import signal +from pecos.utils import logging_util +from pecos.distributed.xmc.xtransformer.module import DeepSpeedUtils + +logging.getLogger(torch.__name__).setLevel(logging.WARNING) +logging.getLogger("DeepSpeed").setLevel(logging.WARNING) + +LOGGER = logging.getLogger(__name__) + + +def parse_arguments(): + """Parse evaluation arguments""" + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--shared-workdir", + type=str, + metavar="PATH", + default=None, + help="the shared workdir which can be accessed by each worker. Default None to disable check", + ) + + parser.add_argument( + "--timeout", + type=int, + default=60, + metavar="INT", + help=f"timeout in seconds for the diagnostic check. Default 60", + ) + + parser.add_argument( + "--verbose-level", + type=int, + choices=logging_util.log_levels.keys(), + default=2, + metavar="INT", + help=f"the verbose level, {', '.join([str(k) + ' for ' + logging.getLevelName(v) for k, v in logging_util.log_levels.items()])}. Default 2", + ) + + parser.add_argument( + "--local_rank", + type=int, + default=0, + metavar="RANK", + help="local rank passed from torch distributed launcher", + ) + return parser + + +def distributed_cluster_check(workdir=None, timeout=60): + """ + Check deepspeed distributed setup using a dummy model + + Args: + timeout (int): number of seconds to wait before raising exception. + Default 60. + """ + + class TimeOutException(Exception): + pass + + def alarm_handler(signum, frame): + raise TimeOutException() + + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + + def forward(self, x): + return F.relu(self.conv1(x)) + + signal.signal(signal.SIGALRM, alarm_handler) + signal.alarm(timeout) + + try: + model = DummyModel() + ds_config = DeepSpeedUtils.get_config_from_params() + _ = deepspeed.initialize( + model=model, + config_params=ds_config, + model_parameters=model.parameters(), + ) + torch_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + ip = socket.gethostbyname(socket.gethostname()) + LOGGER.info(f"Report from {ip}: torch_rank={torch_rank}, world_size={world_size}") + + if workdir is not None: + workdir = os.path.abspath(workdir) + + if torch_rank == 0: + master_stamp = tempfile.TemporaryDirectory(dir=workdir) + master_stamp_name = [f"{master_stamp.name}/_MASTER_STAMP"] + open(master_stamp_name[0], "w").close() + else: + master_stamp_name = [None] + + torch.distributed.broadcast_object_list(master_stamp_name, 0) + + if not os.path.isfile(master_stamp_name[0]): + raise ValueError(f"Rank{torch_rank} not able to access workdir at {workdir}") + else: + LOGGER.info(f"Rank{torch_rank} is able to access workdir at {workdir}") + torch.distributed.barrier() + + except TimeOutException as ex: + raise ex + signal.alarm(0) + + +if __name__ == "__main__": + """ + Sanity check for deepspeed distributed + + Usage: + deepspeed --hostfile [PATH_TO_HOSTFILE] --module pecos.distributed.diagnostic_tools.deepspeed_comm + """ + parser = parse_arguments() + args = parser.parse_args() + logging_util.setup_logging_config(level=args.verbose_level) + distributed_cluster_check(workdir=args.shared_workdir, timeout=args.timeout) diff --git a/pecos/distributed/diagnostic_util/mpi_comm_test.py b/pecos/distributed/diagnostic_tools/mpi_comm.py similarity index 100% rename from pecos/distributed/diagnostic_util/mpi_comm_test.py rename to pecos/distributed/diagnostic_tools/mpi_comm.py diff --git a/pecos/distributed/diagnostic_util/test_util.py b/pecos/distributed/diagnostic_tools/test_util.py similarity index 100% rename from pecos/distributed/diagnostic_util/test_util.py rename to pecos/distributed/diagnostic_tools/test_util.py diff --git a/pecos/distributed/xmc/xlinear/train.py b/pecos/distributed/xmc/xlinear/train.py index 2b363c4..12154cf 100644 --- a/pecos/distributed/xmc/xlinear/train.py +++ b/pecos/distributed/xmc/xlinear/train.py @@ -63,7 +63,7 @@ def parse_arguments(): parser.add_argument( "--nr-splits", type=int, - default=2, + default=16, metavar="INT", help="number of splits used to construct hierarchy (a power of 2 is recommended)", ) @@ -86,7 +86,7 @@ def parse_arguments(): "-k", "--only-topk", type=int, - default=20, + default=None, metavar="INT", help="the default number of top labels used in the prediction", ) @@ -94,7 +94,7 @@ def parse_arguments(): "-b", "--beam-size", type=int, - default=10, + default=None, metavar="INT", help="the default size of beam search used in the prediction", ) @@ -110,7 +110,7 @@ def parse_arguments(): "--post-processor", type=str, choices=PostProcessor.valid_list(), - default="l3-hinge", + default=None, metavar="STR", help="the default post processor used in the prediction", ) @@ -170,7 +170,7 @@ def do_train(args): mpi_comm = MPIComm() # Parse args - args_dict = vars(args) + args_dict = {k: v for k, v in vars(args).items() if v is not None} train_params = DistributedCPUXLinearModel.TrainParams.from_dict(args_dict, recursive=True) cluster_params = DistClustering.ClusterParams( indexer_params=HierarchicalKMeans.TrainParams.from_dict(args_dict), diff --git a/pecos/distributed/xmc/xtransformer/__init__.py b/pecos/distributed/xmc/xtransformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pecos/distributed/xmc/xtransformer/dist_trainer.py b/pecos/distributed/xmc/xtransformer/dist_trainer.py new file mode 100644 index 0000000..df26353 --- /dev/null +++ b/pecos/distributed/xmc/xtransformer/dist_trainer.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 -u +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import argparse +import deepspeed +import os +import json +import time +import torch +import logging +import random + +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader, RandomSampler, TensorDataset + +from pecos.utils import cli, logging_util +from pecos.xmc.xtransformer.matcher import TransformerMatcher +from pecos.xmc.xtransformer.module import XMCTextDataset + +from .module import AllInOneForXMCModel, DeepSpeedUtils + +LOGGER = logging.getLogger(__name__) + + +def parse_arguments(): + """Parse arguments""" + + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--data-path", + type=str, + required=True, + metavar="PATH", + help="path to the training data (XMCTextDataset) to load", + ) + + parser.add_argument( + "-m", + "--model-path", + type=str, + required=True, + metavar="PATH", + help="path to the TransformerMatcher model to fine-tune", + ) + + parser.add_argument( + "-o", + "--output-path", + type=str, + required=True, + metavar="PATH", + help="path to save the output checkpoint", + ) + + parser.add_argument( + "-p", + "--params-path", + type=str, + default=None, + metavar="PARAMS_PATH", + help="Json file for params (default None)", + ) + + parser.add_argument( + "--verbose-level", + type=int, + choices=logging_util.log_levels.keys(), + default=2, + metavar="INT", + help=f"the verbose level, {', '.join([str(k) + ' for ' + logging.getLevelName(v) for k, v in logging_util.log_levels.items()])}. Default 2", + ) + + parser.add_argument( + "--fp16", + type=cli.str2bool, + metavar="[true/false]", + default=True, + help="If true, do half-precision training", + ) + + parser.add_argument( + "--shard-scheme", + type=str, + choices=["synchronized", "ordered"], + metavar="STR", + default="synchronized", + help="access scheme for training data shards", + ) + + parser.add_argument( + "--local_rank", + type=int, + default=0, + metavar="RANK", + help="local rank passed from torch distributed launcher", + ) + return parser + + +def dist_fine_tune(args): + """Fine tune on a single XMC task + + Args: + args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` + """ + # env set by deepspeed launcher + world_size = int(os.getenv("WORLD_SIZE", "1")) + + # === get fine-tuning params === + if args.params_path: + with open(args.params_path, "r") as fin: + params = json.load(fin) + + train_params = params.get("train_params", None) + pred_params = params.get("pred_params", None) + + train_params = TransformerMatcher.TrainParams.from_dict(train_params) + train_params.fp16 = args.fp16 + + pred_params = TransformerMatcher.PredParams.from_dict(pred_params) + + # === load training data and model === + data_stats = XMCTextDataset.get_data_stats(args.data_path) + num_instances = data_stats["num_instances"] + num_shards = data_stats["num_shards"] + LOGGER.info(f"Train data info loaded from {args.data_path}") + + model = AllInOneForXMCModel.load(args.model_path) + loss_function = TransformerMatcher.get_loss_function(train_params.loss_function) + LOGGER.info(f"Model loaded from {args.model_path}") + + # === compute stopping criteria === + total_batch_size = world_size * train_params.batch_size + batches_per_epoch = (num_instances + total_batch_size - 1) // total_batch_size + steps_per_epoch = batches_per_epoch // train_params.gradient_accumulation_steps + + if train_params.max_steps > 0: + t_total = train_params.max_steps + train_params.num_train_epochs = ( + train_params.max_steps + steps_per_epoch - 1 + ) // steps_per_epoch + else: + t_total = steps_per_epoch * train_params.num_train_epochs + train_params.max_steps = t_total + + train_params.save_steps = min(train_params.save_steps, t_total) + train_params.logging_steps = min(train_params.logging_steps, t_total) + + # === setup deepspeed config and engin === + ds_config = DeepSpeedUtils.get_config_from_params(train_params) + params_to_optmize = model.prepare_params(train_params.weight_decay) + ds_engine, _, _, scheduler = deepspeed.initialize( + model=model, + config_params=ds_config, + model_parameters=params_to_optmize, + training_data=TensorDataset(torch.zeros((num_instances,))), + ) + + # === start fine-tuning === + global_rank = torch.distributed.get_rank() + if world_size > num_shards and args.shard_scheme != "synchronized": + LOGGER.warning(f"more workers than shards, fall back to synchronized shard access") + args.shard_scheme = "synchronized" + + if global_rank == 0: + LOGGER.info("***** Running training *****") + LOGGER.info(" Num examples = %d", num_instances) + LOGGER.info(" Num labels = %d", model.nr_labels) + LOGGER.info(" Num Epochs = %d", train_params.num_train_epochs) + LOGGER.info(" Learning Rate Schedule = %s", train_params.lr_schedule) + LOGGER.info(" Batch size = %d", total_batch_size) + LOGGER.info(" Gradient Accumulation steps = %d", train_params.gradient_accumulation_steps) + LOGGER.info(" Num batches per epoch = %d", batches_per_epoch) + LOGGER.info(" Total optimization steps = %d", train_params.max_steps) + LOGGER.info(" Dist World Size = %d", world_size) + LOGGER.info(" Data shard access scheme = %s", args.shard_scheme) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + total_train_time, logging_elapsed = 0.0, 0.0 + logging_steps = train_params.logging_steps + + for epoch in range(1, int(train_params.num_train_epochs) + 1): + shard_access_order = list(range(num_shards)) + random.shuffle(shard_access_order) + + if args.shard_scheme == "ordered": + shard_load_per_epoch = (num_shards + world_size - 1) // world_size + shard_access_order = [ + (global_rank + kk * world_size) % num_shards for kk in range(shard_load_per_epoch) + ] + + ds_engine.module.train() + + # load next shard + for shard_id in shard_access_order: + train_data = XMCTextDataset.load(args.data_path, shard=shard_id) + + # data loader and actual max_steps + if args.shard_scheme == "synchronized": + sampler = DistributedSampler( + train_data, + num_replicas=world_size, + rank=global_rank, + ) + else: + sampler = RandomSampler(train_data) + + train_dataloader = DataLoader( + train_data, + sampler=sampler, + pin_memory=False, + batch_size=train_params.batch_size, + num_workers=train_params.batch_gen_workers // max(world_size, 1), + ) + LOGGER.debug( + f"Rank{global_rank}/{world_size}: Training data({shard_id}/{num_shards}) loaded, num_batches={len(train_dataloader)}" + ) + + for batch_cnt, batch in enumerate(train_dataloader): + start_time = time.time() + batch = tuple(t.to(ds_engine.device) for t in batch) + inputs = { + "input_ids": batch[0], + "attention_mask": batch[1], + "token_type_ids": batch[2], + "instance_number": batch[3], + "label_values": batch[4], + "label_indices": batch[-1] if train_data.has_ns else None, + } + logits = ds_engine(**inputs) + loss = loss_function(logits, inputs["label_values"]).mean() + ds_engine.backward(loss) + ds_engine.step() + + tr_loss += loss.detach().item() + global_step += 1 + logging_elapsed += time.time() - start_time + total_train_time += time.time() - start_time + + if global_rank == 0: + if logging_steps > 0 and global_step % logging_steps == 0: + cur_loss = (tr_loss - logging_loss) / logging_steps + + # incase .step() hasn't been called (fp16 could skip steps) + try: + cur_lr = scheduler.get_last_lr()[0] + except AssertionError: + cur_lr = 0 + + LOGGER.info( + "| [{:4d}/{:4d}][{:6d}/{:6d}] | {:4d}/{:4d} batches | ms/batch {:5.4f} | train_loss {:6e} | lr {:.6e}".format( + int(epoch), + int(train_params.num_train_epochs), + int(global_step), + int(train_params.max_steps), + int(batch_cnt), + batches_per_epoch, + logging_elapsed * 1000.0 / logging_steps, + cur_loss, + cur_lr, + ) + ) + logging_loss = tr_loss + logging_elapsed = 0 + + if global_step % train_params.save_steps == 0: + ds_engine.module.save(args.output_path) + + if global_step >= train_params.max_steps: + # within shard + break + if global_step >= train_params.max_steps: + # within epoch + break + if global_step >= train_params.max_steps: + # outmost loop + break + if global_rank == 0: + ds_engine.module.save(args.output_path) + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + parser = parse_arguments() + args = parser.parse_args() + logging_util.setup_logging_config(level=args.verbose_level) + dist_fine_tune(args) diff --git a/pecos/distributed/xmc/xtransformer/encode.py b/pecos/distributed/xmc/xtransformer/encode.py new file mode 100644 index 0000000..8f42569 --- /dev/null +++ b/pecos/distributed/xmc/xtransformer/encode.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 -u +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import argparse +import deepspeed +import os +import torch +import logging +import itertools + +from pecos.utils import smat_util, logging_util +from pecos.xmc.xtransformer.model import XTransformer + +LOGGER = logging.getLogger(__name__) + + +def parse_arguments(): + """Parse evaluation arguments""" + + parser = argparse.ArgumentParser() + + parser.add_argument( + "-t", + "--text-path", + type=str, + required=True, + metavar="PATH", + help="Path to the instance text file.", + ) + parser.add_argument( + "-m", + "--model-folder", + type=str, + required=True, + metavar="PATH", + help="Path to load x-transformer model.", + ) + parser.add_argument( + "-o", + "--save-emb-folder", + type=str, + required=True, + metavar="PATH", + help="The folder in which the embeddings will be written (in WORLD_SIZE shards)", + ) + + # ======= Other parameters ======== + parser.add_argument( + "--batch-size", + default=32, + type=int, + metavar="INT", + help="Batch size per GPU.", + ) + parser.add_argument( + "--batch-gen-workers", + type=int, + metavar="INT", + default=4, + help="number of CPUs to use for batch generation", + ) + parser.add_argument( + "--truncate-length", + default=None, + type=int, + metavar="INT", + help="max number of tokens to encode", + ) + parser.add_argument( + "--max-pred-chunk", + default=10**7, + metavar="INT", + type=int, + help="Max number of instances to predict on at once, set to avoid OOM. Set to None to predict on all instances at once. Default 10^7", + ) + parser.add_argument( + "--verbose-level", + type=int, + choices=logging_util.log_levels.keys(), + default=1, + metavar="INT", + help=f"the verbose level, {', '.join([str(k) + ' for ' + logging.getLevelName(v) for k, v in logging_util.log_levels.items()])}, default 1", + ) + + parser.add_argument( + "--local_rank", + type=int, + default=0, + metavar="RANK", + help="local rank passed from torch distributed launcher", + ) + return parser + + +def dist_encode(args): + """Encode with XTransformer model in distributed fashion. + Each worker will encode an exclusive chunk and save the result to args.save_emb_folder/X.emb.[WORKER_RANK].npy + + Args: + args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` + """ + os.makedirs(args.save_emb_folder, exist_ok=True) + + deepspeed.init_distributed(dist_backend="nccl") + global_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + LOGGER.info( + f"Initialized device for rank={args.local_rank}, global_rank={global_rank}, world_size={world_size}" + ) + xtf = XTransformer.load(args.model_folder) + + # get number of lines without reading all in memory + nr_inst = sum(1 for line in open(args.text_path, "r")) + chunk_size = (nr_inst + world_size - 1) // world_size + start = global_rank * chunk_size + end = min(nr_inst, start + chunk_size) + + with open(args.text_path, "r") as fin: + X = [] + for line in itertools.islice(fin, start, end): + X.append(line.strip()) + LOGGER.info(f"Rank{global_rank}/{world_size} will encode {start} to {end}") + + pred_params = xtf.get_pred_params() + for i in range(len(pred_params.matcher_params_chain)): + if args.truncate_length: + pred_params.matcher_params_chain[i].truncate_length = args.truncate_length + + X_emb = xtf.encode( + X, + batch_size=args.batch_size, + batch_gen_workers=args.batch_gen_workers, + device_id=args.local_rank, + pred_params=pred_params, + max_pred_chunk=args.max_pred_chunk, + ) + + local_tgt = os.path.join(args.save_emb_folder, f"X.emb.{global_rank}.npy") + smat_util.save_matrix(local_tgt, X_emb) + LOGGER.info(f"Rank{global_rank}/{world_size} saved embedding {X_emb.shape} to {local_tgt}") + + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + parser = parse_arguments() + args = parser.parse_args() + logging_util.setup_logging_config(level=args.verbose_level) + dist_encode(args) diff --git a/pecos/distributed/xmc/xtransformer/model.py b/pecos/distributed/xmc/xtransformer/model.py new file mode 100644 index 0000000..79b939d --- /dev/null +++ b/pecos/distributed/xmc/xtransformer/model.py @@ -0,0 +1,332 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import json +import logging +import os +import gc +import tempfile + +import numpy as np +import scipy.sparse as smat +import dataclasses as dc +import pecos +from pecos.core import clib +from pecos.utils.cluster_util import ClusterChain +from pecos.xmc.base import HierarchicalMLModel +from pecos.xmc.xtransformer.model import XTransformer + +from pecos.xmc.xtransformer.matcher import TransformerMatcher +from pecos.xmc.xtransformer.module import MLProblemWithText +from pecos.xmc.xtransformer.network import TransformerLinearXMCHead +from . import dist_trainer +from .module import DeepSpeedUtils as ds_utils + + +LOGGER = logging.getLogger(__name__) + + +class XTransformerDistTrainer(XTransformer): + """Distributed Trainer for Hierarchical-XTransformer for XMC.""" + + @dc.dataclass + class DistParams(pecos.BaseParams): + """Distributed Parameters of XTransformer. + + fp16 (bool): whether to use half-precision in model fine-tuning + max_shard_size (int): max number of instances within a data shard. Default 10 ** 7 + shard_scheme (str): data shard access scheme. [synchronized, ordered]. + Default synchronized that every worker loads the same shard + hostfile (str): path to the hostfile for distributed training. + shared_workdir (str): a shared workdir by all workers to cache and read data/params. + """ + + fp16: bool = False + max_shard_size: int = 10**7 + shard_scheme: str = "synchronized" + hostfile: str = "" + shared_workdir: str = "." + + @staticmethod + def get_label_hierarchy(Y, clustering): + """Get label hierarchy for multi-resolution training + + Args: + Y (csr_matrix): the label matrix with shape = (nr_inst, nr_labels) + clustering (ClusterChain): Hierarchical label tree with depth D + + Returns: + YC_list (list): list of length D. YC_list[d] is the label matrix at level d, shape=(nr_inst, K^(d)) + """ + + YC_list = [Y] + for cur_C in reversed(clustering[1:]): + Y_t = clib.sparse_matmul(YC_list[-1], cur_C, threads=min(32, os.cpu_count())).tocsr() + YC_list.append(Y_t) + YC_list.reverse() + return YC_list + + @staticmethod + def get_pretrained(model_shortcut, num_labels=None, hidden_dropout_prob=0.1): + """Get pretrained transformer model for fine-tuning + + Args: + model_shortcut (str, TransformerMatcher): + if TransformerMatcher, do nothing + if local path to serialized TransformerMatcher, load from disk + else if model identifier, download pre-trained model frim huggingface repo. + num_labels (int): the number of label embeddings the model is expected to have. + If existing num_labels is inconsistent with given value, new label embeddings will be randomly initialized. + hidden_dropout_prob (float): training dropout probabilities for the hidden vectors. + Only used when downloading the model from huggingface repo. + + Returns: + TransformerMatcher with num_labels label embeddings + + """ + if isinstance(model_shortcut, TransformerMatcher): + parent_model = model_shortcut + elif os.path.exists(model_shortcut): + # load from local + parent_model = TransformerMatcher.load(model_shortcut) + LOGGER.info("Loaded model from {}.".format(model_shortcut)) + else: + # download from huggingface repo + parent_model = TransformerMatcher.download_model( + model_shortcut, + num_labels, + hidden_dropout_prob=hidden_dropout_prob, + ) + LOGGER.info("Downloaded {} model from s3.".format(model_shortcut)) + + if num_labels is not None and num_labels != parent_model.nr_labels: + LOGGER.warning( + f"Got mismatch nr_labels (expected {num_labels} but got {parent_model.nr_labels}), text_model reinitialized!" + ) + parent_model.text_model = TransformerLinearXMCHead( + parent_model.text_encoder.config.hidden_size, num_labels + ) + parent_model.text_encoder.config.num_labels = num_labels + return parent_model + + @classmethod + def train( + cls, + prob, + clustering=None, + train_params=None, + pred_params=None, + dist_params=None, + **kwargs, + ): + """Train the XR-Transformer model with the given input data. + + Args: + prob (MLProblemWithText): ML problem to solve. + clustering (ClusterChain, optional): preliminary hierarchical label tree, + where transformer is fine-tuned on. + train_params (XTransformer.TrainParams): training parameters for XTransformer + pred_params (XTransformer.PredParams): pred parameters for XTransformer + dist_params (XTransformerDistTrainer.DistParams): distributed parameters + for XTransformerDistTrainer + kwargs: + beam_size (int, optional): overrides only_topk for models except + bottom layer one + + Returns: + XTransformer + """ + + # construct train_params + if train_params is None: + # fill all BaseParams class with their default value + train_params = cls.TrainParams.from_dict(dict(), recursive=True) + else: + train_params = cls.TrainParams.from_dict(train_params) + # construct pred_params + if pred_params is None: + # fill all BaseParams with their default value + pred_params = cls.PredParams.from_dict(dict(), recursive=True) + else: + pred_params = cls.PredParams.from_dict(pred_params) + # construct dist_params + if dist_params is None: + # fill all BaseParams with their default value + dist_params = cls.DistParams.from_dict(dict(), recursive=True) + else: + dist_params = cls.DistParams.from_dict(dist_params) + + if not train_params.only_encoder: + LOGGER.warning( + f"Distributed fine-tuning is only for encoder, fall back to only_encoder=True" + ) + train_params.only_encoder = True + + if not train_params.do_fine_tune: + LOGGER.warning(f"do_fine_tune is set to False, override to do_fine_tune=True..") + train_params.do_fine_tune = True + + # 1. Constructing primary Hierarchial Label Tree + if clustering is None: + clustering = smat.csc_matrix(np.ones((prob.nr_labels, 1)), dtype=np.float32) + + clustering = ClusterChain(clustering) + if clustering[-1].shape[0] != prob.nr_labels: + raise ValueError("nr_labels mismatch!") + + prelim_hierarchiy = [cc.shape[0] for cc in clustering] + LOGGER.info("Hierarchical label tree: {}".format(prelim_hierarchiy)) + + # 1.1 Get the fine-tuning task numbers + nr_transformers = sum(i <= train_params.max_match_clusters for i in prelim_hierarchiy) + + LOGGER.info( + "Fine-tune Transformers with nr_labels={}".format( + [cc.shape[0] for cc in clustering[:nr_transformers]] + ) + ) + + # 1.2 construct fields with chain now we know the depth + train_params = HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain( + train_params, cls.TrainParams, nr_transformers + ) + + LOGGER.debug( + f"XTransformer train_params: {json.dumps(train_params.to_dict(), indent=True)}" + ) + + pred_params = HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain( + pred_params, cls.PredParams, nr_transformers + ) + pred_params = pred_params.override_with_kwargs(kwargs) + + LOGGER.debug(f"XTransformer pred_params: {json.dumps(pred_params.to_dict(), indent=True)}") + LOGGER.debug( + f"XTransformerDistTrainer dist_params: {json.dumps(dist_params.to_dict(), indent=True)}" + ) + # construct label chain for training and validation set + YC_list = cls.get_label_hierarchy(prob.Y, clustering) + + parent_model = train_params.matcher_params_chain[0].model_shortcut + if train_params.matcher_params_chain[0].init_model_dir: + parent_model = train_params.matcher_params_chain[0].init_model_dir + + prev_head = None + for i in range(nr_transformers): + cur_train_params = train_params.matcher_params_chain[i] + cur_pred_params = pred_params.matcher_params_chain[i] + + # only support TFN + M = YC_list[i - 1] if i > 0 else None + + cur_prob = MLProblemWithText( + prob.X_text, + YC_list[i], + X_feat=None, + C=clustering[i], + M=M, + ) + + avr_trn_labels = ( + float(cur_prob.M.nnz) / YC_list[i].shape[0] + if cur_prob.M is not None + else YC_list[i].shape[1] + ) + LOGGER.info( + "Fine-tuning XR-Transformer with {} at level {}, nr_labels={}, avr_M_nnz={}".format( + "tfn", i, YC_list[i].shape[1], avr_trn_labels + ) + ) + + # construct TransformerMatcher instance for this layer + parent_model = cls.get_pretrained( + parent_model, + num_labels=cur_prob.Y.shape[1], + hidden_dropout_prob=train_params.matcher_params_chain[0].hidden_dropout_prob, + ) + parent_model.C = cur_prob.C + parent_model.train_params = cur_train_params + parent_model.pred_params = cur_pred_params + if cur_train_params.bootstrap_method == "inherit" and i > 0: + parent_model.text_model.inherit(prev_head, cur_prob.C, sparse=False) + LOGGER.info("Initialized transformer text_model from parent layer!") + elif cur_train_params.bootstrap_method == "no-bootstrap" or i == 0: + parent_model.text_model.random_init(sparse=False) + LOGGER.info("Randomly initialized transformer text_model!") + else: + raise ValueError( + f"bootstrap_method={cur_train_params.bootstrap_method} not supported in distributed training!" + ) + + if cur_train_params.pre_tokenize: + if not prob.is_tokenized: + prob.X_text = parent_model.text_to_tensor( + prob.X_text, + max_length=cur_pred_params.truncate_length, + ) + + # temp folder in workdir + temp_dir = tempfile.TemporaryDirectory(dir=dist_params.shared_workdir) + temp_data_dir = temp_dir.name + temp_data_path = f"{temp_data_dir}/train_data" + temp_params_path = f"{temp_data_dir}/param.json" + + if cur_train_params.checkpoint_dir: + temp_encoder_path = cur_train_params.checkpoint_dir + else: + temp_encoder_path = f"{temp_data_path}/encoder" + + # construct dataset and save into shards + train_data = TransformerMatcher.prepare_data( + cur_prob, + label_padding_idx=parent_model.text_model.label_padding_idx, + pre_tensorize_labels=cur_train_params.pre_tensorize_labels, + input_transform=None if cur_prob.is_tokenized else parent_model._tokenize, + ) + num_shards = ( + len(train_data) + dist_params.max_shard_size - 1 + ) // dist_params.max_shard_size + train_data.save(temp_data_path, num_shards=num_shards) + del train_data + gc.collect() + LOGGER.info(f"Cached train_data to {temp_data_path} with {num_shards}") + + temp_params = { + "train_params": cur_train_params.to_dict(), + "pred_params": cur_pred_params.to_dict(), + "dist_params": dist_params.to_dict(), + } + with open(temp_params_path, "w") as fout: + fout.write(json.dumps(temp_params)) + LOGGER.info(f"Cached params: {json.dumps(temp_params)}") + + parent_model.save(temp_encoder_path) + + # start distributed training + ds_utils.cli_launcher( + dist_trainer.__name__, + hostfile=dist_params.hostfile, + module_args={ + "data_path": temp_data_path, + "model_path": temp_encoder_path, + "output_path": temp_encoder_path, + "params_path": temp_params_path, + "fp16": 1 if dist_params.fp16 else 0, + "shard_scheme": dist_params.shard_scheme, + }, + ) + + LOGGER.info("Reload the best checkpoint from {}".format(temp_encoder_path)) + parent_model = TransformerMatcher.load(temp_encoder_path) + parent_model.clear_cuda() + prev_head = parent_model.text_model + + return cls(parent_model, None) diff --git a/pecos/distributed/xmc/xtransformer/module.py b/pecos/distributed/xmc/xtransformer/module.py new file mode 100644 index 0000000..59c3994 --- /dev/null +++ b/pecos/distributed/xmc/xtransformer/module.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 -u +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. + +import os +import torch +import logging + +from pecos.xmc.xtransformer.matcher import TransformerMatcher + +LOGGER = logging.getLogger(__name__) + + +class AllInOneForXMCModel(torch.nn.Module): + """Wrapper class to pack transformer encoder and label embeddings + + Args: + encoder (BertForXMC, RobertaForXMC, XLMRobertaForXMC, XLNetForXMC, DistilBertForXMC) + label_embedding (TransformerLinearXMCHead) + + """ + + def __init__(self, encoder, label_embedding): + super().__init__() + self.encoder = encoder + self.label_embedding = label_embedding + + @classmethod + def load(cls, load_dir): + matcher = TransformerMatcher.load(load_dir) + return cls(matcher.text_encoder, matcher.text_model) + + def save(self, save_dir): + encoder_dir = os.path.join(save_dir, "text_encoder") + os.makedirs(encoder_dir, exist_ok=True) + # this creates config.json, pytorch_model.bin + self.encoder.save_pretrained(encoder_dir) + text_model_dir = os.path.join(save_dir, "text_model") + torch.save(self.label_embedding, text_model_dir) + + @property + def nr_labels(self): + return self.label_embedding.num_labels + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + instance_number=None, + label_values=None, + label_indices=None, + ): + pooled_output = self.encoder( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + )["pooled_output"] + + W_act, b_act = self.label_embedding( + output_indices=label_indices, + num_device=len(self.device_ids) if hasattr(self, "device_ids") else 1, + ) + logits = (pooled_output.unsqueeze(1) * W_act).sum(dim=-1) + b_act.squeeze(2) + return logits + + def prepare_params(self, weight_decay=0): + no_decay = ["bias", "LayerNorm.weight"] + all_params = [ + { + "params": [ + p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + + return all_params + + +class DeepSpeedUtils(object): + """Utility functions to use DeepSpeed""" + + @staticmethod + def cli_launcher(module_name, hostfile=None, module_args={}): + """Deepspeed launcher method + + Args: + module_name (str): the python module to launch + hostfile (str, optional): hostfile to launch distributed job. + Default None to use local multi-gpu distribution. + module_args (dict, optional): arguments and corresponding values to pass to the module. + + Example: + module_args = {'arg_name_1': v1, 'arg_name_2': v2} + Actual command + deepspeed --hostfile [hostfile] --module [module_name] \ + --arg-name-1 v1 --arg-name-2 v2 + """ + + worker_cmd = [f"deepspeed"] + if hostfile: + worker_cmd += [f"--hostfile {hostfile}"] + else: + worker_cmd += [f"--num_gpus {torch.cuda.device_count()}"] + + worker_cmd += [f"--module {module_name}"] + + for k, v in module_args.items(): + worker_cmd += [f"--{k.replace('_', '-')} {v}"] + worker_cmd = [" ".join(worker_cmd)] + LOGGER.info(f"Actual command: {worker_cmd}") + + import subprocess + + subprocess.check_call( + worker_cmd, + shell=True, + encoding="utf-8", + universal_newlines=True, + bufsize=1, + ) + + @staticmethod + def get_config_from_params(train_params=None): + """Construct DeepSpeed config from TransformerMatcher.TrainParams""" + + if train_params is None: + train_params = TransformerMatcher.TrainParams() + + ds_config = { + "fp16": {"enabled": train_params.fp16 if hasattr(train_params, "fp16") else False}, + "bf16": {"enabled": False}, + "zero_optimization": { + "stage": 1, + "overlap_comm": True, + "contiguous_gradients": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": train_params.learning_rate, + "betas": [0.9, 0.999], + "eps": train_params.adam_epsilon, + "weight_decay": train_params.weight_decay, + }, + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": train_params.max_steps, + "warmup_min_lr": 0, + "warmup_max_lr": train_params.learning_rate, + "warmup_num_steps": train_params.warmup_steps, + "warmup_type": "linear", + }, + }, + "steps_per_print": train_params.logging_steps, + "sparse_gradients": True, + "gradient_clipping": train_params.max_grad_norm, + "train_micro_batch_size_per_gpu": train_params.batch_size, + "gradient_accumulation_steps": train_params.gradient_accumulation_steps, + "wall_clock_breakdown": False, + "dump_state": False, + } + return ds_config diff --git a/pecos/distributed/xmc/xtransformer/train.py b/pecos/distributed/xmc/xtransformer/train.py new file mode 100644 index 0000000..177496e --- /dev/null +++ b/pecos/distributed/xmc/xtransformer/train.py @@ -0,0 +1,189 @@ +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance +# with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions +# and limitations under the License. +import json +import logging +import os + +import numpy as np +from pecos.utils import cli, logging_util, smat_util, torch_util +from pecos.utils.cluster_util import ClusterChain + +from pecos.utils.featurization.text.preprocess import Preprocessor +from pecos.xmc.xtransformer.model import XTransformer +from pecos.xmc.xtransformer.module import MLProblemWithText +from pecos.xmc.xtransformer.train import parse_arguments +from pecos.distributed.diagnostic_tools import deepspeed_comm as ds_diagnose +from .model import XTransformerDistTrainer +from .module import DeepSpeedUtils as ds_utils + +LOGGER = logging.getLogger(__name__) + + +def add_dist_arguments(parser): + """Add distributed training arguments""" + # ========= train data paths ============ + parser.add_argument( + "--hostfile", + type=str, + metavar="PATH", + default="", + help="path to the hostfile", + ) + parser.add_argument( + "--fp16", + type=cli.str2bool, + metavar="[true/false]", + default=True, + help="If true, do half-precision training", + ) + + parser.add_argument( + "--shard-scheme", + type=str, + choices=["synchronized", "ordered"], + metavar="STR", + default="synchronized", + help="access scheme for training data shards", + ) + + parser.add_argument( + "--shared-workdir", + type=str, + metavar="PATH", + default=".", + help="the shared workdir for distributed training", + ) + parser.add_argument( + "--max-shard-size", + type=int, + default=10**7, + metavar="INT", + help="max number of instances in each data shard", + ) + return parser + + +def do_train(args): + """Train and save XR-Transformer model. + + Args: + args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` + """ + params = dict() + if args.generate_params_skeleton: + params["train_params"] = XTransformer.TrainParams.from_dict({}, recursive=True).to_dict() + params["pred_params"] = XTransformer.PredParams.from_dict({}, recursive=True).to_dict() + params["dist_params"] = XTransformerDistTrainer.DistParams.from_dict( + {}, recursive=True + ).to_dict() + print(f"{json.dumps(params, indent=True)}") + return + + if args.params_path: + with open(args.params_path, "r") as fin: + params = json.load(fin) + + train_params = params.get("train_params", None) + pred_params = params.get("pred_params", None) + dist_params = params.get("dist_params", None) + + if train_params is not None: + train_params = XTransformer.TrainParams.from_dict(train_params) + else: + train_params = XTransformer.TrainParams.from_dict( + {k: v for k, v in vars(args).items() if v is not None}, + recursive=True, + ) + + if pred_params is not None: + pred_params = XTransformer.PredParams.from_dict(pred_params) + else: + pred_params = XTransformer.PredParams.from_dict( + {k: v for k, v in vars(args).items() if v is not None}, + recursive=True, + ) + + if dist_params is not None: + dist_params = XTransformerDistTrainer.DistParams.from_dict(dist_params) + else: + dist_params = XTransformerDistTrainer.DistParams.from_dict( + {k: v for k, v in vars(args).items() if v is not None}, + recursive=True, + ) + + # health check for distributed cluster and shared dir + ds_utils.cli_launcher( + ds_diagnose.__name__, + hostfile=dist_params.hostfile, + module_args={ + "timeout": 60, + "shared_workdir": dist_params.shared_workdir, + "verbose_level": 2, + }, + ) + + torch_util.set_seed(args.seed) + LOGGER.info("Setting random seed {}".format(args.seed)) + + # Load training feature + if args.trn_feat_path: + LOGGER.warning(f"Numerical features are ignored in current distributed implementation!") + + # Load training labels + Y_trn = smat_util.load_matrix(args.trn_label_path, dtype=np.float32) + LOGGER.info("Loaded training label matrix with shape={}".format(Y_trn.shape)) + + # Load training texts + trn_corpus = Preprocessor.load_data_from_file( + args.trn_text_path, + label_text_path=None, + text_pos=0, + )["corpus"] + LOGGER.info("Loaded {} training sequences".format(len(trn_corpus))) + + # Load test feature if given + if args.tst_feat_path or args.tst_label_path or args.tst_text_path: + LOGGER.warning(f"Validation set is ignored in distributed training") + + # load cluster chain + if os.path.exists(args.code_path): + cluster_chain = ClusterChain.from_partial_chain( + smat_util.load_matrix(args.code_path), + min_codes=train_params.preliminary_indexer_params.min_codes, + nr_splits=train_params.preliminary_indexer_params.nr_splits, + ) + LOGGER.info("Loaded from code-path: {}".format(args.code_path)) + else: + cluster_chain = None + LOGGER.warning( + "Label partition not provided, falling back to one-versue-all training. \ + For multi-resolution training, provide label partition with --code-path" + ) + + trn_prob = MLProblemWithText(trn_corpus, Y_trn) + + xtf = XTransformerDistTrainer.train( + trn_prob, + clustering=cluster_chain, + train_params=train_params, + pred_params=pred_params, + dist_params=dist_params, + beam_size=args.beam_size, + ) + + xtf.save(args.model_dir) + + +if __name__ == "__main__": + parser = add_dist_arguments(parse_arguments()) + args = parser.parse_args() + logging_util.setup_logging_config(level=args.verbose_level) + do_train(args) diff --git a/pecos/utils/torch_util.py b/pecos/utils/torch_util.py index da68b1a..0398228 100644 --- a/pecos/utils/torch_util.py +++ b/pecos/utils/torch_util.py @@ -16,27 +16,34 @@ LOGGER = logging.getLogger(__name__) -def setup_device(use_gpu_if_available=True): +def setup_device(use_gpu_if_available=True, device_id=-1): """Setup device for pytorch. Args: use_gpu_if_available (bool, optional): whether to use GPU if available. Default True + device_id (int, optional): GPU id to use. Default -1 to use all Returns: device (torch.device): torch device n_active_gpu (int): number of GPUs available for torch.cuda """ - if use_gpu_if_available: # use all that available - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - n_active_gpu = torch.cuda.device_count() - if not torch.cuda.is_available(): - LOGGER.warning("CUDA is not available, will fall back to CPU.") + if use_gpu_if_available and torch.cuda.is_available(): + if device_id >= 0: + # use specified device + device = torch.device("cuda", device_id) + n_active_gpu = 1 + else: + # regular dataparallel + device = torch.device("cuda") + n_active_gpu = torch.cuda.device_count() else: - device = torch.device("cpu") + if use_gpu_if_available: + LOGGER.warning("CUDA is not available, will fall back to CPU.") if torch.cuda.is_available(): LOGGER.warning("CUDA is available but disabled, will only use CPU.") + device = torch.device("cpu") n_active_gpu = 0 - LOGGER.info("Setting device to {}, number of active GPUs: {}".format(device, n_active_gpu)) + LOGGER.info(f"Setting device to {device}, number of active GPUs: {n_active_gpu}") return device, n_active_gpu diff --git a/pecos/xmc/xtransformer/matcher.py b/pecos/xmc/xtransformer/matcher.py index b5827c8..00d5ae7 100644 --- a/pecos/xmc/xtransformer/matcher.py +++ b/pecos/xmc/xtransformer/matcher.py @@ -28,7 +28,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers import AdamW, AutoConfig, get_scheduler, BatchEncoding -from .module import XMCTensorDataset, XMCTextDataset +from .module import XMCLabelTensorizer, XMCTextTensorizer, XMCTextDataset from .network import ENCODER_CLASSES, HingeLoss, TransformerLinearXMCHead logging.getLogger(transformers.__name__).setLevel(logging.WARNING) @@ -131,7 +131,8 @@ class TrainParams(pecos.BaseParams): # type: ignore save_steps: int = 100 cost_sensitive_ranker: bool = False - pre_tokenize: bool = False + pre_tokenize: bool = True + pre_tensorize_labels: bool = True use_gpu: bool = True eval_by_true_shorlist: bool = False @@ -188,7 +189,7 @@ def __init__( self, text_encoder, text_tokenizer, - text_model, + text_model=None, C=None, concat_model=None, train_params=None, @@ -214,7 +215,16 @@ def __init__( self.text_tokenizer = text_tokenizer self.C = C - self.text_model = text_model + if text_model is None: + self.text_model = TransformerLinearXMCHead( + text_encoder.config.hidden_size, text_encoder.config.num_labels + ) + LOGGER.warning( + f"XMC text_model of {text_encoder.__class__.__name__} not initialized from pre-trained model." + ) + else: + self.text_model = text_model + self.concat_model = concat_model self.train_params = self.TrainParams.from_dict(train_params) @@ -246,6 +256,8 @@ def clear_cuda(self): """Clear CUDA memory""" if hasattr(self.text_encoder, "module"): self.text_encoder = self.text_encoder.module + if hasattr(self.text_model, "module"): + self.text_model = self.text_model.module self.text_encoder.to(torch.device("cpu")) self.text_model.to(torch.device("cpu")) torch.cuda.empty_cache() @@ -288,7 +300,10 @@ def nr_features(self): @property def nr_labels(self): """Get the number of labels""" - return self.text_model.num_labels + if hasattr(self.text_model, "module"): + return self.text_model.module.num_labels + else: + return self.text_model.num_labels @property def model_type(self): @@ -323,8 +338,6 @@ def save(self, save_dir): with open(os.path.join(save_dir, "param.json"), "w", encoding="utf-8") as f: f.write(json.dumps(param, indent=True)) - smat_util.save_matrix(os.path.join(save_dir, "C.npz"), self.C) - encoder_dir = os.path.join(save_dir, "text_encoder") os.makedirs(encoder_dir, exist_ok=True) # this creates config.json, pytorch_model.bin @@ -334,12 +347,19 @@ def save(self, save_dir): os.makedirs(tokenizer_dir, exist_ok=True) self.text_tokenizer.save_pretrained(tokenizer_dir) + # this creates C.npz + if self.C is not None: + smat_util.save_matrix(os.path.join(save_dir, "C.npz"), self.C) # this creates text_model text_model_dir = os.path.join(save_dir, "text_model") - torch.save(self.text_model, text_model_dir) + if self.text_model is not None: + head_to_save = ( + self.text_model.module if hasattr(self.text_model, "module") else self.text_model + ) + torch.save(head_to_save, text_model_dir) # save the concat_model concat_model_dir = os.path.join(save_dir, "concat_model") - if self.concat_model: + if self.concat_model is not None: self.concat_model.save(concat_model_dir) @classmethod @@ -390,18 +410,14 @@ def load(cls, load_dir): if os.path.exists(text_model_dir): text_model = torch.load(text_model_dir) else: - text_model = TransformerLinearXMCHead( - encoder_config.hidden_size, encoder_config.num_labels - ) - LOGGER.warning( - f"XMC text_model of {text_encoder.__class__.__name__} not initialized from pre-trained model." - ) + text_model = None # load C C_path = os.path.join(load_dir, "C.npz") if not os.path.exists(C_path): - raise ValueError(f"Cluster code does not exist at {C_path}") - C = smat_util.load_matrix(C_path) + C = smat.csr_matrix(np.ones((encoder_config.num_labels, 1), dtype=np.float32)) + else: + C = smat_util.load_matrix(C_path) # load concat_model concat_model_dir = os.path.join(load_dir, "concat_model") @@ -412,7 +428,7 @@ def load(cls, load_dir): return cls( text_encoder, text_tokenizer, - text_model, + text_model=text_model, C=C, concat_model=concat_model, train_params=train_params, @@ -478,6 +494,12 @@ def _get_tokenizer_config(**kwargs): } return {**convert_kwargs, **kwargs} + def _tokenize(self, text): + return self.text_tokenizer( + text=text, + **self._get_tokenizer_config(max_length=self.pred_params.truncate_length), + ) + def text_to_tensor(self, corpus, max_length=None): """Convert input text corpus into padded tensors @@ -509,95 +531,6 @@ def text_to_tensor(self, corpus, max_length=None): LOGGER.info("***** Finished with time cost={} *****".format(time.time() - t_start)) return feature_tensors - @staticmethod - def _get_label_tensors(M, Y, idx_padding=-1, max_labels=None): - """ - Given matching matrix M and label matrix Y, construct label tensors for XMC training - The non-zero indices of Y are seen as positive labels and therefore all - included in the result. - - Example: - M = smat.csr_matrix([[1, 1, 0, 0], - [0, 0, 1, 1]]) - Y = smat.csr_matrix([[0, 1, 0, 2], - [0, 0, 0, 3]]) - then the returned values will be: - label_indices = torch.IntTensor([[1, 3, 0], [3, 2, -1]]) - label_values = torch.FloatTensor([[1., 2., 0.], [3., 0., 0.]]) - - Args: - M (csr_matrix or None): matching matrix, shape = (nr_inst, nr_labels) - It's indices are the candidate label indices to consider - It's values will not be used - Y (csr_matrix or None): label matrix, shape = (nr_inst, nr_labels) - It's non-zero indices are positive labels and will always be - included. - idx_padding (int, optional): the index used to pad all label_indices - to the same length. Default -1 - max_labels (int, optional): max number of labels considered for each - instance, will subsample from existing label indices if need to. - Default None to use max row nnz of M. - - Returns: - label_indices (torch.IntTensor or None): containing label indices with - shape = (nr_inst, max_labels). Return None if M is None - label_values (torch.FloatTensor or None): containing label values - with shape = (nr_inst, max_labels). If Y is None, return None - """ - if M is None and Y is None: - return None, None - elif M is None and Y is not None: - # if M is None, taking all labels into account - return None, torch.FloatTensor(Y.toarray()) - - if Y is not None: - if Y.shape != M.shape: - raise ValueError("Y and M shape mismatch: {} and {}".format(Y.shape, M.shape)) - label_lower_bound = max(Y.indptr[1:] - Y.indptr[:-1]) - # make sure all positive labels are included - M1 = smat_util.binarized(M) + smat_util.binarized(Y) - else: - M1 = M - label_lower_bound = 0 - - label_upper_bound = max(M1.indptr[1:] - M1.indptr[:-1]) - if max_labels is None: - max_labels = label_upper_bound - else: - max_labels = min(max_labels, label_upper_bound) - if max_labels < label_lower_bound: - max_labels = label_lower_bound - LOGGER.warning( - f"Increasing max_labels to {label_lower_bound} to accommodate all positive labels." - ) - - nr_inst = M1.shape[0] - label_indices = np.zeros((nr_inst, max_labels), dtype=np.int64) + idx_padding - if Y is not None: - label_values = np.zeros((nr_inst, max_labels), dtype=np.float32) - - for i in range(nr_inst): - offset = 0 - neg_samples = M1.indices[M1.indptr[i] : M1.indptr[i + 1]] - # fill with positive samples first - if Y is not None: - y_nnz = Y.indptr[i + 1] - Y.indptr[i] - rng = slice(Y.indptr[i], Y.indptr[i + 1]) - label_indices[i, :y_nnz] = Y.indices[rng] - label_values[i, :y_nnz] = Y.data[rng] - offset += y_nnz - neg_samples = neg_samples[np.invert(np.isin(neg_samples, Y.indices[rng]))] - # fill the rest slots with negative samples - if neg_samples.size > max_labels - offset: - # random sample negative labels - neg_samples = np.random.choice(neg_samples, max_labels - offset) - - label_indices[i, offset : offset + neg_samples.size] = neg_samples - - label_indices = torch.IntTensor(label_indices) - - return label_indices, None if Y is None else torch.FloatTensor(label_values) - @staticmethod def ensemble_prediction(transformer_pred_csr, concat_pred_csr, only_topk, ens_method): """Generate micro ensemble of concat predictions and transformer predictions @@ -779,6 +712,7 @@ def _predict( """ batch_gen_workers = kwargs.get("batch_gen_workers", 4) only_embeddings = kwargs.get("only_embeddings", False) + label_padding_idx = self.text_model.label_padding_idx if csr_codes is not None: # need to keep explicit zeros in csr_codes_next @@ -800,20 +734,23 @@ def _predict( ) else: csr_codes_next = None - LOGGER.info("Predict on input text tensors({})".format(X_text["input_ids"].shape)) + LOGGER.info( + "Predict on input text tensors({}) in OVA mode".format(X_text["input_ids"].shape) + ) - label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( - csr_codes_next, None, idx_padding=self.text_model.label_pad + input_tensorizer = XMCTextTensorizer( + X_text, + feature_keys=["input_ids", "attention_mask", "token_type_ids", "instance_number"], ) - data = XMCTensorDataset( - X_text["input_ids"], - X_text["attention_mask"], - X_text["token_type_ids"], - X_text["instance_number"], - label_values=label_values_pt, - label_indices=label_indices_pt, + lbl_tensorizer = XMCLabelTensorizer( + Y=None, + M=csr_codes_next, + label_padding_idx=label_padding_idx, + pre_compute=True, ) + data = XMCTextDataset(input_tensorizer, lbl_tensorizer) + # since number of active labels may vary # using pinned memory will slow down data loading dataloader = DataLoader( @@ -881,7 +818,7 @@ def _predict( val = c_pred.cpu().numpy().flatten() val = val[ np.argwhere( - inputs["label_indices"].cpu().flatten() != self.text_model.label_pad + inputs["label_indices"].cpu().flatten() != label_padding_idx ) ].flatten() val = PostProcessor.get(pred_params.post_processor).transform( @@ -950,6 +887,56 @@ def concat_features(X_feat, X_emb, normalize_emb=True): raise TypeError(f"Expected CSR or ndarray, got {type(X_feat)}") return X_cat + @staticmethod + def prepare_data( + prob, label_padding_idx=None, pre_tensorize_labels=False, input_transform=None + ): + """Prepare data for transformer encoder from given problem + + Args: + prob (MLProblemWithText): problem to build data for + label_padding_idx (int, optional): padding index in label embeddings. + Default None to use the number of columns of prob.Y + pre_tensorize_labels (bool, optional): use pre-tensorize or realtime + tensorization of label tensors. Default False + input_transform (function, optional): transformation function for + input text tokenization. + + Returns: + XMCTextDataset + """ + if label_padding_idx is None: + label_padding_idx = prob.Y.shape[1] + + if prob.M is not None: + # need to keep explicit zeros in csr_codes_next + # therefore do not pass it through constructor + if not isinstance(prob.M, smat.csr_matrix): + raise TypeError(f"Got type={type(prob.M)} for M!") + # getting the result in csr by computing csr * csr + M_next = clib.sparse_matmul( + prob.M, + prob.C.T, + eliminate_zeros=False, + ) + else: + M_next = None + + input_tensorizer = XMCTextTensorizer( + prob.X_text, + feature_keys=["input_ids", "attention_mask", "token_type_ids", "instance_number"], + input_transform=input_transform, + ) + lbl_tensorizer = XMCLabelTensorizer( + Y=prob.Y, + M=M_next, + label_padding_idx=label_padding_idx, + pre_compute=pre_tensorize_labels, + ) + + train_data = XMCTextDataset(input_tensorizer, lbl_tensorizer) + return train_data + def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): """Fine tune the transformer text_encoder @@ -971,65 +958,20 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): self.device ) - max_act_labels = train_params.max_active_matching_labels logging_steps = train_params.logging_steps max_steps = train_params.max_steps max_no_improve_cnt = train_params.max_no_improve_cnt - if prob.M is not None: - # need to keep explicit zeros in csr_codes_next - # therefore do not pass it through constructor - if not isinstance(prob.M, smat.csr_matrix): - raise TypeError(f"Got type={type(prob.M)} for M!") - # getting the result in csr by computing csr * csr - M_next = clib.sparse_matmul( - prob.M, - self.C.T, - eliminate_zeros=False, - threads=train_params.batch_gen_workers, - ) - - do_resample = max_act_labels is not None and max_act_labels < max( - M_next.indptr[1:] - M_next.indptr[:-1] - ) - else: - M_next = None - do_resample = False if prob.M is None or train_params.max_num_labels_in_gpu >= self.nr_labels: # put text_model to GPU self.text_model.to(self.device) - if prob.is_tokenized: - LOGGER.info("Using XMCTensorDataset for tokenized inputs!") - label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( - M_next, - prob.Y, - idx_padding=self.text_model.label_pad, - max_labels=max_act_labels, - ) - train_data = XMCTensorDataset( - prob.X_text["input_ids"], - prob.X_text["attention_mask"], - prob.X_text["token_type_ids"], - prob.X_text["instance_number"], - label_values=label_values_pt, - label_indices=label_indices_pt, - ) - else: - LOGGER.info("Using XMCTextDataset for text inputs!") - os.environ["TOKENIZERS_PARALLELISM"] = "false" - train_data = XMCTextDataset( - prob.X_text, - lambda x: self.text_tokenizer( - text=x, - **self._get_tokenizer_config(max_length=pred_params.truncate_length), - ), - feature_keys=["input_ids", "attention_mask", "token_type_ids", "instance_number"], - Y=prob.Y, - M=M_next, - idx_padding=self.text_model.label_pad, - max_labels=max_act_labels, - ) + train_data = self.prepare_data( + prob, + label_padding_idx=self.text_model.label_padding_idx, + pre_tensorize_labels=train_params.pre_tensorize_labels, + input_transform=None if prob.is_tokenized else self._tokenize, + ) # since number of active labels may vary # using pinned memory will slow down data loading @@ -1087,7 +1029,7 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): ) sparse_parameters = list(self.text_model.parameters()) - if prob.M is not None: + if self.text_model.is_sparse: emb_optimizer = torch.optim.SparseAdam( sparse_parameters, lr=train_params.learning_rate, @@ -1109,9 +1051,9 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): # Start Batch Training LOGGER.info("***** Running training *****") - LOGGER.info(" Num examples = %d", prob.nr_inst) + LOGGER.info(" Num examples = %d", len(train_data)) LOGGER.info(" Num labels = %d", self.nr_labels) - if prob.M is not None: + if train_data.has_ns: LOGGER.info(" Num active labels per instance = %d", train_data.num_active_labels) LOGGER.info(" Num Epochs = %d", train_params.num_train_epochs) LOGGER.info(" Learning Rate Schedule = %s", train_params.lr_schedule) @@ -1130,19 +1072,6 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): self.text_encoder.zero_grad() self.text_model.zero_grad() for epoch in range(1, int(train_params.num_train_epochs) + 1): - if ( - isinstance(train_data, XMCTensorDataset) and do_resample and epoch > 1 - ): # redo subsample negative labels - label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( - M_next, - prob.Y, - idx_padding=self.text_model.label_pad, - max_labels=train_params.max_active_matching_labels, - ) - train_data.refresh_labels( - label_values=label_values_pt, - label_indices=label_indices_pt, - ) for batch_cnt, batch in enumerate(train_dataloader): self.text_encoder.train() self.text_model.train() @@ -1155,7 +1084,7 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): "token_type_ids": batch[2], "instance_number": batch[3], "label_values": batch[4], - "label_indices": batch[-1] if prob.M is not None else None, + "label_indices": batch[-1] if train_data.has_ns else None, } text_model_W_seq, text_model_b_seq = self.text_model( output_indices=inputs["label_indices"], @@ -1329,7 +1258,7 @@ def train( kwargs: saved_trn_pt (str): path to the tokenized trn text. If given, will skip tokenization saved_val_pt (str): path to the tokenized val text. If given, will skip tokenization - bootstrapping (tuple): (init_encoder, init_embeddings) the + bootstrapping (tuple): (init_encoder, init_embeddings, prev_head) the text_encoder and corresponding instance embeddings generated by it. Used for bootstrap current text_encoder and text_model. Default None to ignore @@ -1431,13 +1360,13 @@ def train( M=prob.M, R=prob.Y if "weighted" in train_params.bootstrap_method else None, ) - matcher.text_model.bootstrap(bootstrap_prob) + matcher.text_model.bootstrap(bootstrap_prob, sparse=True) LOGGER.info("Initialized transformer text_model with xlinear!") elif train_params.bootstrap_method == "inherit": - matcher.text_model.inherit(prev_head, prob.C) + matcher.text_model.inherit(prev_head, prob.C, sparse=True) LOGGER.info("Initialized transformer text_model form parent layer!") elif train_params.bootstrap_method == "no-bootstrap": - matcher.text_model.random_init() + matcher.text_model.random_init(sparse=True) LOGGER.info("Randomly initialized transformer text_model!") else: raise ValueError(f"Unknown bootstrap_method: {train_params.bootstrap_method}") @@ -1548,9 +1477,9 @@ def train( if return_dict: return { "matcher": matcher, - "trn_pred": P_trn if return_pred_on_trn else None, + "trn_pred": P_trn, "val_pred": P_val, - "trn_embeddings": inst_embeddings if return_embed_on_trn else None, + "trn_embeddings": inst_embeddings, "val_embeddings": val_inst_embeddings, } else: diff --git a/pecos/xmc/xtransformer/model.py b/pecos/xmc/xtransformer/model.py index 36102b8..71814f3 100644 --- a/pecos/xmc/xtransformer/model.py +++ b/pecos/xmc/xtransformer/model.py @@ -101,7 +101,10 @@ def override_with_kwargs(self, pred_kwargs): overridden_beam_size = pred_kwargs.get("beam_size", None) overridden_only_topk = pred_kwargs.get("only_topk", None) overridden_post_processor = pred_kwargs.get("post_processor", None) - depth = len(self.matcher_params_chain) + if isinstance(self.matcher_params_chain, list): + depth = len(self.matcher_params_chain) + else: + depth = 1 for d in range(depth): if overridden_beam_size: if d == depth - 1: @@ -178,11 +181,12 @@ def save(self, save_dir): LOGGER.info("Model saved to {}".format(save_dir)) @classmethod - def load(cls, load_dir): + def load(cls, load_dir, **xrl_kwargs): """Load X-Transformer model from file Args: load_dir (str): dir to load the model + xrl_kwargs: kwargs to pass to XLinearModel.load to load concat_model Returns: XTransformer @@ -191,7 +195,10 @@ def load(cls, load_dir): raise ValueError(f"load dir does not exist at: {load_dir}") text_encoder = TransformerMatcher.load(os.path.join(load_dir, "text_encoder")) try: - concat_model = XLinearModel.load(os.path.join(load_dir, "concat_model")) + concat_model = XLinearModel.load( + os.path.join(load_dir, "concat_model"), + **xrl_kwargs, + ) LOGGER.info("Full model loaded from {}".format(load_dir)) except FileNotFoundError: concat_model = None @@ -263,19 +270,19 @@ def train( if not train_params.do_fine_tune: if isinstance(train_params.matcher_params_chain, list): - matcher_train_params = train_params.matcher_params_chain[-1] + matcher_train_params = train_params.matcher_params_chain[0] else: matcher_train_params = train_params.matcher_params_chain if isinstance(train_params.matcher_params_chain, list): - matcher_pred_params = pred_params.matcher_params_chain[-1] + matcher_pred_params = pred_params.matcher_params_chain[0] else: matcher_pred_params = pred_params.matcher_params_chain device, n_gpu = torch_util.setup_device(matcher_train_params.use_gpu) if matcher_train_params.init_model_dir: - parent_model = cls.load(train_params.init_model_dir) + parent_model = TransformerMatcher.load(matcher_train_params.init_model_dir) LOGGER.info("Loaded encoder from {}.".format(matcher_train_params.init_model_dir)) else: parent_model = TransformerMatcher.download_model( @@ -536,6 +543,7 @@ def predict( use_gpu (bool, optional): use GPU if available. Default True max_pred_chunk (int, optional): max number of instances to predict at once. Set to None to ignore. Default 10^7 + device_id (int, optional): GPU id to use. Default -1 to use all threads (int, optional): the number of threads to use for linear model prediction. Returns: @@ -546,9 +554,11 @@ def predict( batch_size = kwargs.get("batch_size", 8) batch_gen_workers = kwargs.get("batch_gen_workers", 4) - use_gpu = kwargs.get("use_gpu", True) max_pred_chunk = kwargs.get("max_pred_chunk", 10**7) - device, n_gpu = torch_util.setup_device(use_gpu) + device, n_gpu = torch_util.setup_device( + use_gpu_if_available=kwargs.get("use_gpu", True), + device_id=kwargs.get("device_id", -1), + ) # get the override pred_params if pred_params is None: @@ -612,15 +622,18 @@ def encode( use_gpu (bool, optional): use GPU if available. Default True max_pred_chunk (int, optional): max number of instances to predict at once. Set to None to ignore. Default 10^7 + device_id (int, optional): GPU id to use. Default -1 to use all Returns: embeddings (ndarray): instance embedding on training data, shape = (nr_inst, hidden_dim). """ batch_size = kwargs.get("batch_size", 8) batch_gen_workers = kwargs.get("batch_gen_workers", 4) - use_gpu = kwargs.get("use_gpu", True) max_pred_chunk = kwargs.get("max_pred_chunk", 10**7) - device, n_gpu = torch_util.setup_device(use_gpu) + device, n_gpu = torch_util.setup_device( + use_gpu_if_available=kwargs.get("use_gpu", True), + device_id=kwargs.get("device_id", -1), + ) # get the override pred_params if pred_params is None: diff --git a/pecos/xmc/xtransformer/module.py b/pecos/xmc/xtransformer/module.py index ab35ff1..35cc2f2 100644 --- a/pecos/xmc/xtransformer/module.py +++ b/pecos/xmc/xtransformer/module.py @@ -9,11 +9,13 @@ # OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions # and limitations under the License. import logging +import os +import json import numpy as np import torch import scipy.sparse as smat from pecos.utils import smat_util -from torch.utils.data import Dataset, TensorDataset +from torch.utils.data import Dataset from transformers import BatchEncoding LOGGER = logging.getLogger(__name__) @@ -74,117 +76,165 @@ def nr_inst(self): return self.Y.shape[0] -class XMCTensorDataset(Dataset): - """Dataset to hold feature and label tensors for XMC training and prediction. +class XMCTextTensorizer(object): - Args: - *features (tensors): feature tensors, required to have same first - dimension: nr_inst - label_values (tensor or None): label values with shape = (nr_inst, num_active_labels) - label_indices (tensor or None): label indices with shape = (nr_inst, num_active_labels) - - Return values depend on the label_values and label_indices: - if label_values is None and label_indices is not None: - data[i] = (feature[0][i], feature[1][i], ..., label_values[i], label_indices[i]) - elif label_values is not None: - data[i] = (feature[0][i], feature[1][i], ..., label_values[i]) - elif label_indices is not None: - data[i] = (feature[0][i], feature[1][i], ..., label_indices[i]) - else: - data[i] = (feature[0][i], feature[1][i], ...) - """ + DEFAULT_FEATURE_KEYS = ["input_ids", "attention_mask", "token_type_ids", "instance_number"] - def __init__(self, *features, label_values=None, label_indices=None): - self.nr_inst = features[0].size(0) - self.data = TensorDataset(*features) - if label_values is not None and label_values.size(0) != self.nr_inst: - raise ValueError("First dimension mismatch between features and label_values") - if label_indices is not None and label_indices.size(0) != self.nr_inst: - raise ValueError("First dimension mismatch between features and label_indices") + def __init__(self, text, feature_keys=None, input_transform=None): + self.text = text + self.feature_keys = feature_keys + self.input_transform = input_transform - self.label_values = label_values - self.label_indices = label_indices + if self.feature_keys is None: + self.feature_keys = self.DEFAULT_FEATURE_KEYS - @property - def num_active_labels(self): - if self.label_indices is None: - return None + if self.is_tokenized: + for k in self.feature_keys: + if k not in self.text: + raise KeyError(f"Missing key ({k}) from tokenized inputs") else: - return self.label_indices.shape[1] - - def __getitem__(self, index): - if self.label_values is not None and self.label_indices is not None: - return self.data[index] + (self.label_values[index], self.label_indices[index]) - elif self.label_indices is not None: - return self.data[index] + (self.label_indices[index],) - elif self.label_values is not None: - return self.data[index] + (self.label_values[index],) + if input_transform is None: + raise ValueError(f"Expect tokenizer if raw text given to XMCTextTensorizer") + + LOGGER.info( + f"Constructed XMCTextTensorizer, tokenized={self.is_tokenized}, len={len(self)}" + ) + + def get_shard(self, start, end): + if end <= start: + raise ValueError(f"end >= start: {end} <= {start}") + if self.is_tokenized: + text_shard = {k: self.text[k][start:end] for k in self.feature_keys if k in self.text} else: - return self.data[index] + text_shard = self.text[start:end] + return XMCTextTensorizer( + text_shard, + feature_keys=self.feature_keys, + input_transform=self.input_transform, + ) def __len__(self): - return self.nr_inst + if self.is_tokenized: + return self.text[self.feature_keys[0]].shape[0] + else: + return len(self.text) - def refresh_labels(self, label_values=None, label_indices=None): - """Refresh label-values and label-indices from given tensors""" - self.label_values = label_values - self.label_indices = label_indices + @property + def is_tokenized(self): + return isinstance(self.text, (dict, BatchEncoding)) + def __getitem__(self, i): + if self.is_tokenized: + ret = {k: self.text[k][i] for k in self.feature_keys if k in self.text} + return tuple(self.text[k][i] for k in self.feature_keys) + else: + ret = self.input_transform(self.text[i]) + ret["instance_number"] = torch.IntTensor([i]) + return tuple(ret[kk].squeeze(dim=0) for kk in self.feature_keys) -class XMCTextDataset(Dataset): - """Dataset to hold text and label/matching matrices for XMC training and prediction. - Conduct real-time tokenization of input text and label tensor generation to save memory. +class XMCLabelTensorizer(object): + """ Args: - text (list of str): input text, length = nr_inst - input_transform (function): the transform function to process/tokenize text - feature_keys (list of str): the feature keys in order for batch generation. Y (csr_matrix, optional): training labels, shape = (nr_inst, nr_labels) M (csr_matrix, optional): matching matrix, shape = (nr_inst, nr_codes) model will be trained only on its non-zero indices its values will not be used. - idx_padding (int, optional): the index used to pad all label_indices + label_padding_idx (int, optional): the index used to pad all label_indices to the same length. Default -1 max_labels (int, optional): max number of labels considered for each instance, will subsample from existing label indices if need to. Default None to ignore. - + pre_compute (bool, optional): whether to pre-generate label tensors for the dataset. + Default False Return values depend on the Y and M: 1. Both Y and M are not None (train on middle layer): - data[i] = (feature[0][i], feature[1][i], ..., label_values[i], label_indices[i]) + data[i] = (label_values[i], label_indices[i]) 2. Both Y and M are None (inference on top layer): - data[i] = (feature[0][i], feature[1][i], ...) + data[i] = (,) 2. Y is not None, M is None (train on top layer): - data[i] = (feature[0][i], feature[1][i], ..., label_values[i]) + data[i] = (label_values[i],) 3. Y is None, M is not None (inference on middle layer): - data[i] = (feature[0][i], feature[1][i], ..., label_indices[i]) + data[i] = (label_indices[i],) """ def __init__( self, - text, - input_transform, - feature_keys, Y=None, M=None, - idx_padding=-1, + label_padding_idx=-1, max_labels=None, + pre_compute=False, ): - self.text = text - self.input_transform = input_transform - self.feature_keys = feature_keys - self.idx_padding = idx_padding - self.lbl_mat = None + self.label_padding_idx = label_padding_idx self.has_label = Y is not None self.has_ns = M is not None + self.pre_compute = pre_compute + self.label_width = None self.offset = 0 + if pre_compute: + # pre-computed will use these + self.get_lbl_tensors(M, Y, max_labels=max_labels) + else: + # realtime compute will use these + self.get_lbl_mat(M, Y, max_labels=max_labels) + + LOGGER.debug( + f"Constructed XMCLabelTensorizer, pre_compute={self.pre_compute}, len={len(self)}, num_active_labels={self.num_active_labels}" + ) + + def get_shard(self, start, end): + if end <= start: + raise ValueError(f"end <= start: {end} <= {start}") + + ret = XMCLabelTensorizer( + label_padding_idx=self.label_padding_idx, + pre_compute=self.pre_compute, + ) + ret.has_label = self.has_label + ret.has_ns = self.has_ns + ret.label_width = self.label_width + ret.offset = self.offset + + if self.pre_compute: + ret.label_indices = ( + None if self.label_indices is None else self.label_indices[start:end, :] + ) + ret.label_values = ( + None if self.label_values is None else self.label_values[start:end, :] + ) + elif self.lbl_mat is not None: + ret.lbl_mat = self.lbl_mat[start:end, :] + else: + ret.lbl_mat = None + + return ret + + def __len__(self): + if not self.has_ns and not self.has_label: + return 0 + if self.pre_compute: + if self.has_ns: + return self.label_indices.shape[0] + else: + return self.label_values.shape[0] + else: + return self.lbl_mat.shape[0] + + @property + def num_active_labels(self): + return self.label_width + + def get_lbl_mat(self, M, Y, max_labels=None): + if M is None and Y is None: # 1.inference at top layer - self.label_width = None + self.label_width = 0 + self.lbl_mat = None elif M is not None and Y is None: # 2.inference at intermediate layer self.label_width = max(M.indptr[1:] - M.indptr[:-1]) @@ -215,46 +265,207 @@ def __init__( f"label-width ({self.label_width}) is not able to cover all positive labels ({label_lower_bound})!" ) - def __len__(self): - return len(self.text) - - @property - def num_active_labels(self): - return self.label_width - - def get_input_tensors(self, i): - ret = self.input_transform(self.text[i]) - ret["instance_number"] = torch.IntTensor([i]) - return tuple(ret[kk].squeeze(dim=0) for kk in self.feature_keys) + def get_lbl_tensors(self, M, Y, max_labels=None): + if M is None and Y is None: + self.label_indices = None + self.label_values = None + self.label_width = 0 + elif M is None and Y is not None: + # if M is None, taking all labels into account + self.label_indices = None + self.label_values = torch.FloatTensor(Y.toarray()) + self.label_width = Y.shape[1] + else: + if Y is not None: + if Y.shape != M.shape: + raise ValueError("Y and M shape mismatch: {} and {}".format(Y.shape, M.shape)) + label_lower_bound = max(Y.indptr[1:] - Y.indptr[:-1]) + # make sure all positive labels are included + M1 = smat_util.binarized(M) + smat_util.binarized(Y) + else: + M1 = M + label_lower_bound = 0 - def get_output_tensors(self, i): - if not self.has_ns: - if not self.has_label: - return tuple() + label_upper_bound = max(M1.indptr[1:] - M1.indptr[:-1]) + if max_labels is None: + max_labels = label_upper_bound + else: + max_labels = min(max_labels, label_upper_bound) + if max_labels < label_lower_bound: + max_labels = label_lower_bound + LOGGER.warning( + f"Increasing max_labels to {label_lower_bound} to accommodate all positive labels." + ) + + nr_inst = M1.shape[0] + label_indices = np.zeros((nr_inst, max_labels), dtype=np.int64) + self.label_padding_idx + if Y is not None: + label_values = np.zeros((nr_inst, max_labels), dtype=np.float32) + + for i in range(nr_inst): + offset = 0 + neg_samples = M1.indices[M1.indptr[i] : M1.indptr[i + 1]] + # fill with positive samples first + if Y is not None: + y_nnz = Y.indptr[i + 1] - Y.indptr[i] + rng = slice(Y.indptr[i], Y.indptr[i + 1]) + label_indices[i, :y_nnz] = Y.indices[rng] + label_values[i, :y_nnz] = Y.data[rng] + offset += y_nnz + neg_samples = neg_samples[np.invert(np.isin(neg_samples, Y.indices[rng]))] + # fill the rest slots with negative samples + if neg_samples.size > max_labels - offset: + # random sample negative labels + neg_samples = np.random.choice(neg_samples, max_labels - offset) + + label_indices[i, offset : offset + neg_samples.size] = neg_samples + + self.label_indices = torch.IntTensor(label_indices) + self.label_values = None if Y is None else torch.FloatTensor(label_values) + self.label_width = max_labels + + def __getitem__(self, i): + if self.pre_compute: + if self.label_values is not None and self.label_indices is not None: + return (self.label_values[i], self.label_indices[i]) + elif self.label_indices is not None: + return (self.label_indices[i],) + elif self.label_values is not None: + return (self.label_values[i],) else: - return (torch.FloatTensor(self.lbl_mat[i].toarray()).squeeze(dim=0),) + return tuple() + else: - nr_active = self.lbl_mat.indptr[i + 1] - self.lbl_mat.indptr[i] - rng = slice(self.lbl_mat.indptr[i], self.lbl_mat.indptr[i + 1]) - - if nr_active > self.label_width: - # sub-sample to fit in self.label_width - nr_active = self.label_width - rng = np.random.choice( - np.arange(self.lbl_mat.indptr[i], self.lbl_mat.indptr[i + 1]), - nr_active, - replace=False, + if not self.has_ns: + if not self.has_label: + return tuple() + else: + return (torch.FloatTensor(self.lbl_mat[i].toarray()).squeeze(dim=0),) + else: + nr_active = self.lbl_mat.indptr[i + 1] - self.lbl_mat.indptr[i] + rng = slice(self.lbl_mat.indptr[i], self.lbl_mat.indptr[i + 1]) + + if nr_active > self.label_width: + # sub-sample to fit in self.label_width + nr_active = self.label_width + rng = np.random.choice( + np.arange(self.lbl_mat.indptr[i], self.lbl_mat.indptr[i + 1]), + nr_active, + replace=False, + ) + + label_indices = ( + torch.zeros((self.label_width,), dtype=torch.int) + self.label_padding_idx ) + label_indices[:nr_active] = torch.from_numpy(self.lbl_mat.indices[rng]) - label_indices = torch.zeros((self.label_width,), dtype=torch.int) + self.idx_padding - label_indices[:nr_active] = torch.from_numpy(self.lbl_mat.indices[rng]) + if not self.has_label: + return (label_indices,) + else: + label_values = torch.zeros((self.label_width,), dtype=torch.float32) + label_values[:nr_active] = torch.from_numpy( + self.lbl_mat.data[rng] - self.offset + ) + return (label_values, label_indices) - if not self.has_label: - return (label_indices,) - else: - label_values = torch.zeros((self.label_width,), dtype=torch.float32) - label_values[:nr_active] = torch.from_numpy(self.lbl_mat.data[rng] - self.offset) - return (label_values, label_indices) - def __getitem__(self, index): - return self.get_input_tensors(index) + self.get_output_tensors(index) +class XMCTextDataset(Dataset): + """Dataset to hold text and label/matching matrices for XMC training and prediction. + Conduct real-time tokenization of input text and label tensor generation to save memory. + + Args: + text (list of str): input text, length = nr_inst + input_transform (function): the transform function to process/tokenize text + feature_keys (list of str): the feature keys in order for batch generation. + Y (csr_matrix, optional): training labels, shape = (nr_inst, nr_labels) + M (csr_matrix, optional): matching matrix, shape = (nr_inst, nr_codes) + model will be trained only on its non-zero indices + its values will not be used. + label_padding_idx (int, optional): the index used to pad all label_indices + to the same length. Default -1 + max_labels (int, optional): max number of labels considered for each + instance, will subsample from existing label indices if need to. + Default None to ignore. + + + Return values depend on the Y and M: + 1. Both Y and M are not None (train on middle layer): + data[i] = (feature[0][i], feature[1][i], ..., label_values[i], label_indices[i]) + 2. Both Y and M are None (inference on top layer): + data[i] = (feature[0][i], feature[1][i], ...) + 2. Y is not None, M is None (train on top layer): + data[i] = (feature[0][i], feature[1][i], ..., label_values[i]) + 3. Y is None, M is not None (inference on middle layer): + data[i] = (feature[0][i], feature[1][i], ..., label_indices[i]) + """ + + def __init__( + self, + input_tensorizer, + output_tensorizer=None, + ): + if output_tensorizer is None: + output_tensorizer = XMCLabelTensorizer() + + if len(output_tensorizer) > 0: + if len(input_tensorizer) != len(output_tensorizer): + raise ValueError( + f"Dimension 0 mismatch: {len(input_tensorizer)} != {len(output_tensorizer)}" + ) + + self.input_tensorizer = input_tensorizer + self.output_tensorizer = output_tensorizer + + def __len__(self): + return len(self.input_tensorizer) + + def get_shard(self, start, end): + return self.__class__( + self.input_tensorizer.get_shard(start, end), + self.output_tensorizer.get_shard(start, end), + ) + + def save(self, save_dir, num_shards=None, init_shard_idx=0): + if num_shards is None: + num_shards = 1 + + os.makedirs(save_dir, exist_ok=True) + param = { + "model": self.__class__.__name__, + "num_shards": num_shards, + "num_instances": len(self), + } + with open(f"{save_dir}/config.json", "w") as f: + f.write(json.dumps(param, indent=True)) + + chunk_size = (len(self) + num_shards - 1) // num_shards + for sid in range(init_shard_idx, init_shard_idx + num_shards): + cur_chunk_dir = f"{save_dir}/{sid}" + start = chunk_size * sid + end = min(chunk_size * (sid + 1), len(self)) + torch.save(self.get_shard(start, end), cur_chunk_dir, pickle_protocol=4) + LOGGER.info(f"Shard{sid} saved to {cur_chunk_dir}, len={end - start}") + + @classmethod + def get_data_stats(cls, load_dir): + with open(f"{load_dir}/config.json", "r") as f: + config = json.load(f) + return config + + @classmethod + def load(cls, load_dir, shard=0): + nr_shards = cls.get_data_stats(load_dir)["num_shards"] + if shard >= nr_shards: + raise ValueError(f"Loading shard#{shard} where there are only {nr_shards} available") + return torch.load(f"{load_dir}/{shard}") + + @property + def has_ns(self): + return self.output_tensorizer.has_ns + + @property + def num_active_labels(self): + return self.output_tensorizer.num_active_labels + + def __getitem__(self, i): + return self.input_tensorizer[i] + self.output_tensorizer[i] diff --git a/pecos/xmc/xtransformer/network.py b/pecos/xmc/xtransformer/network.py index 1989d45..c2df706 100644 --- a/pecos/xmc/xtransformer/network.py +++ b/pecos/xmc/xtransformer/network.py @@ -103,30 +103,38 @@ class TransformerLinearXMCHead(nn.Module): Containing label weight embeddings and label bias embeddings """ - def __init__(self, hidden_size, num_labels): + def __init__(self, hidden_size, num_labels, sparse=False): super().__init__() - self.label_pad = num_labels + padding_idx = num_labels self.num_labels = num_labels - self.W = nn.Embedding(num_labels + 1, hidden_size, padding_idx=self.label_pad) - self.b = nn.Embedding(num_labels + 1, 1, padding_idx=self.label_pad) + self.W = nn.Embedding(num_labels + 1, hidden_size, padding_idx=padding_idx) + self.b = nn.Embedding(num_labels + 1, 1, padding_idx=padding_idx) - self.random_init() + self.random_init(sparse=sparse) + + @property + def label_padding_idx(self): + return self.W.padding_idx + + @property + def is_sparse(self): + return self.W.sparse @property def device(self): return self.W.weight.device - def random_init(self): + def random_init(self, sparse=False): """Initialize the weight and bias embeddings Initialize label weight embedding with N(0, 0.02) while keeping PAD column to be 0. Initialize label bias embedding with 0. """ - mat = 0.02 * np.random.randn(self.label_pad, self.W.weight.shape[1]) + mat = 0.02 * np.random.randn(self.label_padding_idx, self.W.weight.shape[1]) mat = np.hstack([mat, np.zeros([mat.shape[0], 1])]) - self.init_from(mat) + self.init_from(mat, sparse=sparse) - def inherit(self, prev_head, C): + def inherit(self, prev_head, C, sparse=False): prev_W = prev_head.W.weight[:-1, :].detach().numpy() prev_b = prev_head.b.weight[:-1, :].detach().numpy() @@ -135,7 +143,7 @@ def inherit(self, prev_head, C): mat = np.hstack([cur_W, cur_b]) - self.init_from(mat) + self.init_from(mat, sparse=sparse) def bootstrap(self, prob, **kwargs): """Initialize head with weights learned from linear model using transformer embeddings @@ -154,9 +162,9 @@ def bootstrap(self, prob, **kwargs): threshold = kwargs.get("threshold", 0) mat = MLModel.train(prob, threshold=threshold, Cp=Cp, Cn=Cn) mat = mat.W.toarray().T - self.init_from(mat) + self.init_from(mat, sparse=kwargs.get("sparse", False)) - def init_from(self, mat): + def init_from(self, mat, sparse=False): """Initialize the weight and bias embeddings with given matrix Args: @@ -164,7 +172,7 @@ def init_from(self, mat): """ if not isinstance(mat, np.ndarray): raise ValueError("Expect ndarray to initialize label embedding") - if mat.shape[0] != self.label_pad: + if mat.shape[0] != self.label_padding_idx: raise ValueError("nr_labels mismatch!") # add padding index by appending an all-zero row @@ -173,14 +181,14 @@ def init_from(self, mat): self.W = nn.Embedding.from_pretrained( torch.FloatTensor(mat[:, :-1]), freeze=False, - sparse=True, - padding_idx=self.label_pad, + sparse=sparse, + padding_idx=self.label_padding_idx, ) self.b = nn.Embedding.from_pretrained( - torch.FloatTensor(mat[:, -1]).view((self.label_pad + 1, 1)), + torch.FloatTensor(mat[:, -1]).view((self.label_padding_idx + 1, 1)), freeze=False, - sparse=True, - padding_idx=self.label_pad, + sparse=sparse, + padding_idx=self.label_padding_idx, ) def forward(self, pooled_output=None, output_indices=None, num_device=1): diff --git a/pecos/xmc/xtransformer/train.py b/pecos/xmc/xtransformer/train.py index c22af9e..cf40ab0 100644 --- a/pecos/xmc/xtransformer/train.py +++ b/pecos/xmc/xtransformer/train.py @@ -513,8 +513,8 @@ def do_train(args): if os.path.exists(args.code_path): cluster_chain = ClusterChain.from_partial_chain( smat_util.load_matrix(args.code_path), - min_codes=args.min_codes, - nr_splits=args.nr_splits, + min_codes=train_params.preliminary_indexer_params.min_codes, + nr_splits=train_params.preliminary_indexer_params.nr_splits, ) LOGGER.info("Loaded from code-path: {}".format(args.code_path)) else: @@ -535,6 +535,10 @@ def do_train(args): del label_feat gc.collect() + if args.code_path: + smat_util.save_matrix(args.code_path, cluster_chain[-1]) + LOGGER.info(f"Saved clusters {cluster_chain[-1].shape} to {args.code_path}") + trn_prob = MLProblemWithText(trn_corpus, Y_trn, X_feat=X_trn) if all(v is not None for v in [tst_corpus, Y_tst]): val_prob = MLProblemWithText(tst_corpus, Y_tst, X_feat=X_tst) diff --git a/setup.py b/setup.py index 61e6148..d6f4943 100644 --- a/setup.py +++ b/setup.py @@ -119,7 +119,7 @@ def get_blas_lib_dir(cls): 'torch>=1.8.0', 'sentencepiece>=0.1.86,!=0.1.92', # 0.1.92 results in error for transformers 'transformers>=4.1.1; python_version<"3.9"', - 'transformers==4.4.2; python_version>="3.9"' # Python 3.9 only support transformer 4.4.2 + 'transformers>=4.4.2; python_version>="3.9"' ] # Fetch Numpy before building Numpy-dependent extension, if Numpy required version was not installed diff --git a/test/pecos/distributed/xmc/test_dist_xmc.py b/test/pecos/distributed/xmc/test_dist_xmc.py index 9b0b21f..8ce097d 100644 --- a/test/pecos/distributed/xmc/test_dist_xmc.py +++ b/test/pecos/distributed/xmc/test_dist_xmc.py @@ -12,7 +12,7 @@ import scipy.sparse as smat import numpy as np from pecos.utils.cluster_util import ClusterChain -from pecos.distributed.diagnostic_util.test_util import DummyComm +from pecos.distributed.diagnostic_tools.test_util import DummyComm class GenerateClusterChain(object): diff --git a/test/pecos/distributed/xmc/xlinear/test_dist_xlinear.py b/test/pecos/distributed/xmc/xlinear/test_dist_xlinear.py index aa560cd..5ed2067 100644 --- a/test/pecos/distributed/xmc/xlinear/test_dist_xlinear.py +++ b/test/pecos/distributed/xmc/xlinear/test_dist_xlinear.py @@ -11,7 +11,7 @@ import pytest # noqa: F401; pylint: disable=unused-variable import scipy.sparse as smat import numpy as np -from pecos.distributed.diagnostic_util.test_util import DummyComm +from pecos.distributed.diagnostic_tools.test_util import DummyComm from pecos.utils.cluster_util import ClusterChain