diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 0481cae0eeba..db243ea3cd88 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -42,9 +42,16 @@ default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import add_prefix, cpu_has_amx_support, is_cpu, make_layers +from sglang.srt.utils import ( + add_prefix, + cpu_has_amx_support, + is_cpu, + is_npu, + make_layers, +) _is_cpu = is_cpu() +_is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -573,10 +580,17 @@ def __init__( local_theta = getattr(config, "rope_local_base_freq", 10000.0) global_config = copy.deepcopy(config) - global_config.rope_parameters = { - "rope_type": "default", - "rope_theta": global_theta, - } + if not _is_npu: + global_config.rope_parameters = { + "rope_type": "default", + "rope_theta": global_theta, + } + else: + global_config.rope_parameters = { + "rope_theta": global_theta, + "factor": 8, + "rope_type": "linear", + } self.rotary_emb = Gemma3RotaryEmbedding(config=global_config) self.gradient_checkpointing = False