diff --git a/run_llama_train.sh b/run_llama_train.sh index 13b66aeac..05ebde763 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -17,8 +17,8 @@ NGPU=${NGPU:-"8"} LOG_RANK=${LOG_RANK:-0} -CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama_7b.toml"} -torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \ +torchrun --nproc_per_node=${NGPU} --rdzv-backend=c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d11fac9f2..4e6104109 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -8,6 +8,7 @@ from collections import defaultdict import torch +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import ( distribute_module, distribute_tensor, @@ -20,13 +21,6 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, ) -from torch.distributed.fsdp import ( - BackwardPrefetch, - FullyShardedDataParallel as FSDP, - MixedPrecision, - ShardingStrategy, -) -from torch.distributed.fsdp.wrap import enable_wrap, wrap from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -35,7 +29,6 @@ ) from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import rank0_log -from torchtrain.meta_init import meta_to_real_init_fn logger = logging.getLogger(__name__) @@ -76,6 +69,7 @@ def partition_fn(name, module, device_mesh): torch.ops.c10d_functional.reduce_scatter_tensor.default, } + # Uses PTD FSDP AC wrapper def checkpoint_wrapper(module, enable_selective_ac): if enable_selective_ac: @@ -153,6 +147,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ), }, ) + distribute_rmsnorm(model.norm, tp_mesh) # apply sequence parallelism to every transformer block for layer_id, transformer_block in enumerate(model.layers): @@ -194,40 +189,26 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - - fsdp_config = { - "mixed_precision": MixedPrecision( - param_dtype=torch.bfloat16, - # TODO: see whether we should expose a option to user - reduce_dtype=torch.float32, - ), - "sharding_strategy": ShardingStrategy.FULL_SHARD, - "backward_prefetch": BackwardPrefetch.BACKWARD_PRE, - # When torch.compile is active, it requires us to set use_orig_params=True - "use_orig_params": True, - "device_mesh": dp_mesh, - "param_init_fn": meta_to_real_init_fn, - } - - with enable_wrap(wrapper_cls=FSDP, **fsdp_config): - for layer_id, transformer_block in enumerate(model.layers): - - # apply selective AC - transformer_block = checkpoint_wrapper( - transformer_block, job_config.training.enable_selective_ac - ) - - # Wraps each layer with FSDP - model.layers[layer_id] = wrap(transformer_block) - - # wrap the rest layers with FSDP - model = wrap(model) - + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + ) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + for layer_id, transformer_block in enumerate(model.layers): + transformer_block = checkpoint_wrapper( + transformer_block, job_config.training.enable_selective_ac + ) + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = layer_id < len(model.layers) - 1 + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + model.layers[layer_id] = transformer_block + model = fully_shard(model, **fsdp_config) rank0_log("Applied FSDP to the model...") else: model.cuda() - # we have now moved from meta to device, - # reset parameters for proper initialization - model.reset_parameters() return model diff --git a/train.py b/train.py index 9c8e2f7b8..558abdc5b 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,6 @@ # torch imports import torch import torch.nn.functional as F -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torchtrain.checkpoint import CheckpointManager, IntervalType @@ -62,9 +61,9 @@ def build_optimizer(model, job_config: JobConfig): name = job_config.optimizer.name lr = job_config.optimizer.lr if name == "Adam": - optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer = torch.optim.Adam(model.parameters(), lr=lr, foreach=True) elif name == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, foreach=True) else: raise NotImplementedError(f"optimizer {name} not added") @@ -73,13 +72,14 @@ def build_optimizer(model, job_config: JobConfig): def build_grad_scaler(model): # apply gradient scaling if mixed precision training is enabled with fp16 param dtype - if model.mixed_precision.param_dtype == torch.float16: - enable_grad_scaling = True - rank0_log("Enabling gradient scaling for mixed precision training.") - else: - enable_grad_scaling = False - rank0_log("Gradient scaling not enabled.") - + # TODO: We do not expose the mixed precision attribute. This is low + # priority since we do not use fp16. + # if model.mixed_precision.param_dtype == torch.float16: + # enable_grad_scaling = True + # rank0_log("Enabling gradient scaling for mixed precision training.") + # else: + enable_grad_scaling = False + rank0_log("Gradient scaling not enabled.") return ShardedGradScaler(enabled=enable_grad_scaling) @@ -121,8 +121,8 @@ def main(job_config: JobConfig): model_config.vocab_size = tokenizer.n_words # build model using meta init - with meta_model_init(): - model = model_cls.from_model_args(model_config) + # with meta_model_init(): + model = model_cls.from_model_args(model_config) # log model size model_param_count = get_num_params(model) @@ -145,9 +145,6 @@ def main(job_config: JobConfig): model, world_mesh, parallel_dims, job_config ) - # to use FSDP-customized gradient scaler and gradient clipping solutions - assert isinstance(model, FSDP) - # build optimizer after apply parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) @@ -163,9 +160,7 @@ def main(job_config: JobConfig): True ) rank0_log(f"Compiling model {model_name} with torch.compile...") - model = torch.compile( - model, - ) + model = torch.compile(model) train_state = TrainState() @@ -224,7 +219,12 @@ def main(job_config: JobConfig): # clip gradients (after unscaling gradients of the optimizer's params) scaler.unscale_(optimizer) - model.clip_grad_norm_(job_config.training.max_norm) + # TODO: Disable `clip_grad_norm_()` until it is supported: + # https://github.com/pytorch/pytorch/pull/120238 + # torch.nn.utils.clip_grad_norm_( + # model.parameters(), job_config.training.max_norm, foreach=True + # ) + # model.clip_grad_norm_(job_config.training.max_norm) # optimizer step # If gradients don't contain infs/NaNs, optimizer.step() is then called; diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 2f6f3051e..295943fdc 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -4,7 +4,7 @@ dump_folder = "./outputs" description = "debug training" [profiling] -run_profiler = true +run_profiler = false save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 diff --git a/train_configs/llama_7b.toml b/train_configs/llama_7b.toml index 2b8b5015f..85d5098dc 100644 --- a/train_configs/llama_7b.toml +++ b/train_configs/llama_7b.toml @@ -7,7 +7,7 @@ description = "llama 7b training" run_profiler = true save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled -profile_every_x_iter = 100 +profile_every_x_iter = 10 [metrics] enable_tensorboard = true @@ -29,12 +29,13 @@ batch_size = 8 seq_len = 2048 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 10 # only dp would be sufficient for 7B data_parallel_degree = -1 sequence_parallel_degree = 1 pipeline_parallel_degree = 1 compile = false +enable_selective_ac = false checkpoint_interval = 3600 checkpoint_interval_type = "steps" checkpoint_folder = ""