diff --git a/python/sglang/srt/layers/attention/linear/utils.py b/python/sglang/srt/layers/attention/linear/utils.py index 63e7c06cbcc1..13d2b271b218 100644 --- a/python/sglang/srt/layers/attention/linear/utils.py +++ b/python/sglang/srt/layers/attention/linear/utils.py @@ -16,6 +16,11 @@ class LinearAttnKernelBackend(Enum): TRITON = "triton" CUTEDSL = "cutedsl" FLASHINFER = "flashinfer" + CUSTOM = "custom" + + @classmethod + def _missing_(cls, value): + return cls.CUSTOM def is_triton(self): return self == LinearAttnKernelBackend.TRITON @@ -26,6 +31,9 @@ def is_cutedsl(self): def is_flashinfer(self): return self == LinearAttnKernelBackend.FLASHINFER + def is_custom(self): + return self == LinearAttnKernelBackend.CUSTOM + LINEAR_ATTN_DECODE_BACKEND: Optional[LinearAttnKernelBackend] = None LINEAR_ATTN_PREFILL_BACKEND: Optional[LinearAttnKernelBackend] = None @@ -41,11 +49,8 @@ def initialize_linear_attn_config(server_args: ServerArgs): LINEAR_ATTN_DECODE_BACKEND = LinearAttnKernelBackend(decode) LINEAR_ATTN_PREFILL_BACKEND = LinearAttnKernelBackend(prefill) - rank0_log( - f"Linear attention kernel backend: " - f"decode={LINEAR_ATTN_DECODE_BACKEND.value}, " - f"prefill={LINEAR_ATTN_PREFILL_BACKEND.value}" - ) + + rank0_log(f"Linear attention kernel backend: decode={decode}, prefill={prefill}") def get_linear_attn_decode_backend() -> LinearAttnKernelBackend: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f8ec870bd510..f608518f872e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -305,6 +305,10 @@ def add_rl_on_policy_target_choices(choices): RL_ON_POLICY_TARGET_CHOICES.extend(choices) +def add_linear_attn_kernel_backend_choices(choices): + LINEAR_ATTN_KERNEL_BACKEND_CHOICES.extend(choices) + + def _resolve_speculative_algorithm_alias( speculative_algorithm: Optional[str], speculative_draft_model_path: Optional[str],