diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index ddb77f37a1e5..0fa13e5d32e0 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -334,7 +334,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Glm4ForCausalLM` * GLM-4-0414 - * `THUDM/GLM-4-32B-Chat-0414`, etc. + * `THUDM/GLM-4-32B-0414`, etc. * ✅︎ * ✅︎ - * `GPT2LMHeadModel` diff --git a/tests/models/registry.py b/tests/models/registry.py index 1599b1da07ca..f8dd933d2fca 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -147,7 +147,7 @@ def check_available_online( min_transformers_version="4.50"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo( - "THUDM/GLM-4-32B-Chat-0414", + "THUDM/GLM-4-32B-0414", is_available_online=False, min_transformers_version="4.52.dev0" ), diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index cba093cbfef7..28cebfbd7baa 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -82,7 +82,7 @@ def __init__(self, partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads - self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.rotary_dim = self.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -110,6 +110,7 @@ def __init__(self, base=self.rope_theta, rope_scaling=rope_scaling, partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -197,13 +198,12 @@ def forward( ) hidden_states = self.post_self_attn_layernorm(hidden_states) - hidden_states = residual + hidden_states # Fully Connected - hidden_states = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) - hidden_states = residual + hidden_states return hidden_states, residual