Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[train_engine] support fsdp
Browse files Browse the repository at this point in the history
Mddct committed Mar 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent c8084ef commit 55f955b
Showing 3 changed files with 148 additions and 12 deletions.
17 changes: 8 additions & 9 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
@@ -30,11 +30,11 @@
from wenet.utils.init_model import init_model
from wenet.utils.init_tokenizer import init_tokenizer
from wenet.utils.train_utils import (
add_model_args, add_dataset_args, add_ddp_args, add_deepspeed_args,
add_trace_args, init_distributed, init_dataset_and_dataloader,
check_modify_and_save_config, init_optimizer_and_scheduler,
trace_and_print_model, wrap_cuda_model, init_summarywriter, save_model,
log_per_epoch)
add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args,
add_deepspeed_args, add_trace_args, init_distributed,
init_dataset_and_dataloader, check_modify_and_save_config,
init_optimizer_and_scheduler, init_scaler, trace_and_print_model,
wrap_cuda_model, init_summarywriter, save_model, log_per_epoch)


def get_args():
@@ -47,6 +47,7 @@ def get_args():
parser = add_dataset_args(parser)
parser = add_ddp_args(parser)
parser = add_deepspeed_args(parser)
parser = add_fsdp_args(parser)
parser = add_trace_args(parser)
args = parser.parse_args()
if args.train_engine == "deepspeed":
@@ -96,7 +97,7 @@ def main():
writer = init_summarywriter(args)

# Dispatch model from cpu to gpu
model, device = wrap_cuda_model(args, model)
model, device = wrap_cuda_model(args, model, configs)

# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(
@@ -118,9 +119,7 @@ def main():
executor.step = configs["init_infos"].get('step', -1) + int("step_" in tag)

# Init scaler, used for pytorch amp mixed precision training
scaler = None
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
scaler = init_scaler(args)

# Start training loop
start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag)
57 changes: 57 additions & 0 deletions wenet/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from functools import partial
import os
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
FullStateDictConfig, StateDictType)

from torch.distributed.fsdp.wrap import (_or_policy, lambda_auto_wrap_policy,
transformer_auto_wrap_policy)
from wenet.utils.checkpoint import save_checkpoint

from wenet.utils.init_model import (WENET_DECODER_CLASSES,
WENET_ENCODER_CLASSES)


def wenet_fsdp_wrap_policy():
to_wrap_class = {
encoder_class_name
for encoder_class_name in WENET_ENCODER_CLASSES.values()
}
to_wrap_class.update({
decoder_class_name
for decoder_class_name in WENET_DECODER_CLASSES.values()
})
# TODO(Mddct):
# 1 wrap transducer's predictor and joint
# 2 wrap paraformer's cif

wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls=to_wrap_class)

# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa
def no_grad_fn(module):
if (len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad):
return True
return False

no_grad_ploicy = partial(lambda_auto_wrap_policy, lambda_fn=no_grad_fn)

auto_wrap_policy = partial(_or_policy,
policies=[no_grad_ploicy, wrap_policy])
return auto_wrap_policy


fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True,
rank0_only=True)


def fsdp_save_model(model, save_model_path, info_dict):
# TODO(Mddct); When the model is large, saving a model will take a long time.
# We only need to keep the sharding in an asynchronous manner, but it is
# good now. This feature will be supported when llm is supported in the future.

rank = int(os.environ.get('RANK', 0))
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
fullstate_save_policy):
save_checkpoint(model, save_model_path, info_dict)
86 changes: 83 additions & 3 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,9 @@
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
CPUOffload, MixedPrecision,
sharded_grad_scaler, ShardingStrategy)
from deepspeed.runtime.zero.stage_1_and_2 import (
estimate_zero2_model_states_mem_needs_all_live)
from deepspeed.runtime.zero.stage3 import (
@@ -35,6 +38,7 @@
convert_zero_checkpoint_to_fp32_state_dict)
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.fsdp_utils import fsdp_save_model, wenet_fsdp_wrap_policy
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id

@@ -135,18 +139,48 @@ def add_deepspeed_args(parser):
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--dtype',
dest='training dtype',
default='fp32',
choices=['fp32', 'fp16', 'bf16'],
help='model compute dtype')

# DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser
parser = deepspeed.add_config_arguments(parser)
return parser


def add_fsdp_args(parser):
parser.add_argument(
'--fsdp_cpu_offload',
default=False,
type=bool,
help='whether to offload parameters to CPU',
)
parser.add_argument(
'--fsdp_sync_module_states',
type=bool,
default=True,
help='\
each FSDP module will broadcast module parameters and buffers from \
rank 0 to ensure that they are replicated across ranks',
)
parser.add_argument(
'--fstp_sharding_strategy',
dest='zero1 or zero1+zeros2 or zeros3',
default='model_only',
choices=['full', 'grad_op', 'no_shard'],
help='FULL_SHARD, SHARD_GRAD_OP, NO_SHARD, see FSDP api')
return parser


def init_distributed(args):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
if args.train_engine == "torch_ddp":
if args.train_engine == "torch_ddp" or args.train_engine == "torch_fsdp":
torch.cuda.set_device(local_rank)
dist.init_process_group(args.dist_backend)
elif args.train_engine == "deepspeed":
@@ -162,6 +196,8 @@ def check_modify_and_save_config(args, configs, symbol_table):
configs["dtype"] = "fp16"
else:
configs["dtype"] = "fp32"
elif args.train_engine == "torch_fsdp":
configs["dtype"] = args.dtype
elif args.train_engine == "deepspeed":
# NOTE(xcsong): DeepSpeed does not support uneven data. When using custom
# dataset, we need to manually ensure that the data is evenly distributed
@@ -275,7 +311,7 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):
return train_dataset, cv_dataset, train_data_loader, cv_data_loader


def wrap_cuda_model(args, model):
def wrap_cuda_model(args, model, configs=None):
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if hasattr(model, 'encoder'):
@@ -310,6 +346,38 @@ def wrap_cuda_model(args, model):
num_nodes=world_size // local_world_size)
device = None # Init device later
pass # Init DeepSpeed later
elif args.train_engine == 'torch_fsdp':
assert configs is not None
mixed_precision_dtype = {
'fp32': torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}[configs['dtype']]

sharding_strategy = {
'full': ShardingStrategy.FULL_SHARD,
'grad_op': ShardingStrategy.SHARD_GRAD_OP,
'no_shar': ShardingStrategy.NO_SHARD,
}[args.fsdp_sharding_strategy]
model = FSDP(
model,
auto_wrap_policy=wenet_fsdp_wrap_policy,
cpu_offload=CPUOffload(offload_params=True)
if args.fsdp_cpu_offload is True else None,
mixed_precision=MixedPrecision(
param_dtype=mixed_precision_dtype,
reduce_dtype=mixed_precision_dtype,
buffer_dtype=mixed_precision_dtype,
),
sharding_strategy=sharding_strategy,
limit_all_gathers=True,
use_orig_params=True,
sync_module_states=args.sync_module_states,
# init_distributed is called (torch.cuda.set_device),
# we should set device_id, see FSDP api
device_id=torch.cuda.current_device(),
)
device = torch.device("cuda")
else:
logging.error("not supported engine: {}".format(args.train_engine))

@@ -388,6 +456,15 @@ def init_summarywriter(args):
return writer


def init_scaler(args):
scaler = None
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
elif args.train_engine == 'torch_fsdp':
scaler = sharded_grad_scaler.ShardedGradScaler()
return scaler


def save_model(model, info_dict):
rank = int(os.environ.get('RANK', 0))
tag = info_dict["tag"]
@@ -410,7 +487,10 @@ def save_model(model, info_dict):
elif rank == 0:
# NOTE(xcsong): For torch_ddp, only rank-0 should call this.
save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
save_checkpoint(model, save_model_path, info_dict)
if info_dict['train_engine'] == "torch_fsdp":
fsdp_save_model(model, save_model_path, info_dict)
else:
save_checkpoint(model, save_model_path, info_dict)
# save yaml
if rank == 0:
with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout:

0 comments on commit 55f955b

Please sign in to comment.