From 3984015705d1aae53143bb32dc66f69a49a992e1 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Wed, 7 May 2025 11:26:50 +0000 Subject: [PATCH 1/6] add qwen3-32B base code --- litgpt/config.py | 53 +++++++++++++++++++ litgpt/model.py | 7 +-- litgpt/scripts/convert_hf_checkpoint.py | 69 +++++++++++++++++++++++++ tests/test_model.py | 67 ++++++++++++++++++++++++ 4 files changed, 193 insertions(+), 3 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index acbea699b3..956158fe25 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2449,6 +2449,59 @@ def norm_class(self) -> Type: configs.extend(qwq) +############# +# Qwen3 +############# +qwen3 = [ + dict( + name="Qwen3-8B", + hf_config=dict(org="Qwen", name="Qwen3-8B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=36, + n_head=32, + n_embd=4096, + n_query_groups=8, + head_size=128, + parallel_residual=False, + rotary_percentage=1.0, + bias=False, + attn_bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=12288, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), + dict( + name="Qwen3-235B-A22B", + hf_config=dict(org="Qwen", name="Qwen3-235B-A22B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=94, + n_head=64, + n_embd=4096, + n_query_groups=4, + head_size=128, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMoE", + intermediate_size=12288, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + n_expert=128, + n_expert_per_token=8, + ) +] + +configs.extend(qwen3) ############# # Salamandra diff --git a/litgpt/model.py b/litgpt/model.py index 5fcb04d4b9..09e2c372ad 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -325,6 +325,8 @@ def forward( else: x = attention_output + x x_normed = self.norm_2(x) + print("mlp output", self.mlp(x_normed)) + print("attention output ", self.post_mlp_norm(self.mlp(x_normed)) + x) return self.post_mlp_norm(self.mlp(x_normed)) + x @@ -399,14 +401,14 @@ def forward( q = q.transpose(1, 2) # (B, nh_q, T, hs) k = k.transpose(1, 2) # (B, nh_k, T, hs) v = v.transpose(1, 2) # (B, nh_v, T, hs) - + if self.config.norm_qk: q = self.norm_q(q) k = self.norm_k(k) - # Unlike standard positional embeddings rotary embeddings must be applied at every layer. q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) @@ -453,7 +455,6 @@ def forward( # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) - # Re-assemble all head outputs side by side. y = y.reshape(B, T, head_size * n_head) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 341e1a757d..6083ed9586 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -533,6 +533,75 @@ def copy_weights_qwen_2_5( pbar.update(progress_per_file) +def copy_weights_qwen_3( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, + pbar: Optional[tqdm] = None, + progress_per_file: Optional[float] = None, + debug_mode: Optional[bool] = False, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight", + "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight", + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + "model.norm.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + + if progress_per_file is not None: + progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) + + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if progress_per_file is not None: + pbar.update(progress_per_file) + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is split across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + if progress_per_file is not None: + pbar.update(progress_per_file) + + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/tests/test_model.py b/tests/test_model.py index 8860a26614..1495da5006 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -30,6 +30,7 @@ from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM +from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM import litgpt.config as config_module from litgpt import GPT, Config @@ -42,6 +43,7 @@ copy_weights_hf_llama, copy_weights_phi, copy_weights_qwen_2_5, + copy_weights_qwen_3 ) from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf @@ -1008,6 +1010,71 @@ def test_against_original_qwen_2_5(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize( + "model_name", ["Qwen3-8B"] +) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + _RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_qwen_3(model_name, device, dtype): + torch.set_default_dtype(dtype) + + T = 20 + ours_config = Config.from_name( + model_name, + block_size=T, + n_layer=1, + n_head=16, + n_embd=32, + intermediate_size=86 + ) + + theirs_config = Qwen3Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=ours_config.block_size, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.attn_bias + + ) + print(ours_config) + print(theirs_config) + + theirs_model = Qwen3ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + + state_dict = {} + copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) @pytest.mark.parametrize( From 9c8c15fe2775de8d49e2c210d63cb85ac072aa30 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Wed, 7 May 2025 11:29:16 +0000 Subject: [PATCH 2/6] remove stray lines and prints --- litgpt/model.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 09e2c372ad..a5c3c563f1 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -84,7 +84,7 @@ def forward( self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, lm_head_chunk_size: int = 0, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ @@ -291,7 +291,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual @@ -325,8 +325,6 @@ def forward( else: x = attention_output + x x_normed = self.norm_2(x) - print("mlp output", self.mlp(x_normed)) - print("attention output ", self.post_mlp_norm(self.mlp(x_normed)) + x) return self.post_mlp_norm(self.mlp(x_normed)) + x @@ -363,7 +361,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: # Notation: # - B | batch size @@ -401,14 +399,14 @@ def forward( q = q.transpose(1, 2) # (B, nh_q, T, hs) k = k.transpose(1, 2) # (B, nh_k, T, hs) v = v.transpose(1, 2) # (B, nh_v, T, hs) - + if self.config.norm_qk: q = self.norm_q(q) k = self.norm_k(k) + # Unlike standard positional embeddings rotary embeddings must be applied at every layer. q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) @@ -455,6 +453,7 @@ def forward( # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) + # Re-assemble all head outputs side by side. y = y.reshape(B, T, head_size * n_head) @@ -836,4 +835,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return (x_normed * weight.float()).to(dtype=dtype) def reset_parameters(self) -> None: - torch.nn.init.ones_(self.weight) + torch.nn.init.ones_(self.weight) \ No newline at end of file From a4e739a403f979aa359f8e44d1f1ac1c8456e626 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Wed, 7 May 2025 11:30:27 +0000 Subject: [PATCH 3/6] code style fixes --- litgpt/config.py | 2 +- litgpt/model.py | 2 +- litgpt/scripts/convert_hf_checkpoint.py | 2 +- tests/test_model.py | 21 ++++++--------------- 4 files changed, 9 insertions(+), 18 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 956158fe25..b1ed2bbec6 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2498,7 +2498,7 @@ def norm_class(self) -> Type: norm_qk=True, n_expert=128, n_expert_per_token=8, - ) + ), ] configs.extend(qwen3) diff --git a/litgpt/model.py b/litgpt/model.py index a5c3c563f1..db6aebe790 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -835,4 +835,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return (x_normed * weight.float()).to(dtype=dtype) def reset_parameters(self) -> None: - torch.nn.init.ones_(self.weight) \ No newline at end of file + torch.nn.init.ones_(self.weight) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 6083ed9586..754b6bf704 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -600,7 +600,7 @@ def copy_weights_qwen_3( if progress_per_file is not None: pbar.update(progress_per_file) - + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config diff --git a/tests/test_model.py b/tests/test_model.py index 1495da5006..dfca840bcf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -43,7 +43,7 @@ copy_weights_hf_llama, copy_weights_phi, copy_weights_qwen_2_5, - copy_weights_qwen_3 + copy_weights_qwen_3, ) from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf @@ -1011,9 +1011,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype): @torch.inference_mode() -@pytest.mark.parametrize( - "model_name", ["Qwen3-8B"] -) +@pytest.mark.parametrize("model_name", ["Qwen3-8B"]) @pytest.mark.parametrize( ("device", "dtype"), [ @@ -1034,15 +1032,8 @@ def test_against_original_qwen_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 - ours_config = Config.from_name( - model_name, - block_size=T, - n_layer=1, - n_head=16, - n_embd=32, - intermediate_size=86 - ) - + ours_config = Config.from_name(model_name, block_size=T, n_layer=1, n_head=16, n_embd=32, intermediate_size=86) + theirs_config = Qwen3Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, @@ -1054,8 +1045,7 @@ def test_against_original_qwen_3(model_name, device, dtype): rms_norm_eps=ours_config.norm_eps, num_key_value_heads=ours_config.n_query_groups, rope_theta=ours_config.rope_base, - attention_bias=ours_config.attn_bias - + attention_bias=ours_config.attn_bias, ) print(ours_config) print(theirs_config) @@ -1075,6 +1065,7 @@ def test_against_original_qwen_3(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) @pytest.mark.parametrize( From 0569bb68e3f7edd8fbd612bd99dd4b216aec8ef9 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Wed, 7 May 2025 11:31:22 +0000 Subject: [PATCH 4/6] reformat test and update args --- tests/test_model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index dfca840bcf..a95fa4d642 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1032,8 +1032,14 @@ def test_against_original_qwen_3(model_name, device, dtype): torch.set_default_dtype(dtype) T = 20 - ours_config = Config.from_name(model_name, block_size=T, n_layer=1, n_head=16, n_embd=32, intermediate_size=86) - + ours_config = Config.from_name( + model_name, + block_size=T, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + ) theirs_config = Qwen3Config( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, @@ -1047,8 +1053,6 @@ def test_against_original_qwen_3(model_name, device, dtype): rope_theta=ours_config.rope_base, attention_bias=ours_config.attn_bias, ) - print(ours_config) - print(theirs_config) theirs_model = Qwen3ForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() From bfac2a3453d4de7f62ca62a0c42cf50820d91591 Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Wed, 7 May 2025 11:50:43 +0000 Subject: [PATCH 5/6] add other dense configs --- litgpt/config.py | 106 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 96 insertions(+), 10 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index b1ed2bbec6..e2c1152063 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2453,6 +2453,72 @@ def norm_class(self) -> Type: # Qwen3 ############# qwen3 = [ + dict( + name="Qwen3-0.6B", + hf_config=dict(org="Qwen", name="Qwen3-0.6B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=16, + n_embd=1024, + n_query_groups=8, + head_size=128, + parallel_residual=False, + rotary_percentage=1.0, + bias=False, + attn_bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=3072, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), + dict( + name="Qwen3-1.7B", + hf_config=dict(org="Qwen", name="Qwen3-1.7B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=16, + n_embd=2048, + n_query_groups=8, + head_size=128, + parallel_residual=False, + rotary_percentage=1.0, + bias=False, + attn_bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=6144, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), + dict( + name="Qwen3-4B", + hf_config=dict(org="Qwen", name="Qwen3-4B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=36, + n_head=32, + n_embd=2560, + n_query_groups=8, + head_size=128, + parallel_residual=False, + rotary_percentage=1.0, + bias=False, + attn_bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=9728, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), dict( name="Qwen3-8B", hf_config=dict(org="Qwen", name="Qwen3-8B"), @@ -2476,28 +2542,48 @@ def norm_class(self) -> Type: norm_qk=True, ), dict( - name="Qwen3-235B-A22B", - hf_config=dict(org="Qwen", name="Qwen3-235B-A22B"), + name="Qwen3-14B", + hf_config=dict(org="Qwen", name="Qwen3-14B"), block_size=40960, vocab_size=151643, padded_vocab_size=151936, - n_layer=94, - n_head=64, - n_embd=4096, - n_query_groups=4, + n_layer=40, + n_head=40, + n_embd=5120, + n_query_groups=8, head_size=128, + parallel_residual=False, rotary_percentage=1.0, + bias=False, + attn_bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=17408, + norm_eps=1e-6, + rope_base=1000000, + norm_qk=True, + ), + dict( + name="Qwen3-32B", + hf_config=dict(org="Qwen", name="Qwen3-8B"), + block_size=40960, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=64, + n_head=64, + n_embd=5120, + n_query_groups=8, + head_size=128, parallel_residual=False, + rotary_percentage=1.0, bias=False, attn_bias=False, norm_class_name="RMSNorm", - mlp_class_name="LLaMAMoE", - intermediate_size=12288, + mlp_class_name="LLaMAMLP", + intermediate_size=25600, norm_eps=1e-6, rope_base=1000000, norm_qk=True, - n_expert=128, - n_expert_per_token=8, ), ] From 5a689ff56d84bb732744ad5183c9f44fa45fe68d Mon Sep 17 00:00:00 2001 From: KaelanDt Date: Wed, 7 May 2025 11:57:21 +0000 Subject: [PATCH 6/6] add all dense models to tests --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index a95fa4d642..fb130c0dc5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1011,7 +1011,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype): @torch.inference_mode() -@pytest.mark.parametrize("model_name", ["Qwen3-8B"]) +@pytest.mark.parametrize("model_name", ["Qwen3-0.6B", "Qwen3-1.7B", "Qwen3-4B", "Qwen3-8B", "Qwen3-14B", "Qwen3-32B"]) @pytest.mark.parametrize( ("device", "dtype"), [