Skip to content

[fix][wip] GlmMoeDsa: try implement DSA#43912

Open
JaredforReal wants to merge 9 commits intohuggingface:mainfrom
zRzRzRzRzRzRzR:glm-dsa
Open

[fix][wip] GlmMoeDsa: try implement DSA#43912
JaredforReal wants to merge 9 commits intohuggingface:mainfrom
zRzRzRzRzRzRzR:glm-dsa

Conversation

@JaredforReal
Copy link
Contributor

@JaredforReal JaredforReal commented Feb 11, 2026

What does this PR do?

  • fix k_norm as layernorm
  • add index_head_dim to config
  • rewrite GlmMoeDsaConfig from PreTrainedConfig
  • rewrite indexer as an nn.Module class
  • fix mlp layers mismatch
  • implement Attention.forward()

Current state:

GlmMoeDsaIndexer implemented, GlmMoeDsaAttention with MLA and DSA implemented.
No direct bug stops running, but there is something wrong in the implementation that makes the output nonsense

help wanted

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copilot AI review requested due to automatic review settings February 11, 2026 12:37
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@JaredforReal
Copy link
Contributor Author

Known Bug as follow

(transformers) root@develop-20260210102829-0bprk:/vllm-workspace/transformers# python test.py
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2559/2559 [02:22<00:00, 17.97it/s, Materializing param=model.norm.weight]
GlmMoeDsaForCausalLM LOAD REPORT from: /workspace/glm5-0210-fp8/
Key                                                           | Status     | Details
--------------------------------------------------------------+------------+--------
model.layers.78.mlp.shared_experts.up_proj.weight_scale_inv   | UNEXPECTED |
model.layers.78.self_attn.indexer.wk.weight                   | UNEXPECTED |
model.layers.78.self_attn.indexer.wk.weight_scale_inv         | UNEXPECTED |
model.layers.78.hnorm.weight                                  | UNEXPECTED |
model.layers.78.self_attn.indexer.k_norm.bias                 | UNEXPECTED |
model.layers.78.input_layernorm.weight                        | UNEXPECTED |
model.layers.78.self_attn.kv_b_proj.weight                    | UNEXPECTED |
model.layers.78.self_attn.q_b_proj.weight                     | UNEXPECTED |
model.layers.78.self_attn.q_a_layernorm.weight                | UNEXPECTED |
model.layers.78.mlp.shared_experts.up_proj.weight             | UNEXPECTED |
model.layers.78.eh_proj.weight                                | UNEXPECTED |
model.layers.78.mlp.shared_experts.gate_proj.weight_scale_inv | UNEXPECTED |
model.layers.78.self_attn.o_proj.weight                       | UNEXPECTED |
model.layers.78.self_attn.kv_a_proj_with_mqa.weight_scale_inv | UNEXPECTED |
model.layers.78.mlp.experts.down_proj_scale_inv               | UNEXPECTED |
model.layers.78.self_attn.kv_a_layernorm.weight               | UNEXPECTED |
model.layers.78.mlp.gate.e_score_correction_bias              | UNEXPECTED |
model.layers.78.mlp.gate.weight                               | UNEXPECTED |
model.layers.78.self_attn.kv_b_proj.weight_scale_inv          | UNEXPECTED |
model.layers.78.self_attn.indexer.wq_b.weight                 | UNEXPECTED |
model.layers.78.mlp.shared_experts.down_proj.weight_scale_inv | UNEXPECTED |
model.layers.78.self_attn.indexer.weights_proj.weight         | UNEXPECTED |
model.layers.78.self_attn.indexer.k_norm.weight               | UNEXPECTED |
model.layers.78.mlp.shared_experts.gate_proj.weight           | UNEXPECTED |
model.layers.78.self_attn.q_b_proj.weight_scale_inv           | UNEXPECTED |
model.layers.78.self_attn.q_a_proj.weight                     | UNEXPECTED |
model.layers.78.self_attn.q_a_proj.weight_scale_inv           | UNEXPECTED |
model.layers.78.self_attn.kv_a_proj_with_mqa.weight           | UNEXPECTED |
model.layers.78.self_attn.o_proj.weight_scale_inv             | UNEXPECTED |
model.layers.78.mlp.experts.gate_up_proj                      | UNEXPECTED |
model.layers.78.self_attn.indexer.wq_b.weight_scale_inv       | UNEXPECTED |
model.layers.78.enorm.weight                                  | UNEXPECTED |
model.layers.78.mlp.experts.down_proj                         | UNEXPECTED |
model.layers.78.mlp.shared_experts.down_proj.weight           | UNEXPECTED |
model.layers.78.shared_head.norm.weight                       | UNEXPECTED |
model.layers.78.post_attention_layernorm.weight               | UNEXPECTED |
model.layers.78.mlp.experts.gate_up_proj_scale_inv            | UNEXPECTED |

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
The following generation flags are not valid and may be ignored: ['top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Hi, introduce yourself

? true EM...



 EB EM external? real and all


: new a an the to re ",?

 local?

ernal a //?...





 or all

 false an or // or  new

 real in

 an ".

 local dark

 in re or




 //



 the lost: else the or or

 the.

 a //

 true // you else a as...



 new...



 or an.



://, E



 an   all

 lost in //, self saidernal the.

 then

 to

 and

? ":// the...

  you // else

 a reernal

 or
(transformers) root@develop-20260210102829-0bprk:/vllm-workspace/transformers#

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the GLM-MoE-DSA implementation to add a DSA indexer path and a rewritten MLA attention forward pass, alongside a config refactor to support new indexer parameters.

Changes:

  • Add a single-tensor RoPE helper and introduce a standalone GlmMoeDsaIndexer module used by attention.
  • Rewrite GlmMoeDsaAttention.forward() to apply top-k masking (DSA) and adjust caching/FP8-related behavior.
  • Refactor GlmMoeDsaConfig to inherit from PreTrainedConfig and add indexer-related config fields.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 11 comments.

File Description
src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py Implements the new RoPE helper, indexer module, and rewrites attention forward + FP8-related init behavior.
src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py Generated file reflecting the modular changes.
src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py Refactors config and adds indexer parameters (index_head_dim, index_n_heads, etc.).

Comment on lines 527 to 530
if self.q_lora_rank is None:
q_states = self.q_proj(hidden_states)
query_states = self.q_proj(hidden_states)
q_resid = None
else:
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When q_lora_rank is None, this sets q_resid = None, but the indexer path later expects a real [B, S, q_lora_rank] tensor (it feeds it through wq_b). Either compute an appropriate indexer query input for the non-LoRA case or explicitly disallow q_lora_rank=None for this model.

Copilot uses AI. Check for mistakes.
Comment on lines +219 to +223
# Indexer (DSA) parameters
self.index_topk = index_topk
self.index_head_dim = index_head_dim
self.index_n_heads = index_n_heads
self.indexer_rope_interleave = indexer_rope_interleave
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_head_dim/qk_rope_head_dim are used together in the indexer via torch.split(..., [qk_rope_head_dim, index_head_dim - qk_rope_head_dim]); if index_head_dim < qk_rope_head_dim this will fail at runtime. Consider validating these constraints (and index_n_heads > 0, index_topk > 0) in the config init so misconfigurations fail early with a clear error.

Copilot uses AI. Check for mistakes.
Number of top tokens selected by the indexer for retrieval/attention in each step.
Number of top tokens selected by the indexer for sparse attention.
index_head_dim (`int`, *optional*, defaults to 128):
Head dimension for the indexer projections (DSA).
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config docstring lists index_topk, index_head_dim, and indexer_rope_interleave, but it doesn't document the newly added index_n_heads. Please add a short description of what index_n_heads controls and any constraints.

Suggested change
Head dimension for the indexer projections (DSA).
Head dimension for the indexer projections (DSA).
index_n_heads (`int`, *optional*, defaults to 8):
Number of attention heads used by the indexer projections. Should be a positive integer and is typically
less than or equal to `num_attention_heads`.

Copilot uses AI. Check for mistakes.
@Rocketknight1
Copy link
Member

Hi @JaredforReal, you didn't attach the test.py script so we can't actually see what the bug is here!

@JaredforReal
Copy link
Contributor Author

JaredforReal commented Feb 11, 2026

@Rocketknight1 @ArthurZucker
test.py

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "/workspace/glm5-0210-fp8/"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
)

prompt = "Hi, introduce yourself"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=True,
        temperature=1.0,
        top_p=0.95,
    )

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few small nits, and if you have an integration test (you are using locally maybe) would be good.

The TP plan probably needs an update as well given the q/k norms. If you don't want to figure out right now i recommend skipping TP for tthe attention part only do it on experts

Comment on lines +422 to +426
# q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T]
scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale

# Weight per head and sum across heads → [B, S, T]
index_scores = torch.einsum("bsht,bsh->bst", scores, weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we avoid einsums when we can but its not a big deal

Comment on lines +401 to +405
if self._cached_keys is not None:
k_cached = torch.cat([self._cached_keys, k], dim=1) # [B, T, D]
else:
k_cached = k
self._cached_keys = k_cached
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmnmm yeah it works, the main issue is if you generate again with a new prompt this is gonna break no? (dif k shape)

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
@JaredforReal
Copy link
Contributor Author

@ArthurZucker PTAL
the current output and logs:
image

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating

Signed-off-by: JaredforReal <w13431838023@gmail.com>
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: glm_moe_dsa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants