diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1b25aa3f54..e2e643db6b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -19,10 +19,7 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn -from torch.distributed.checkpoint import ( - HuggingFaceStorageReader, - HuggingFaceStorageWriter, -) +from torch.distributed.checkpoint import HuggingFaceStorageWriter from torch.distributed.checkpoint._consolidate_hf_safetensors import ( consolidate_safetensors_files_on_every_rank, ) @@ -249,6 +246,9 @@ def load_state_dict(state_dict): self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf self.initial_load_path = checkpoint_config.initial_load_path + self.initial_load_in_hf_quantized = ( + checkpoint_config.initial_load_in_hf_quantized + ) self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf if self.last_save_in_hf: @@ -418,6 +418,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + from_quantized: bool, ) -> None: """Load the checkpoint with dcp. Args: @@ -432,10 +433,13 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader( + checkpoint_id, from_quantized + ) dcp.load( hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + storage_reader=hf_storage_reader, ) state_dict = self.sd_adapter.from_hf(hf_state_dict) @@ -544,13 +548,21 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + from_quantized = False if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf + from_quantized = self.initial_load_in_hf_quantized if from_hf: assert ( model_only ), "Only model can be loaded when loading from HF's safetensors checkpoint." + + if from_quantized: + assert ( + from_hf + ), "Quantized checkpoint can only be loaded from HuggingFace format." + if self.initial_load_path: checkpoint_id = self.initial_load_path if not os.path.isdir(checkpoint_id): @@ -602,6 +614,7 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, + from_quantized=from_quantized, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -679,6 +692,7 @@ def _ft_load(self) -> None: checkpoint_id=checkpoint_id, # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. from_hf=False, + from_quantized=False, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index eb477941ca..2f200ef181 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -453,6 +453,14 @@ class Checkpoint: non-tensors. The default value is False. """ + initial_load_in_hf_quantized: bool = False + """ + Enable loading of HuggingFace's safetensors format with quantized state dict keys. The option + is only used when `initial_load_path` and `initial_load_path_in_hf` is specified. This will load + checkpoints in HF's model definition and dequantize on model weights if necessary. To support + this parameter, the model need to define proper HuggingFaceStorageReader to perform dequantize. + """ + last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a290ea7e66..5125a7904c 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -134,7 +134,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003a2..0000000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index b366910f16..11d54ffb58 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -8,13 +8,14 @@ import re from typing import Any +import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torch.distributed.tensor import DTensor from torchtitan.models.utils import MoEStateDictAdapter from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 - class DeepSeekV3StateDictAdapter(MoEStateDictAdapter): """ @@ -70,60 +71,33 @@ def __init__( } ) - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: """ - Dequantize the weights from float8 to float32. + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. - 2. Split the GroupedExperts' weight into separate expert's wegiht. + 2. Split the GroupedExperts' weight into separate expert's weight. """ to_hf_map = {v: k for k, v in self.from_hf_map.items()} @@ -172,24 +146,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict - ) - return hf_state_dict_with_scale_inv + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. 2. Convert between the HF shape and the torchtitan shape. - 3. Concate separate expert's wegiht into GroupedExperts' weight. + 3. Concat separate expert's weight into GroupedExperts' weight. """ - # dequantize the tensor in state_dict and remove the scale_inv tensor - - hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): @@ -215,7 +181,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: layer_num, value.device_mesh, ) - else: # keep this path to be compatibile with offline conversion + else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index c6dee8170d..9d8625a28a 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [quantize.linear.float8] diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9bbf..e22692bd52 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. import json -import logging import os import re from abc import ABC, abstractmethod from typing import Any -logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader + +from torchtitan.tools.logging import logger from .model import BaseModelArgs @@ -58,6 +59,21 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + The HuggingFace storage reader to read from HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" @@ -86,3 +102,12 @@ def __init__( self.fqn_to_index_mapping[hf_key] = int(indx) else: self.fqn_to_index_mapping = None + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + if from_quantized: + logger.warning( + "Loading from quantized checkpoint format is not supported for this model." + ) + return HuggingFaceStorageReader(path)