Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,145 @@ def norm_class(self) -> Type:

configs.extend(qwq)

#############
# Qwen3
#############
qwen3 = [
dict(
name="Qwen3-0.6B",
hf_config=dict(org="Qwen", name="Qwen3-0.6B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=16,
n_embd=1024,
n_query_groups=8,
head_size=128,
parallel_residual=False,
rotary_percentage=1.0,
bias=False,
attn_bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=3072,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
dict(
name="Qwen3-1.7B",
hf_config=dict(org="Qwen", name="Qwen3-1.7B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=16,
n_embd=2048,
n_query_groups=8,
head_size=128,
parallel_residual=False,
rotary_percentage=1.0,
bias=False,
attn_bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=6144,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
dict(
name="Qwen3-4B",
hf_config=dict(org="Qwen", name="Qwen3-4B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=2560,
n_query_groups=8,
head_size=128,
parallel_residual=False,
rotary_percentage=1.0,
bias=False,
attn_bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=9728,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
dict(
name="Qwen3-8B",
hf_config=dict(org="Qwen", name="Qwen3-8B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=4096,
n_query_groups=8,
head_size=128,
parallel_residual=False,
rotary_percentage=1.0,
bias=False,
attn_bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=12288,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
dict(
name="Qwen3-14B",
hf_config=dict(org="Qwen", name="Qwen3-14B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=40,
n_head=40,
n_embd=5120,
n_query_groups=8,
head_size=128,
parallel_residual=False,
rotary_percentage=1.0,
bias=False,
attn_bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=17408,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
dict(
name="Qwen3-32B",
hf_config=dict(org="Qwen", name="Qwen3-8B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=64,
n_head=64,
n_embd=5120,
n_query_groups=8,
head_size=128,
parallel_residual=False,
rotary_percentage=1.0,
bias=False,
attn_bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=25600,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
]

configs.extend(qwen3)

#############
# Salamandra
Expand Down
69 changes: 69 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,75 @@ def copy_weights_qwen_2_5(
pbar.update(progress_per_file)


def copy_weights_qwen_3(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}

if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if progress_per_file is not None:
pbar.update(progress_per_file)

if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
62 changes: 62 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM

import litgpt.config as config_module
from litgpt import GPT, Config
Expand All @@ -42,6 +43,7 @@
copy_weights_hf_llama,
copy_weights_phi,
copy_weights_qwen_2_5,
copy_weights_qwen_3,
)
from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved
from litgpt.utils import _RunIf
Expand Down Expand Up @@ -1008,6 +1010,66 @@ def test_against_original_qwen_2_5(model_name, device, dtype):
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["Qwen3-0.6B", "Qwen3-1.7B", "Qwen3-4B", "Qwen3-8B", "Qwen3-14B", "Qwen3-32B"])
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
_RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_qwen_3(model_name, device, dtype):
torch.set_default_dtype(dtype)

T = 20
ours_config = Config.from_name(
model_name,
block_size=T,
n_layer=2,
n_head=16,
n_embd=32,
intermediate_size=86,
)
theirs_config = Qwen3Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=ours_config.block_size,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.attn_bias,
)

theirs_model = Qwen3ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()

state_dict = {}
copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
@pytest.mark.parametrize(
Expand Down
Loading