Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ class EngineArgs:
)

fail_on_environ_validation: bool = False
gdn_decode_backend: Literal["triton", "cutedsl"] | None = None

def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
Expand Down Expand Up @@ -1308,6 +1309,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
"--gdn-decode-backend",
dest="gdn_decode_backend",
choices=["triton", "cutedsl"],
default=None,
help="Select GDN decode backend for Qwen3Next.",
)
return parser

@classmethod
Expand Down Expand Up @@ -1893,6 +1901,9 @@ def create_engine_config(
),
)

if self.gdn_decode_backend is not None:
self.additional_config["gdn_decode_backend"] = self.gdn_decode_backend

config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
VLLM_PORT: int | None = None
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False
VLLM_GDN_DECODE_BACKEND: Literal["triton", "cutedsl"] = "triton"
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: str | None = None
LD_LIBRARY_PATH: str | None = None
Expand Down Expand Up @@ -559,6 +560,13 @@ def _get_or_set_default() -> str:
"VLLM_USE_MODELSCOPE", "False"
).lower()
== "true",
# Selects decode backend for Qwen3Next GDN.
"VLLM_GDN_DECODE_BACKEND": env_with_choices(
"VLLM_GDN_DECODE_BACKEND",
"triton",
["triton", "cutedsl"],
case_sensitive=False,
),
# Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int(
os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")
Expand Down
Loading
Loading