Skip to content

Conversation

@LuJunru
Copy link

@LuJunru LuJunru commented Jan 8, 2026

What does this PR do?

This PR adds the implementation for the released Youtu-LLM model. The model has the following features:

  • Type: Autoregressive Causal Language Models with Dense MLA
  • Release versions: Base and Instruct

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

@LuJunru LuJunru mentioned this pull request Jan 8, 2026
5 tasks
@LuJunru
Copy link
Author

LuJunru commented Jan 8, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43166&sha=5dab39

Hi @ArthurZucker @Cyrilvallez

May I ask if it is possible to concentrate the test only on Youtu-LLM (the new model)? The summary here seems report errors raised by other models.
截屏2026-01-08 19 01 30

junru

LuJunru and others added 7 commits January 8, 2026 19:14
…ition_embedding in DiT (huggingface#43068)

* qwen2_5_omni: make max_mel_frames an inference-time knob

* not fail with raising ValueError, instead make it continue to run by choosing a target_duration that's capped and aligned

* added unit tests for Token2WavShape shape mismatch

Signed-off-by: Dong Wang <dongw2019@gmail.com>

* make fixup

* remove unit test which takes too much GPU memory

Signed-off-by: Dong Wang <dongw2019@gmail.com>

* reduce gpu memory usage from the unit test

* addressed comments

Signed-off-by: Dong Wang <dongw2019@gmail.com>

---------

Signed-off-by: Dong Wang <dongw2019@gmail.com>
@LuJunru
Copy link
Author

LuJunru commented Jan 9, 2026

Hi @ArthurZucker @Cyrilvallez

It seems Youtu-LLM-related codes have passed the auto review. The remaining check fails on other models.
截屏2026-01-09 09 49 13

@molbap molbap self-assigned this Jan 12, 2026
@molbap molbap self-requested a review January 12, 2026 14:13
@molbap
Copy link
Contributor

molbap commented Jan 12, 2026

run-slow: youtu_llm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/youtu_llm"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems clean, good modular file with simply Llama + MLA, beautiful. Asked a few questions, let me know and I'll re-review!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the official name YoutuLLM or Youtu as in the prefixes here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We chose to use Youtu as the prefix of modules, as it is more suitable for extension (e.g., we plan to introduce YoutuVL in near future). Youtu-LLM is rather a brand name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then everything that has to see with model name (youtu) should be named as such, like the model directory


model_sdpa = YoutuForCausalLM.from_pretrained(
"tencent/Youtu-LLM-2B-Base",
dtype=torch.float16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make sdpa explicit here



class YoutuModel(LlamaModel):
_keys_to_ignore_on_load_unexpected = [""]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to remove the Llama attribute? if so, ok

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the current version of the model (Youtu-LLM-2B family), this line of code could be removed.

@require_torch_accelerator
@pytest.mark.torch_compile_test
@require_read_token
def test_compile_static_cache(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding an integration test! however naming-wise, seems to measure dynamic and static Cache no? By the way, could we have a simple no-compile integration test that works in the simplest setting, just to avoid regressions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have provided inference tests below based on no-compile dynamic cache and no-compile static cache. Basically, I implemented this test function by referencing test function of DeepSeek V3.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, can we update the name though to make it more clear and separate in two tests? that way if it breaks at some point it's easier to debug

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, is there any official examples that I can follow up?

Comment on lines +238 to +265
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Youtu-LLM is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass

@unittest.skip("Youtu-LLM is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@unittest.skip("Youtu-LLM is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

@unittest.skip("Youtu-LLM uses MLA so it is not compatible with the standard cache format")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass

@unittest.skip("Youtu-LLM uses MLA so it is not compatible with the standard cache format")
def test_greedy_generate_dict_outputs_use_cache(self):
pass

@unittest.skip(reason="SDPA can't dispatch on flash due to unsupported head dims")
def test_sdpa_can_dispatch_on_flash(self):
pass

@unittest.skip(reason="Youtu-LLM is not suitable for testing with extreme small vocabulary")
def test_resize_tokens_embeddings(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are all these tests indeed not working?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check if we can fix the majority by moving the tests under the CausalLM wrapper classes

@LuJunru
Copy link
Author

LuJunru commented Jan 13, 2026

Hi @molbap

I've updated a new version of code according to the discussion aforementioned. Can you help start a new solo test of Youtu-LLM (run-slow: youtu_llm)?

@molbap
Copy link
Contributor

molbap commented Jan 14, 2026

run-slow: youtu_llm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/youtu_llm"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, approving to move to core maintainer review - I might iterate one more time depending on feedback. Clean work!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the name is youtu, let's rename the directory thus as well. Any name is fine as long as we're keeping it consistently!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, updated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then everything that has to see with model name (youtu) should be named as such, like the model directory

@molbap molbap requested a review from Cyrilvallez January 15, 2026 14:18
>>> configuration = model.config
```"""

model_type = "youtu_llm"
Copy link
Contributor

@xenova xenova Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to also be changed to "youtu", right? @molbap

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! indeed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xenova @molbap

Understand, but since we already used youtu_llm in the open-sourced models, can we keep this unchanged?

Copy link
Contributor

@xenova xenova Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you'd like to use the same repo as before, then indeed you would need to update model_type to youtu for it to match this PR. In that update, you would also remove the auto_map/custom code section:

"auto_map": {
  "AutoConfig": "configuration_youtu.YoutuConfig",
  "AutoModel": "modeling_youtu.YoutuModel",
  "AutoModelForCausalLM": "modeling_youtu.YoutuForCausalLM"
},

so any user of the new models will see these changes.

Alternatively, if you'd prefer not to update the existing model repos (in case this causes issues with users that have not fixed the revision used), you may want to upload separate checkpoints tencent/Youtu-LLM-2B-hf for example. Perhaps @molbap can provide some more guidance here :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good use-case actually, we want to be able to load from the hub despite model keys not being matched perfectly, that's the role of conversion_mapping.py. For the model_type indeed as Joshua says there's the solution to change the key in your model repos.

Based on that I think the most convenient naming would be simply youtu_llm for everything, but if you want youtu without changing hub repos... then I would suggest doing something like touching the CONFIG_MAPPING_NAMES to have non-unique mapping keys to YoutuConfig, so that we know youtu_llm is a valid key. It's not standard, but loading directly from the hub wasn't standard until recently, hence "good use case".

cc @hmellor as it is similar to model/config mappings with vLLM x transformers modelling backend, and cc core reviewer @vasqu too (for review and opinion on this)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I guess many new models, once open-sourced on the hub, will have a pre-defined model_type. Anyway, I've updated a version that explicitly sets the model_type to youtu and uses a different testing repo from the hub. Once the PR is merged, I can update the official repository (tencent/Youtu-LLM-2B, etc.) and add additional notes about the model_type modification.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should match the hub model_type with what we have integrated here, i.e. youtu. So this sounds like the best solution re updating later on - it is slightly breaking for users but unsure how many users really rely on the model_type

We could add onto CONFIG_MAPPING_NAMES but this feels like a dirty solution, wouldn't recommend tbh.

@molbap molbap requested a review from vasqu January 21, 2026 11:23
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, youtu

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments from my side. I think we can simplify a few more things especially on the testing side with our causal lm tester. Looks overall solid tho 🤗

@@ -0,0 +1,138 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
<!--Copyright 2026 The HuggingFace Team. All rights reserved.

happy new year :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seeing this elsewhere but don't want to cluster the review

rendered properly in your Markdown viewer.
-->
*This model was released on 2025-12-31 and added to Hugging Face Transformers on 2026-01-21.*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep in mind, CI will complaing about the addition time so needs to be updated at last time - just for viz

model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
trust_remote_code=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove trust_remote_code, would be nice if we could either rely on a pr on the hub (revision="...")

pass


class YoutuMLP(LlamaMLP):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be inherited from qwen3 or gemma

Comment on lines +44 to +52
nn.Module.__init__(self)
self.hidden_size = config.hidden_size

self.self_attn = YoutuMLAttention(config=config, layer_idx=layer_idx)

self.mlp = YoutuMLP(config)

self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it clashes because of YoutuMLAttention, I think we can simply name it to YoutuAttention -> then we have the same prefix everywhere.

Otherwise, even if we have a difference, we can still use super().__init__(config, layer_idx). You can remove attributes via del attr or overwrite those directly self.attr = SpecialClass(...).

Youtu_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


class YoutuConfig(PreTrainedConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we inherit from deepseekv3 or similar in the modular file? I had comment about how to delete attributes / overwrite attributes - this would massively simplify the config

Comment on lines +122 to +123
initializer_range: float | None = None,
embedding_initializer_range: float | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's default to values directly, I don't see a reason why not

Comment on lines +136 to +157
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim
self.rope_interleave = rope_interleave

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.mlp_bias = False
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this for example would be handled by modular, e.g.

class SolarOpenConfig(Glm4MoeConfig):

Comment on lines +158 to +168
# if initializer_range is None, set it to 2.0 / (5.0 * self.hidden_size) ** 0.5
if self.hidden_size != 0:
self.initializer_range = (
(2.0 / (5.0 * self.hidden_size)) ** 0.5 if initializer_range is None else initializer_range
)
else:
self.initializer_range = 0.02
# if embedding_initializer_range is None, set it to 2.0 * self.initializer_range
self.embedding_initializer_range = (
self.initializer_range * 2.0 if embedding_initializer_range is None else embedding_initializer_range
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding_initializer_range should just directly be defaulted in the kwargs

same for initializer_range but I can see at least a bit more special so it does make more sense - although, we don't need the if else: it would be weird if the hidden size were to be 0

Comment on lines +183 to +197
def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs):
rope_scaling = kwargs.pop("rope_scaling", None)
self.rope_parameters = rope_scaling or self.rope_parameters
self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {}

# Standardize and validate the correctness of rotary position embeddings parameters
self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta))
self.standardize_rope_params()
self.validate_rope(ignore_keys=ignore_keys_at_rope_validation)

# Convert to float because RoPE fn expect a float. Models on the hub were saved as int
for key in ["beta_fast", "beta_slow", "factor"]:
if key in self.rope_parameters:
self.rope_parameters[key] = float(self.rope_parameters[key])
return kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be needed and directly fixed on the hub instead! This overwrite was mostly meant for a few special models that need it for BC but we have the option to properly provide the correct rope_parameters from the get go

@xenova
Copy link
Contributor

xenova commented Jan 24, 2026

I noticed that the Youtu-LLM-2B model on the HF hub include duplicated embed_tokens + lm_head weights, even though according to the config, these should be tied.

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("tencent/Youtu-LLM-2B")
Loading weights: 100%|███████████████| 387/387 [00:01<00:00, 283.68it/s, Materializing param=model.norm.weight]
The tied weights mapping and config for this model specifies to tie model.embed_tokens.weight to lm_head.weight, but both are present in the checkpoints, so we will NOT tie them. You should update the config with `tie_word_embeddings=False` to silence this warning
>>> model.num_parameters()
2224228352
>>> model.tie_weights()
>>> model.num_parameters()
1961560064

https://huggingface.co/tencent/Youtu-LLM-2B?show_file_info=model.safetensors

image

I've opened a PR for the config and weight updates in https://huggingface.co/tencent/Youtu-LLM-2B/discussions/17.

Saves ~500MB in the model weights. 💪
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants