Skip to content

Commit

Permalink
distributed XR-Transformer fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang authored and jiong-zhang committed Jul 7, 2022
1 parent 6723f61 commit a09375c
Show file tree
Hide file tree
Showing 15 changed files with 1,691 additions and 240 deletions.
89 changes: 89 additions & 0 deletions pecos/distributed/diagnostic_util/deepspeed_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3 -u
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
# with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
# and limitations under the License.

import argparse
import deepspeed
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import socket
from pecos.utils import logging_util
from pecos.distributed.xmc.xtransformer.module import DeepSpeedUtils

LOGGER = logging.getLogger(__name__)


class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)

def forward(self, x):
return F.relu(self.conv1(x))


def parse_arguments():
"""Parse evaluation arguments"""

parser = argparse.ArgumentParser()

parser.add_argument(
"--verbose-level",
type=int,
choices=logging_util.log_levels.keys(),
default=2,
metavar="INT",
help=f"the verbose level, {', '.join([str(k) + ' for ' + logging.getLevelName(v) for k, v in logging_util.log_levels.items()])}. Default 2",
)

parser.add_argument(
"--local_rank",
type=int,
default=0,
metavar="RANK",
help="distributed rank",
)
return parser


def dist_check(args):
local_rank = args.local_rank

model = DummyModel()
ds_config = DeepSpeedUtils.get_config()
_ = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=model.parameters(),
)
torch_rank = torch.distributed.get_rank()
ip = socket.gethostbyname(socket.gethostname())

world_size = int(os.getenv("WORLD_SIZE", "1"))
LOGGER.info(
f"Report from {ip}: local_rank={local_rank}, torch_rank={torch_rank}, world_size={world_size}"
)


if __name__ == "__main__":
"""
Sanity check for deepspeed distributed
Usage:
deepspeed --hostfile [PATH_TO_HOSTFILE] --module pecos.distributed.diagnostic_util.deepspeed_test
"""
parser = parse_arguments()
args = parser.parse_args()
logging_util.setup_logging_config(level=args.verbose_level)
dist_check(args)
10 changes: 5 additions & 5 deletions pecos/distributed/xmc/xlinear/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def parse_arguments():
parser.add_argument(
"--nr-splits",
type=int,
default=2,
default=16,
metavar="INT",
help="number of splits used to construct hierarchy (a power of 2 is recommended)",
)
Expand All @@ -86,15 +86,15 @@ def parse_arguments():
"-k",
"--only-topk",
type=int,
default=20,
default=None,
metavar="INT",
help="the default number of top labels used in the prediction",
)
parser.add_argument(
"-b",
"--beam-size",
type=int,
default=10,
default=None,
metavar="INT",
help="the default size of beam search used in the prediction",
)
Expand All @@ -110,7 +110,7 @@ def parse_arguments():
"--post-processor",
type=str,
choices=PostProcessor.valid_list(),
default="l3-hinge",
default=None,
metavar="STR",
help="the default post processor used in the prediction",
)
Expand Down Expand Up @@ -170,7 +170,7 @@ def do_train(args):
mpi_comm = MPIComm()

# Parse args
args_dict = vars(args)
args_dict = {k: v for k, v in vars(args).items() if v is not None}
train_params = DistributedCPUXLinearModel.TrainParams.from_dict(args_dict, recursive=True)
cluster_params = DistClustering.ClusterParams(
indexer_params=HierarchicalKMeans.TrainParams.from_dict(args_dict),
Expand Down
Empty file.
160 changes: 160 additions & 0 deletions pecos/distributed/xmc/xtransformer/encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python3 -u
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
# with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
# and limitations under the License.

import argparse
import deepspeed
import os
import torch
import logging
import itertools

from pecos.utils import smat_util, logging_util
from pecos.xmc.xtransformer.model import XTransformer

LOGGER = logging.getLogger(__name__)


def parse_arguments():
"""Parse evaluation arguments"""

parser = argparse.ArgumentParser()

parser.add_argument(
"-t",
"--text-path",
type=str,
required=True,
metavar="PATH",
help="Path to the instance text file.",
)
parser.add_argument(
"-m",
"--model-folder",
type=str,
required=True,
metavar="PATH",
help="Path to load x-transformer model.",
)
parser.add_argument(
"-o",
"--save-emb-folder",
type=str,
required=True,
metavar="PATH",
help="The folder in which the embeddings will be written (in WORLD_SIZE shards)",
)

# ======= Other parameters ========
parser.add_argument(
"--batch-size",
default=32,
type=int,
metavar="INT",
help="Batch size per GPU.",
)
parser.add_argument(
"--batch-gen-workers",
type=int,
metavar="INT",
default=4,
help="number of CPUs to use for batch generation",
)
parser.add_argument(
"--truncate-length",
default=None,
type=int,
metavar="INT",
help="max number of tokens to encode",
)
parser.add_argument(
"--max-pred-chunk",
default=10**7,
metavar="INT",
type=int,
help="Max number of instances to predict on at once, set to avoid OOM. Set to None to predict on all instances at once. Default 10^7",
)
parser.add_argument(
"--verbose-level",
type=int,
choices=logging_util.log_levels.keys(),
default=1,
metavar="INT",
help=f"the verbose level, {', '.join([str(k) + ' for ' + logging.getLevelName(v) for k, v in logging_util.log_levels.items()])}, default 1",
)

parser.add_argument(
"--local_rank",
type=int,
default=0,
metavar="RANK",
help="distributed rank",
)
return parser


def dist_encode(args):
"""Encode with XTransformer model in distributed fashion.
Each worker will encode an exclusive chunk and save the result to args.save_emb_folder/X.emb.[WORKER_RANK].npy
Args:
args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()`
"""
os.makedirs(args.save_emb_folder, exist_ok=True)

local_rank = args.local_rank
world_size = int(os.getenv("WORLD_SIZE", "1"))
deepspeed.init_distributed(dist_backend="nccl")

global_rank = torch.distributed.get_rank()
LOGGER.info(
f"Initialized device for rank={local_rank}, global_rank={global_rank}, world_size={world_size}"
)
xtf = XTransformer.load(args.model_folder)

# get number of lines without reading all in memory
nr_inst = sum(1 for line in open(args.text_path, "r"))
chunk_size = (nr_inst + world_size - 1) // world_size
start = global_rank * chunk_size
end = min(nr_inst, start + chunk_size)

with open(args.text_path, "r") as fin:
X = []
for line in itertools.islice(fin, start, end):
X.append(line.strip())
LOGGER.info(f"Rank{global_rank}/{world_size} will encode {start} to {end}")

pred_params = xtf.get_pred_params()
for i in range(len(pred_params.matcher_params_chain)):
if args.truncate_length:
pred_params.matcher_params_chain[i].truncate_length = args.truncate_length

X_emb = xtf.encode(
X,
batch_size=args.batch_size,
batch_gen_workers=args.batch_gen_workers,
device_id=local_rank,
pred_params=pred_params,
max_pred_chunk=args.max_pred_chunk,
)

local_tgt = os.path.join(args.save_emb_folder, f"X.emb.{global_rank}.npy")
smat_util.save_matrix(local_tgt, X_emb)
LOGGER.info(f"Rank{global_rank}/{world_size} saved embedding {X_emb.shape} to {local_tgt}")

torch.distributed.destroy_process_group()


if __name__ == "__main__":
parser = parse_arguments()
args = parser.parse_args()
logging_util.setup_logging_config(level=args.verbose_level)
dist_encode(args)
Loading

0 comments on commit a09375c

Please sign in to comment.