Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,17 @@ def get_vllm_port() -> Optional[int]:
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,

# If set, vllm will force use Triton compressed tensor MOE kernel;
# otherwise it will use the cutlass based one.
"VLLM_TRITON_COMPRESSED_TENSORS_MOE_KERNEL":
lambda: bool(
int(os.getenv("VLLM_TRITON_COMPRESSED_TENSORS_MOE_KERNEL", "0"))),

# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),

# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,10 @@ def __init__(
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.use_cutlass = ((quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) and
not envs.VLLM_TRITON_COMPRESSED_TENSORS_MOE_KERNEL)
self.fused_experts = None # type: ignore[assignment]
self.disable_expert_map = False

def create_weights(self, layer: torch.nn.Module, num_experts: int,
Expand Down