Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
Expand Down Expand Up @@ -432,10 +431,11 @@ def dcp_load(
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id)

dcp.load(
hf_state_dict,
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
storage_reader=hf_storage_reader,
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/experiments/qwen3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import re
from typing import Any

from torch.distributed.checkpoint import HuggingFaceStorageReader

from torch.distributed.tensor import DTensor
from torchtitan.models.utils import MoEStateDictAdapter

Expand Down Expand Up @@ -50,6 +52,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None):
"lm_head.weight": "output.weight",
}

def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
return HuggingFaceStorageReader(path)

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. Convert between the HF shape and the torchtitan shape.
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/deepseek_v3/__init__.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this intended to stay here? It looks like a debugging change that's been left in this PR by mistake? The correct number of layers looks like 61 to me from here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, yes you are right and let me fix this configuration. Thanks for pointing out

Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
v_head_dim=128,
use_flex_attn=True,
attn_mask_type="block_causal",
hf_weight_quantized=True,
),
}

Expand Down
3 changes: 3 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
beta_slow: int = 1
mscale: float = 1.0

# HF checkpoint args
hf_weight_quantized: bool = False

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
if seq_len > self.max_seq_len:
Expand Down
73 changes: 0 additions & 73 deletions torchtitan/models/deepseek_v3/model/quantization.py

This file was deleted.

85 changes: 23 additions & 62 deletions torchtitan/models/deepseek_v3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import re
from typing import Any

import torch
from torch.distributed.checkpoint import HuggingFaceStorageReader

from torch.distributed.tensor import DTensor
from torchtitan.models.utils import MoEStateDictAdapter

from .args import DeepSeekV3ModelArgs

from .quantization import calculate_scale_shape, dequantize_from_fp8


class DeepSeekV3StateDictAdapter(MoEStateDictAdapter):
"""
Expand Down Expand Up @@ -70,60 +71,28 @@ def __init__(
}
)

def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Dequantize the weights from float8 to float32.
"""

scale_inv_keys = []
for key, weight in state_dict.items():
if key.endswith(".weight") and key + "_scale_inv" in state_dict:
scale_inv = state_dict[key + "_scale_inv"]
dequantized_weight = dequantize_from_fp8(
weight, scale_inv, dtype=torch.float32
)
# update the weight and remove the scale_inv tensor
state_dict[key] = dequantized_weight
scale_inv_keys.append(key + "_scale_inv")

for key in scale_inv_keys:
state_dict.pop(key)

return state_dict
def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
if self.model_args.hf_weight_quantized:
from torch.distributed.checkpoint.quantized_hf_storage import (
QuantizedHuggingFaceStorageReader,
)

def _add_quantization_scale_inv_tensors(
self, state_dict: dict[str, Any]
) -> dict[str, Any]:
"""
Add quantization scale tensors the state_dict.
"""
non_quantized_keys = [
"input_layernorm.weight",
"post_attention_layernorm.weight",
"norm.weight",
"lm_head.weight",
"embed_tokens.weight",
"mlp.gate.weight",
]

weight_scale_inv_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".weight") and not any(
non_quantized_key in key for non_quantized_key in non_quantized_keys
):
expected_scale_shape = calculate_scale_shape(value)
# add weight_scale_inv to the state_dict
weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones(
expected_scale_shape, dtype=torch.float32
)

state_dict.update(weight_scale_inv_state_dict)
return state_dict
# NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model.
# If loading checkpoints without quantization, use HuggingFaceStorageReader instead
BLOCK_SIZE = 128
return QuantizedHuggingFaceStorageReader(
path=path,
target_dtype=torch.float32,
block_size=BLOCK_SIZE,
Comment on lines +90 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these two be configurable? If not we can remove these two lines to use default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean block_size and taget_dtype? PyTorch default value is

thread_count: int = 1,
target_dtype: torch.dtype = torch.float32,
block_size: int = 128,

I explicit leave block_size here to make the dequantize algorithm not so mysterious - The user can easily know it's block-wise dequantized with blocksize 128

thread_count=4,
)
else:
return HuggingFaceStorageReader(path)

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. Convert between the HF shape and the torchtitan shape.
2. Split the GroupedExperts' weight into separate expert's wegiht.
2. Split the GroupedExperts' weight into separate expert's weight.
"""
to_hf_map = {v: k for k, v in self.from_hf_map.items()}

Expand Down Expand Up @@ -172,24 +141,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
new_key = to_hf_map[key]
hf_state_dict[new_key] = value

# Prepare for dequantization
hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors(
hf_state_dict
)
return hf_state_dict_with_scale_inv
return hf_state_dict

def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. When loading from HF checkpoint, dequantize the weights from float8 to float32.
2. Convert between the HF shape and the torchtitan shape.
3. Concate separate expert's wegiht into GroupedExperts' weight.
3. Concat separate expert's weight into GroupedExperts' weight.
"""

# dequantize the tensor in state_dict and remove the scale_inv tensor

hf_state_dict = self._dequantize(hf_state_dict)
state_dict = {}

expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}}

for key, value in hf_state_dict.items():
Expand All @@ -215,7 +176,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
layer_num,
value.device_mesh,
)
else: # keep this path to be compatibile with offline conversion
else: # keep this path to be compatible with offline conversion
stacked_value = self._concatenate_expert_weights(
expert_weights_by_layer,
titan_abstract_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=true
enable = true
components = ["loss"] # ["model", "loss"]

[quantize.linear.float8]
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/models/llama3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

logger = logging.getLogger()

from torch.distributed.checkpoint import HuggingFaceStorageReader
from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from .args import TransformerModelArgs
Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(
"lm_head.weight": "output.weight",
}

def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
return HuggingFaceStorageReader(path)

# HuggingFace permutation function (exact copy from their conversion script)
def _permute(self, w, n_heads_arg, dim1=None, dim2=None):
if dim1 is None:
Expand Down
15 changes: 15 additions & 0 deletions torchtitan/protocols/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from abc import ABC, abstractmethod
from typing import Any

from torch.distributed.checkpoint import HuggingFaceStorageReader

logger = logging.getLogger()

from .model import BaseModelArgs
Expand Down Expand Up @@ -58,6 +60,19 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
pass

@abstractmethod
def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
"""Returns hf storage reader to read HF checkpoint

Args:
path: the path to read HF checkpoint

Returns:
The HuggingFace storage reader to read from HF checkpoint

"""
pass


class StateDictAdapter(BaseStateDictAdapter):
"""State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping"""
Expand Down
Loading