Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels

logger = init_logger(__name__)

use_legacy_triton_kernels = False

if has_triton_kernels():
try:
import triton_kernels.swiglu
Expand All @@ -38,10 +41,20 @@
from triton_kernels.tensor import (
BIT,
Bitmatrix,
SparseMatrix,
make_ragged_tensor_metadata,
)
from triton_kernels.topk import topk

try:
from triton_kernels.tensor import (
SparseMatrix,
make_ragged_tensor_metadata,
)
except ImportError:
if current_platform.is_rocm():
logger.warning_once("Using legacy triton_kernels on ROCm")
use_legacy_triton_kernels = True
else:
raise
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
Expand Down Expand Up @@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix(
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix

return routing_from_bitmatrix(
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
)
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
Expand Down Expand Up @@ -130,6 +149,10 @@ def legacy_routing(
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing

return routing(logits, n_expts_act, sm_first=sm_first)
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
Expand Down Expand Up @@ -231,11 +254,22 @@ def triton_kernel_fused_experts(
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))

act = FusedActivation(
FnSpecs(
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
),
(swiglu_alpha, swiglu_limit),
act = (
FusedActivation(
FnSpecs(
"swiglu",
triton_kernels.swiglu.swiglu_fn,
("alpha", "limit"),
reduction_n=2,
),
(swiglu_alpha, swiglu_limit),
)
if not use_legacy_triton_kernels
else FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
2,
)
)
gammas = routing_data.gate_scal if routing_data else None

Expand Down Expand Up @@ -296,8 +330,17 @@ def make_routing_data(

bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
bitmatrix = (
Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
)
if not use_legacy_triton_kernels
else Bitmatrix(
bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=None,
)
)

# matmul_ogs expects invalid topk_weights to be -1s
Expand Down