diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e9e7014425..0855e3856f 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -431,10 +431,16 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id) + begin_load = time.monotonic() + logger.info(f"Starting dcp.load with {hf_storage_reader}") dcp.load( hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + storage_reader=hf_storage_reader, + ) + logger.info( + f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds" ) state_dict = self.sd_adapter.from_hf(hf_state_dict) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 227c2ca211..b36aec855c 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -103,7 +103,8 @@ def _custom_policy(ctx, func, *args, **kwargs): mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE + # return CheckpointPolicy.PREFER_RECOMPUTE + return CheckpointPolicy.MUST_SAVE # TODO(jianiw): testing meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm to_save = func in op_sac_save_list and not ( diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 12512bfac0..4761f331d8 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -272,6 +272,7 @@ def wrapper( num_ep_ranks, padded_max_len, TOKEN_GROUP_ALIGN_SIZE_M, + use_cpu=True ) x = torch.vstack((x, x.new_zeros((x.shape[-1])))) diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py index 760cc662be..600d9b511e 100644 --- a/torchtitan/experiments/qwen3/model/state_dict_adapter.py +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -15,6 +15,8 @@ import re from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import Qwen3ModelArgs @@ -45,6 +47,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_hf_map = {v: k for k, v in self.from_hf_map.items()} diff --git a/torchtitan/models/deepseek-v3/model/model.py b/torchtitan/models/deepseek-v3/model/model.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e66..b13490e0a3 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -78,9 +78,9 @@ attn_mask_type="block_causal", ), "16B": DeepSeekV3ModelArgs( - vocab_size=102400, + vocab_size=163840, dim=2048, - inter_dim=10944, + inter_dim=11264, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, @@ -92,6 +92,7 @@ score_func="softmax", route_norm=False, score_before_experts=False, + use_grouped_mm=False, ), q_lora_rank=0, kv_lora_rank=512, @@ -99,8 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -155,6 +156,7 @@ v_head_dim=128, use_flex_attn=True, attn_mask_type="block_causal", + hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index d6afedfa34..b27b7a9d50 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -86,6 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_slow: int = 1 mscale: float = 1.0 + # HF checkpoint args + hf_weight_quantized: bool = False + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len if seq_len > self.max_seq_len: diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index dc612fafb7..14ea53fe5e 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -11,10 +11,11 @@ from torch import nn from torchtitan.models.attention import build_attention -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import FeedForward, MoE, create_tensor_hook from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs +from torchtitan.tools.logging import logger # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 @@ -284,6 +285,52 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id + # Register backward hook to monitor gradients + # self.register_full_backward_hook(self._layer_gradient_hook) + # logger.info(f"[HOOK REGISTRATION] Layer {self.layer_id} TransformerBlock gradient hook registered") + + def _layer_gradient_hook(self, module, grad_input, grad_output): + """Backward hook to monitor gradients of all parameters in this layer.""" + logger.info(f"[LAYER GRAD HOOK] Layer {self.layer_id} TransformerBlock backward pass") + + total_params = 0 + + # Check gradients for all named parameters in this layer + for name, param in self.named_parameters(): + total_params += 1 + if param.grad is not None: + if param.grad.dtype.is_floating_point or param.grad.dtype.is_complex: + # Check for NaN and Inf elements first + has_nan = torch.isnan(param.grad).any().item() + has_inf = torch.isinf(param.grad).any().item() + + if has_nan: + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"norm=NaN, max=NaN, mean=NaN, shape={param.grad.shape}") + elif has_inf: + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"norm=Inf, max=Inf, mean=Inf, shape={param.grad.shape}") + else: + # Calculate gradient statistics safely + try: + grad_norm = param.grad.norm().item() + grad_max = param.grad.abs().max().item() + grad_mean = param.grad.mean().item() + + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"norm={grad_norm:.6e}, max={grad_max:.6e}, mean={grad_mean:.6e}, " + f"shape={param.grad.shape}") + except: + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"failed to compute stats, shape={param.grad.shape}") + else: + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: " + f"dtype={param.grad.dtype} (non-float grad)") + else: + logger.info(f"[PARAM GRAD] Layer {self.layer_id} - {name}: grad is None") + + logger.info(f"[LAYER GRAD SUMMARY] Layer {self.layer_id}: total_params={total_params}") + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): """ Forward pass for the Transformer block. @@ -295,11 +342,22 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + t1 = self.attention_norm(x) + # t1.register_hook(create_tensor_hook("t_after_attention_norm")) + x = x + self.attention(t1, freqs_cis) + # x.register_hook(create_tensor_hook("t_after_attn")) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + t = self.ffn_norm(x) + # t.register_hook(create_tensor_hook("t_after_ffn_norm")) + x = x + self.moe(t) + # x.register_hook(create_tensor_hook("x_after_moe")) + else: - x = x + self.feed_forward(self.ffn_norm(x)) + t = self.ffn_norm(x) + # t.register_hook(create_tensor_hook("t_after_ffn_norm")) + x = x + self.feed_forward(t) + # x.register_hook(create_tensor_hook("x_after_feedforward")) + return x def init_weights(self, buffer_device: torch.device): diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003a2..0000000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index e947d70695..d22faadf1d 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -5,17 +5,20 @@ # LICENSE file in the root directory of this source tree. +import logging import re +import time from typing import Any import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.tools.logging import logger from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 class DeepSeekV3StateDictAdapter(StateDictAdapter): @@ -78,6 +81,24 @@ def __init__( self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + if self.model_args.hf_weight_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + def _calculate_strided_shard_shard_indices( self, strided_shard_dim_degree: int, @@ -220,6 +241,10 @@ def _get_local_experts_weights( Returns: Dictionary mapping individual expert keys to their DTensor weights """ + start_time = time.time() + logger.info( + f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}" + ) device_mesh = grouped_expert_weight.device_mesh dtensor_placements = grouped_expert_weight.placements @@ -285,6 +310,11 @@ def _get_local_experts_weights( local_expert_tensors[expert_key] = expert_dtensor + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return local_expert_tensors def _concatenate_expert_weights_dtensor( @@ -312,6 +342,10 @@ def _concatenate_expert_weights_dtensor( Returns: Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ + start_time = time.time() + logger.info( + f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}" + ) # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer_num][abstract_key] expected_n_experts = ( @@ -341,6 +375,11 @@ def _concatenate_expert_weights_dtensor( if not expert_weights_by_layer[layer_num]: del expert_weights_by_layer[layer_num] + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return stacked_dtensor def _split_experts_weights( @@ -398,61 +437,14 @@ def _concatenate_expert_weights( return stacked_tensor - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: - """ - Dequantize the weights from float8 to float32. - """ - - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict - def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. 2. Split the GroupedExperts' weight into separate expert's wegiht. """ + start_time = time.time() + logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys") + to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -480,6 +472,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict.update(local_expert_fqn) else: + logger.info( + f"Using the old torch.split for value {new_abstract_key} " + ) # keep this path for offline conversion split_values = self._split_experts_weights( value, self.model_args.moe_args.num_experts @@ -500,11 +495,12 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s" ) - return hf_state_dict_with_scale_inv + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -512,12 +508,12 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 2. Convert between the HF shape and the torchtitan shape. 3. Concate separate expert's wegiht into GroupedExperts' weight. """ + start_time = time.time() + logger.info( + f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys" + ) - # dequantize the tensor in state_dict and remove the scale_inv tensor - - hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): @@ -544,6 +540,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: value.device_mesh, ) else: # keep this path to be compatibile with offline conversion + logger.info( + f"Using the old torch.split for value {titan_abstract_key} " + ) stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, @@ -565,4 +564,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: new_key = self.from_hf_map[key] state_dict[new_key] = value + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s" + ) return state_dict diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index d0cd250583..e9d4f21550 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -35,11 +35,12 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 8 +local_batch_size = 1 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 3 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +deterministic = true [parallelism] data_parallel_replicate_degree = 1 @@ -49,23 +50,25 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 8 +expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" -interval = 10 +interval = 100 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_in_hf = true +initial_load_path = "/data/users/jianiw/model/Moonlight-16B-A3B" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable=false components = ["loss"] # ["model", "loss"] [quantize.dense.float8] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 5e933a4772..4a63f599a5 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -25,7 +25,7 @@ hf_assets_path = "./assets/hf/DeepSeek-V3.1-Base" [optimizer] name = "AdamW" -lr = 2.2e-4 +lr = 2.2e-6 eps = 1e-8 [lr_scheduler] @@ -35,11 +35,10 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 4 +local_batch_size = 1 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10_000 -compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -54,19 +53,21 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_in_hf = true +initial_load_path = "/home/jianiw/tmp/mffuse/deepseek-v3/DeepSeek-V3.1-Base" [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "full" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [quantize.dense.float8] diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2c386ece0d..1475ba2055 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -10,6 +10,7 @@ logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs @@ -41,6 +42,9 @@ def __init__( "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + # HuggingFace permutation function (exact copy from their conversion script) def _permute(self, w, n_heads_arg, dim1=None, dim2=None): if dim1 is None: diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 9f519dc04e..cab150cc9b 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools from dataclasses import dataclass from typing import Literal @@ -12,6 +13,77 @@ from torch import nn from torchtitan.distributed.expert_parallel import expert_parallel +from torchtitan.tools.logging import logger + + +def _tensor_gradient_hook_fn(tensor_name: str, grad): + """ + Utility function for tensor gradient hooks that prints gradient statistics. + This function signature matches the format expected by tensor.register_hook(). + + Args: + tensor_name (str): Name identifier for the tensor + grad (torch.Tensor): Gradient tensor + """ + if grad is None: + logger.info(f"[TENSOR GRAD] {tensor_name}: grad is None") + return + + if not (grad.dtype.is_floating_point or grad.dtype.is_complex): + logger.info(f"[TENSOR GRAD] {tensor_name}: dtype={grad.dtype} (non-float grad)") + return + + # Check for NaN and Inf elements + has_nan = torch.isnan(grad).any().item() + has_inf = torch.isinf(grad).any().item() + + if has_nan: + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"mean=NaN, max=NaN, min=NaN, has_nan=True, has_inf={has_inf}, " + f"shape={grad.shape}" + ) + elif has_inf: + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"mean=Inf, max=Inf, min=Inf, has_nan=False, has_inf=True, " + f"shape={grad.shape}" + ) + else: + # Calculate gradient statistics safely + try: + grad_mean = grad.mean().item() + grad_max = grad.max().item() + grad_min = grad.min().item() + + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"mean={grad_mean:.6e}, max={grad_max:.6e}, min={grad_min:.6e}, " + f"has_nan=False, has_inf=False, shape={grad.shape}" + ) + except Exception as e: + logger.info( + f"[TENSOR GRAD] {tensor_name}: " + f"failed to compute stats: {e}, shape={grad.shape}" + ) + + +def create_tensor_hook(tensor_name: str): + """ + Utility function to create a tensor gradient hook using functools.partial. + This follows the pattern shown in the user's example. + + Args: + tensor_name (str): Name identifier for the tensor being hooked + + Returns: + Callable: Hook function that can be registered on a tensor using tensor.register_hook() + + Example usage: + hook_fn = create_tensor_hook("my_tensor") + tensor.register_hook(hook_fn) + """ + return functools.partial(_tensor_gradient_hook_fn, tensor_name) @dataclass @@ -370,14 +442,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) - # top_scores and selected_experts_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) + # Register hook on input tensor + # x.register_hook(create_tensor_hook("moe_input")) + + # Router forward pass - compute expert scores and routing ( top_scores, selected_experts_indices, num_tokens_per_expert, ) = self.router(x, self.expert_bias) + # Register hooks on router outputs for gradient monitoring + # top_scores.register_hook(create_tensor_hook("router_top_scores")) + # tokens_per_expert will be used to update the expert bias for load balancing. # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- @@ -407,21 +484,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # shape (bs*slen*top_k, dim) routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + # routed_input.register_hook(create_tensor_hook("routed_input")) if self.score_before_experts: routed_input = ( routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + # routed_input.register_hook(create_tensor_hook("routed_input_after_scoring")) # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + # routed_output.register_hook(create_tensor_hook("experts_output")) # shared expert # Note: we execute the shared expert before scoring the output of the routed expert # to "implicitly" overlap the shared expert compute with token combine communication if self.shared_experts is not None: out = self.shared_experts(x) + # out.register_hook(create_tensor_hook("shared_experts_output")) else: out = torch.zeros_like(x) @@ -430,10 +511,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + # routed_output.register_hook(create_tensor_hook("routed_output_after_scoring")) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output ) + # out.register_hook(create_tensor_hook("moe_final_output")) + out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9bbf..a6fedd0af4 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -11,6 +11,9 @@ from abc import ABC, abstractmethod from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + + logger = logging.getLogger() from .model import BaseModelArgs @@ -58,6 +61,19 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + THe HuggingFace storage reader to read rom HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" diff --git a/torchtitan/train.py b/torchtitan/train.py index c343279dda..6685948ce9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -409,6 +409,9 @@ def batch_generator( def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: + + # torch.autograd.set_detect_anomaly(True) + model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -473,7 +476,34 @@ def forward_backward_step( loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred + # with torch.autograd.detect_anomaly(): loss.backward() + + # Check for NaN/Inf gradients after backward pass + print(f"[GRAD CHECK] Checking gradients after backward pass...") + nan_params = [] + inf_params = [] + total_params = 0 + total_with_grad = 0 + + for name, param in model_parts[0].named_parameters(): + total_params += 1 + if param.grad is not None: + total_with_grad += 1 + if torch.isnan(param.grad).any(): + nan_params.append(name) + print(f"[GRAD NaN] {name}: shape={param.grad.shape}, norm={param.grad.norm():.6f}") + elif torch.isinf(param.grad).any(): + inf_params.append(name) + print(f"[GRAD INF] {name}: shape={param.grad.shape}, norm={param.grad.norm():.6f}") + + print(f"[GRAD SUMMARY] Total params: {total_params}, with grad: {total_with_grad}") + print(f"[GRAD SUMMARY] NaN gradients: {len(nan_params)}, Inf gradients: {len(inf_params)}") + + if nan_params: + print(f"[GRAD ERROR] Parameters with NaN gradients: {nan_params[:10]}...") # Show first 10 + if inf_params: + print(f"[GRAD ERROR] Parameters with Inf gradients: {inf_params[:10]}...") # Show first 10 return loss