diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py index 0013f6b33387..c4cef4d4ab55 100644 --- a/src/transformers/models/llama4/configuration_llama4.py +++ b/src/transformers/models/llama4/configuration_llama4.py @@ -228,7 +228,9 @@ class Llama4TextConfig(PretrainedConfig): no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO attention_chunk_size (`int`, *optional*, defaults to 8192): - attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO + attn_temperature_tuning (`bool`, *optional*, defaults to `True`): + Whether to dynamically scale the attention temperature for each query token based on sequence length. + Recommended for long sequences (e.g., >32k tokens) to maintain stable output results. floor_scale (`int`, *optional*, defaults to 8192): TODO attn_scale (`int`, *optional*, defaults to 0.1): TODO cache_implementation (``, *optional*, defaults to `"hybrid"`): @@ -291,7 +293,7 @@ def __init__( no_rope_layers=None, no_rope_layer_interval=4, attention_chunk_size=8192, - attn_temperature_tuning=4, + attn_temperature_tuning=True, floor_scale=8192, attn_scale=0.1, cache_implementation="hybrid",