diff --git a/README.md b/README.md index d434325ad9..3c69cea729 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) | | QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) | | R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | diff --git a/litgpt/config.py b/litgpt/config.py index 70bd6079a7..cae2f5f388 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2460,6 +2460,147 @@ def norm_class(self) -> Type: configs.extend(qwq) +qwen_3 = [ + # https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json + dict( + name="Qwen3-0.6B{}", + hf_config=dict(org="Qwen", name="Qwen3-0.6B{}"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=16, + n_embd=1024, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=3072, + norm_eps=1e-6, + rope_base=1000000, + head_size=128, + norm_qk=True, + ), + # https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json + dict( + name="Qwen3-1.7B{}", + hf_config=dict(org="Qwen", name="Qwen3-1.7B{}"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=16, + n_embd=2048, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=6144, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), + # https://huggingface.co/Qwen/Qwen3-4B/blob/main/config.json + dict( + name="Qwen3-4B{}", + hf_config=dict(org="Qwen", name="Qwen3-4B{}"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=36, + n_head=32, + n_embd=2560, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=9728, + norm_eps=1e-6, + rope_base=1000000, + head_size=128, + norm_qk=True, + ), + # https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json + dict( + name="Qwen3-8B{}", + hf_config=dict(org="Qwen", name="Qwen3-8B{}"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=36, + n_head=32, + n_embd=4096, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=12288, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), + # https://huggingface.co/Qwen/Qwen3-14B/blob/main/config.json + dict( + name="Qwen3-14B{}", + hf_config=dict(org="Qwen", name="Qwen3-14B{}"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=40, + n_head=40, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=17408, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), +] +for c in qwen_3: + for kind in ("", "-Base"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) +qwen_3_32b = [ + # https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json + dict( + name="Qwen3-32B", + hf_config=dict(org="Qwen", name="Qwen3-32B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=64, + n_head=64, + n_embd=5120, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=25600, + norm_eps=1e-6, + rope_base=1000000, + head_size=128, + norm_qk=True, + ), +] +configs.extend(qwen_3_32b) + ############# # Salamandra diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 7297c6e468..67b5e0df86 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -345,7 +345,7 @@ def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) class ChatML(PromptStyle): - def __init__(self, system_message: str): + def __init__(self, system_message: Optional[str] = None): self.system_message = system_message def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str: @@ -372,6 +372,11 @@ def __init__(self): ) +class Qwen3(ChatML): + def __init__(self): + super().__init__() + + class SmolLM2(ChatML): def __init__(self): super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face") @@ -411,6 +416,7 @@ def __init__(self): "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, + "qwen3": Qwen3, "smollm2": SmolLM2, "salamandra": Salamandra, } @@ -463,6 +469,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() + if re.search(r"Qwen3-.*", model_name): + return Qwen3() if re.search(r"SmolLM2.*-Instruct", model_name): return SmolLM2() if re.search(r"salamandra-.*-instruct", model_name): diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 341e1a757d..47ebc41edf 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -533,6 +533,75 @@ def copy_weights_qwen_2_5( pbar.update(progress_per_file) +def copy_weights_qwen_3( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, + pbar: Optional[tqdm] = None, + progress_per_file: Optional[float] = None, + debug_mode: Optional[bool] = False, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight", + "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight", + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + "model.norm.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + + if progress_per_file is not None: + progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) + + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if progress_per_file is not None: + pbar.update(progress_per_file) + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is split across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + if progress_per_file is not None: + pbar.update(progress_per_file) + + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -624,6 +693,10 @@ def convert_hf_checkpoint( # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights) + elif model_name.lower().startswith("qwen3"): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_qwen_3, config, qkv_weights) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 01c4ca7785..232070b1fa 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -393,6 +393,56 @@ def copy_weights_qwen_2_5( state_dict[to_name] = param +def copy_weights_qwen_3( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight", + "transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight", + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", + "transformer.ln_f.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", + } + + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: + continue + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + else: + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] @@ -437,6 +487,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: copy_fn = partial(copy_weights_phi, config) elif config.name.lower().startswith(("qwen2.5", "qwq")): copy_fn = partial(copy_weights_qwen_2_5, config) + elif config.name.lower().startswith("qwen3"): + copy_fn = partial(copy_weights_qwen_3, config) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) diff --git a/tests/test_model.py b/tests/test_model.py index 39d946fb2d..c48dbd9e83 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -30,6 +30,7 @@ from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM +from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM import litgpt.config as config_module from litgpt import GPT, Config @@ -42,6 +43,7 @@ copy_weights_hf_llama, copy_weights_phi, copy_weights_qwen_2_5, + copy_weights_qwen_3, ) from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf @@ -1008,6 +1010,65 @@ def test_against_original_qwen_2_5(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ["Qwen3-0.6B", "Qwen3-8B", "Qwen3-4B-Base", "Qwen3-14B-Base", "Qwen3-32B"]) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + _RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_qwen_3(model_name, device, dtype): + torch.set_default_dtype(dtype) + + T = 20 + ours_config = Config.from_name( + model_name, + block_size=T, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + ) + theirs_config = Qwen3Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=ours_config.block_size, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + tie_word_embeddings=False, + ) + + theirs_model = Qwen3ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) @pytest.mark.parametrize( diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index a94f46d710..09e1d0d2e0 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -47,7 +47,7 @@ def test_tokenizer_against_hf(config, tmp_path): else: assert ours.vocab_size == config.vocab_size - if config.name.startswith(("falcon", "stablecode", "Qwen2.5", "QwQ")): + if config.name.startswith(("falcon", "stablecode", "Qwen2.5", "QwQ", "Qwen3")): # even though their config defines it, it's set as None in HF assert isinstance(ours.bos_id, int) assert theirs.bos_token_id is None diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 0a41110be4..0dfaee49bf 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -44,6 +44,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) | | QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) | | R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |