Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from enum import IntEnum
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

Expand All @@ -34,7 +34,13 @@
from ..jit import env as jit_env
from ..jit import gen_jit_spec, setup_cubin_loader, sm100a_nvcc_flags
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
from ..utils import _check_shape_dtype_device, register_custom_op, register_fake_op
from ..utils import (
_check_shape_dtype_device,
get_shuffle_matrix_a_row_indices,
get_shuffle_matrix_sf_a_row_indices,
register_custom_op,
register_fake_op,
)
from .utils import (
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
Expand Down Expand Up @@ -69,6 +75,56 @@ class WeightLayout(IntEnum):
BlockMajorK = 2


def _maybe_get_cached_w3_w1_permute_indices(
_cache_permute_indices,
dst_w3_w1_weight: torch.Tensor,
epilogue_tile_m: int,
num_elts_per_sf: Union[None, int] = None,
) -> torch.Tensor:
if dst_w3_w1_weight.shape not in _cache_permute_indices:
# Get permute indices and chain them together
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight)
if num_elts_per_sf is None:
permute1 = get_shuffle_matrix_a_row_indices(
dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m
)
else:
permute1 = get_shuffle_matrix_sf_a_row_indices(
dst_w3_w1_weight,
epilogue_tile_m=epilogue_tile_m,
num_elts_per_sf=num_elts_per_sf,
)
# Memoize permute indices as recompute is **very** costly
_cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to(
dst_w3_w1_weight.device
)
permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape]
return permute_indices
Comment thread
aleozlx marked this conversation as resolved.


def _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
dst_w2_weight: torch.Tensor,
epilogue_tile_m: int,
num_elts_per_sf: Union[None, int] = None,
) -> torch.Tensor:
if dst_w2_weight.shape not in _cache_permute_indices:
if num_elts_per_sf is None:
permute_indices = get_shuffle_matrix_a_row_indices(
dst_w2_weight, epilogue_tile_m
).to(dst_w2_weight.device)
else:
permute_indices = get_shuffle_matrix_sf_a_row_indices(
dst_w2_weight,
epilogue_tile_m=epilogue_tile_m,
num_elts_per_sf=num_elts_per_sf,
).to(dst_w2_weight.device)
# Memoize permute indices as recompute is **very** costly
_cache_permute_indices[dst_w2_weight.shape] = permute_indices
permute_indices = _cache_permute_indices[dst_w2_weight.shape]
return permute_indices
Comment thread
aleozlx marked this conversation as resolved.


def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor:
"""
Reorders rows in the gemm/MOE_gemm weight matrix for min-latency
Expand Down
94 changes: 62 additions & 32 deletions tests/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from abc import ABC, abstractmethod
from enum import IntEnum
from typing import Literal
from typing import Dict, Literal

import pytest
import torch
Expand All @@ -30,15 +30,19 @@
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe import (
WeightLayout,
convert_to_block_layout,
trtllm_fp4_block_scale_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices,
)


def check_cuda(err):
Expand Down Expand Up @@ -386,50 +390,67 @@ def prepare_static_weights_for_kernel(
num_experts, hidden_size, intermediate_size // 16
) # fp8 scaling factors

# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
)
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
)

# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
gemm1_weights_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
gemm1_scales_fp4_interleaved = torch.stack(
gemm1_scales_fp4_interleaved
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)

# Shuffle weights and scaling factors for transposed mma output
# Using cached permute index calculation can speed up weights preprocessing
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
for i in range(num_experts):
# Calculate the permute indices for the following:
# 1. Reorder rows of W1 and scales for fused gated activation
# 2. Shuffle weights and scaling factors for transposed mma output
# for both w3_w1 and w2 weights and scale factors
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
)
gemm1_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
.contiguous()
)

permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
self._cache_permute_indices,
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
nvfp4_block_scale_interleave(
gemm1_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
]
.contiguous()
)
)

permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
)
gemm2_weights_fp4[i]
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
.contiguous()
)

permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_fp4_shuffled.append(
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
nvfp4_block_scale_interleave(
gemm2_scales_linear_fp4[i]
.view(torch.uint8)[
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
]
.contiguous()
)
)

Expand Down Expand Up @@ -1627,6 +1648,12 @@ def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) ->
return tile_tokens_dim


@pytest.fixture(scope="module")
def cache_permute_indices():
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
return _cache_permute_indices


@pytest.mark.parametrize("num_tokens", [1, 1024])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [1024, 768, 384])
Expand Down Expand Up @@ -1758,6 +1785,7 @@ def test_moe_quantization_classes(
moe_impl,
routing_config,
weight_processing,
cache_permute_indices,
):
"""
Test MoE implementations using separated quantization workflow.
Expand All @@ -1778,6 +1806,8 @@ def test_moe_quantization_classes(
f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}"
)

moe_impl._cache_permute_indices = cache_permute_indices

seed = 0
torch.random.manual_seed(seed)

Expand Down