diff --git a/python/sglang/srt/models/minimax_m2.py b/python/sglang/srt/models/minimax_m2.py index 11929f740771..64580f633007 100644 --- a/python/sglang/srt/models/minimax_m2.py +++ b/python/sglang/srt/models/minimax_m2.py @@ -73,6 +73,7 @@ is_non_idle_and_non_empty, make_layers, ) +from sglang.srt.utils.hf_transformers_utils import get_rope_config logger = logging.getLogger(__name__) @@ -570,7 +571,7 @@ def __init__( # RoPE settings - support partial RoPE # FIXME: minimax_m2 config use external config that not compatible with transformers v5 - self.rope_theta = config.rope_theta + self.rope_theta, self.rope_scaling = get_rope_config(config) self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_dim = getattr( config, "rotary_dim", self.head_dim @@ -600,13 +601,12 @@ def __init__( ) # Setup RoPE with partial rotary dimension - rope_scaling = getattr(config, "rope_scaling", None) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_dim, # Use partial rotary dimension max_position=self.max_position_embeddings, base=self.rope_theta, - rope_scaling=rope_scaling, + rope_scaling=self.rope_scaling, ) # QK Normalization layers