Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) |
| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
Expand Down
141 changes: 141 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,147 @@ def norm_class(self) -> Type:

configs.extend(qwq)

qwen_3 = [
# https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json
dict(
name="Qwen3-0.6B{}",
hf_config=dict(org="Qwen", name="Qwen3-0.6B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=16,
n_embd=1024,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=3072,
norm_eps=1e-6,
rope_base=1000000,
head_size=128,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json
dict(
name="Qwen3-1.7B{}",
hf_config=dict(org="Qwen", name="Qwen3-1.7B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=28,
n_head=16,
n_embd=2048,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=6144,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-4B/blob/main/config.json
dict(
name="Qwen3-4B{}",
hf_config=dict(org="Qwen", name="Qwen3-4B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=2560,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=9728,
norm_eps=1e-6,
rope_base=1000000,
head_size=128,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json
dict(
name="Qwen3-8B{}",
hf_config=dict(org="Qwen", name="Qwen3-8B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=36,
n_head=32,
n_embd=4096,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=12288,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
# https://huggingface.co/Qwen/Qwen3-14B/blob/main/config.json
dict(
name="Qwen3-14B{}",
hf_config=dict(org="Qwen", name="Qwen3-14B{}"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=40,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=17408,
norm_eps=1e-6,
rope_base=1000000,
norm_qk=True,
),
]
for c in qwen_3:
for kind in ("", "-Base"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)
qwen_3_32b = [
# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json
dict(
name="Qwen3-32B",
hf_config=dict(org="Qwen", name="Qwen3-32B"),
block_size=40960,
vocab_size=151643,
padded_vocab_size=151936,
n_layer=64,
n_head=64,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=25600,
norm_eps=1e-6,
rope_base=1000000,
head_size=128,
norm_qk=True,
),
]
configs.extend(qwen_3_32b)


#############
# Salamandra
Expand Down
10 changes: 9 additions & 1 deletion litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str)


class ChatML(PromptStyle):
def __init__(self, system_message: str):
def __init__(self, system_message: Optional[str] = None):
self.system_message = system_message

def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
Expand All @@ -372,6 +372,11 @@ def __init__(self):
)


class Qwen3(ChatML):
def __init__(self):
super().__init__()


class SmolLM2(ChatML):
def __init__(self):
super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face")
Expand Down Expand Up @@ -411,6 +416,7 @@ def __init__(self):
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
"qwen3": Qwen3,
"smollm2": SmolLM2,
"salamandra": Salamandra,
}
Expand Down Expand Up @@ -463,6 +469,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Qwen2_5()
if re.search(r"QwQ-.*", model_name):
return QwQ()
if re.search(r"Qwen3-.*", model_name):
return Qwen3()
if re.search(r"SmolLM2.*-Instruct", model_name):
return SmolLM2()
if re.search(r"salamandra-.*-instruct", model_name):
Expand Down
73 changes: 73 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,75 @@ def copy_weights_qwen_2_5(
pbar.update(progress_per_file)


def copy_weights_qwen_3(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}

if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if progress_per_file is not None:
pbar.update(progress_per_file)

if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -624,6 +693,10 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
elif model_name.lower().startswith("qwen3"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_3, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
52 changes: 52 additions & 0 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,56 @@ def copy_weights_qwen_2_5(
state_dict[to_name] = param


def copy_weights_qwen_3(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.ln_f.weight": "model.norm.weight",
"lm_head.weight": "lm_head.weight",
}

for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
weight_type = from_name.split(".")[-1] # weight or bias
to_names = (
"model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type),
"model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)

for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
Expand Down Expand Up @@ -437,6 +487,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
copy_fn = partial(copy_weights_phi, config)
elif config.name.lower().startswith(("qwen2.5", "qwq")):
copy_fn = partial(copy_weights_qwen_2_5, config)
elif config.name.lower().startswith("qwen3"):
copy_fn = partial(copy_weights_qwen_3, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)
Expand Down
Loading