Skip to content

model: Added support for Maincoder-1B model#18534

Merged
CISC merged 6 commits into
ggml-org:masterfrom
MaincodeHQ:maincoder
Jan 2, 2026
Merged

model: Added support for Maincoder-1B model#18534
CISC merged 6 commits into
ggml-org:masterfrom
MaincodeHQ:maincoder

Conversation

@maincode-prabod
Copy link
Copy Markdown
Contributor

Added support for Maincoder-1B

fixes #18346

Copy link
Copy Markdown
Contributor

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model seems like a normal llama arch to me. Probably enough to just add one single line to convert_hf_to_gguf.py

@CISC
Copy link
Copy Markdown
Member

CISC commented Jan 2, 2026

The model seems like a normal llama arch to me. Probably enough to just add one single line to convert_hf_to_gguf.py

Agreed, this is another arch rebrand, though I'd say Qwen given the tokenizer and chat template.

@maincode-prabod
Copy link
Copy Markdown
Contributor Author

maincode-prabod commented Jan 2, 2026

Thanks for the review! I understand the concern about architecture similarity with LLAMA and Qwen.

Let me clarify the key difference:
MAINCODER applies QK normalization AFTER RoPE, not before:

# From modelling_maincoder.py (lines 200-206)
# RoPE first
query_states, key_states = apply_rotary_emb(query_states, key_states, ...)

# Then QK 
query_states = self.q_norm(query_states)  
normkey_states = self.k_norm(key_states)

This differs from:

  • Qwen3: Applies QK norm → then RoPE
  • Llama (with use_kq_norm): Uses unweighted RMS norm, no learned weights

The order matters mathematically, RoPE modifies the query/key vectors, so normalizing before vs after produces different results.

Additionally, the model has learned QK norm weights (q_norm.weight, k_norm.weight per layer), which are not present in Qwen2 and use a different application order than Qwen3.

If there's an existing architecture that matches this exact pattern (RoPE → learned QK norm), I'm happy to use that instead. I checked llama.cpp, qwen2.cpp, qwen3.cpp, and others, but none match this specific order.

Comment thread convert_hf_to_gguf.py Outdated
Comment thread convert_hf_to_gguf.py Outdated
Comment thread convert_hf_to_gguf.py Outdated
Comment thread src/models/maincoder.cpp Outdated
Comment thread src/models/maincoder.cpp Outdated
Comment thread convert_hf_to_gguf.py Outdated
Comment thread convert_hf_to_gguf.py
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Jan 2, 2026

A quick search show that hunyuan-dense.cpp also has weighted norm after rope. But I'm ok to add a new one as this case is rare.

In any cases, it's always better to leave comments in the code to avoid confusions for future contributors.

Just an extra question, what's the reason behind having norm after rope (vs. before rope?) If I understand correctly, most models apply norm before rope to avoid adding any distortions to the embedded positional information

@maincode-prabod
Copy link
Copy Markdown
Contributor Author

Thanks @ngxson. Yeah, it is similar, The catch is Hunyuan uses ROPE_TYPE_NEOX (interleaved dimension pairs) while Maincoder uses ROPE_TYPE_NORM (sequential dimension pairs), different dimension ordering, so can't directly reuse it. Let me know if I'm missing something that would allow us to reuse it, though!

On norm-after-RoPE:
Good question. Normalizing after RoPE means we're scaling the combined semantic+positional representation together, which keeps attention scores more stable across different positions. We found this helped during RL fine-tuning where attention distributions are sensitive. Most models do norm-before-RoPE to preserve positional encoding exactly, but post-RoPE norm worked better for our specific use case.

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented Jan 2, 2026

Also, can we please appreciate the new formulation of RoPE that I at least haven't seen in Transformers before?

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Apply rotary embeddings to query and key tensors."""
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # Broadcast freqs_cis
    freqs_cis = freqs_cis[:, :, None, :]

    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

It's amazing in how many ways you can express the same thing in Python!

@CISC CISC merged commit 5755e52 into ggml-org:master Jan 2, 2026
75 checks passed
@maincode-prabod
Copy link
Copy Markdown
Contributor Author

Thanks @CISC and @ngxson for the review and for helping get this over the line! Really appreciate the time, expect more PRs from our team soon!

@pwilkin I wish I could take credit for that RoPE implementation! I had the exact same reaction when I saw it. I actually adapted it from the Llama 4 implementation here

blime4 referenced this pull request in blime4/llama.cpp Feb 5, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
my-other-github-account pushed a commit to my-other-github-account/llama.cpp that referenced this pull request May 15, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
phibya pushed a commit to ziee-ai/llama.cpp that referenced this pull request May 29, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
* Add Maincoder model support

* Removed SPM model vocabulary setting and MOE related GGUF parameters
Removed trailing spaces from maincoder.cpp

* removed set_vocab

* added new line

* Fix formatting

* Add a new line for PEP8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Support Maincoder Architecture

4 participants