Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions docs/design/fused_moe_modular_kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ FusedMoEExpertsModular performs the core of the FusedMoE operations. The various

`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.

`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.

`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.

`FusedMoEExpertsModular::workspace_shapes()` /
Expand Down Expand Up @@ -220,8 +217,8 @@ If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsMod

1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`,
`Config::is_fe_supports_chunking()` methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`
methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)

Doing this will add the new implementation to the test suite.

Expand Down
6 changes: 0 additions & 6 deletions tests/kernels/moe/modular_kernel_tools/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ def to_quant_torch_dtype(s: str) -> torch.dtype:
"--num-experts", type=int, default=32, help="Global num experts"
)
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl.",
)

# Quant args
parser.add_argument(
Expand Down Expand Up @@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config:
quant_config=quant_config,
prepare_finalize_type=args.pf_type,
fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path,
)
15 changes: 0 additions & 15 deletions tests/kernels/moe/modular_kernel_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class Config:
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEExperts

fused_moe_chunk_size: int | None
world_size: int

torch_trace_dir_path: str | None = None
Expand All @@ -89,7 +88,6 @@ def describe(self) -> str:
s += f" K={self.K}\n"
s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n"
if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype}\n"
Expand Down Expand Up @@ -152,11 +150,6 @@ def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:

vllm_config.parallel_config.all2all_backend = self.all2all_backend()

if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
)

return vllm_config, env_dict

def is_fp8_block_quantized(self):
Expand Down Expand Up @@ -189,10 +182,6 @@ def is_block_quant_supported(self):
info = expert_info(self.fused_experts_type)
return info.blocked_quantization_support

def is_fe_supports_chunking(self):
info = expert_info(self.fused_experts_type)
return info.supports_chunking

def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
Expand Down Expand Up @@ -233,10 +222,6 @@ def is_valid(self) -> tuple[bool, str | None]:
if not self.is_standard_fused_experts():
return False, "Mismatched format."

use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking():
return False, "Chunking not supported."

# Check quantization sanity
if (
int(self.is_per_act_token_quant)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ def rank_worker(
):
set_random_seed(pgi.rank)

# sanity check
from vllm import envs

if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE

# get weights to this device
weights.to_current_device()

Expand Down Expand Up @@ -135,7 +129,6 @@ def add_to_results(
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None,
)

success = None
Expand Down
14 changes: 0 additions & 14 deletions tests/kernels/moe/modular_kernel_tools/mk_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[torch.dtype | str]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
Expand Down Expand Up @@ -127,7 +126,6 @@ def register_experts(
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[torch.dtype | str],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
Expand All @@ -141,7 +139,6 @@ def register_experts(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_chunking,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
Expand Down Expand Up @@ -176,7 +173,6 @@ def expert_info(kind) -> ExpertInfo:
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
)
Expand All @@ -186,7 +182,6 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
)
Expand All @@ -196,7 +191,6 @@ def expert_info(kind) -> ExpertInfo:
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True,
)

Expand Down Expand Up @@ -262,7 +256,6 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
Expand All @@ -281,7 +274,6 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_aiter=True,
)
Expand All @@ -294,7 +286,6 @@ def expert_info(kind) -> ExpertInfo:
batched_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=False,
needs_deep_gemm=True,
Expand All @@ -304,7 +295,6 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
Expand All @@ -314,7 +304,6 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
Expand All @@ -331,15 +320,13 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False,
)
register_experts(
CutlassBatchedExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False,
)
else:
Expand All @@ -354,7 +341,6 @@ def expert_info(kind) -> ExpertInfo:
standard_format,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ def rank_worker(
):
set_random_seed(pgi.rank)

# sanity check
from vllm import envs

if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE

# get weights to this device
weights.to_current_device()

Expand Down
9 changes: 1 addition & 8 deletions tests/kernels/moe/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def test_w8a8_block_fp8_fused_moe(

torch.manual_seed(seed)

monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")

a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)

Expand Down Expand Up @@ -226,11 +224,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")

chunk_size = 1024

torch.manual_seed(seed)

monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_size = get_mk_alignment_for_contiguous_layout()
dtype = torch.bfloat16

Expand All @@ -252,9 +247,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
# setup code in case we are able to revisit this later.
use_compile = False

use_cudagraph = (
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
)
use_cudagraph = N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()

topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

Expand Down
2 changes: 0 additions & 2 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def test_cutlass_moe_8_bit_no_graph(
ep_size: int | None = None,
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)

Expand Down Expand Up @@ -376,7 +375,6 @@ def test_cutlass_moe_8_bit_cuda_graph(
workspace_init,
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
dtype = torch.half

Expand Down
2 changes: 0 additions & 2 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100")
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=True, activation=activation
Expand Down Expand Up @@ -289,7 +288,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
workspace_init,
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation
Expand Down
Loading
Loading