From e00d01d5f6882e4578bcfadd3f62d17b6ef202ea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 20 May 2024 17:54:31 +0000 Subject: [PATCH 01/14] [Kernel] Improve benchmark_moe script --- benchmarks/kernels/benchmark_moe.py | 361 ++++++++++++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 benchmarks/kernels/benchmark_moe.py diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py new file mode 100644 index 000000000000..badb94035a49 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe.py @@ -0,0 +1,361 @@ +import argparse +import time +from datetime import datetime +from typing import Any, Dict, List, Tuple + +import ray +import torch +import torch.nn.functional as F +import triton.language as tl +from transformers import AutoConfig + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.fused_moe import * + +logger = init_logger(__name__) + + +def benchmark_config( + config: Dict[str, int], + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: torch.dtype, + num_iters: int = 100, +) -> float: + x = torch.randn(M, K, dtype=dtype) + w = torch.randn(E, N, K, dtype=dtype) + o = torch.empty(M, topk, N, dtype=dtype) + gating = torch.randn(num_iters, M, E, dtype=dtype) + + compute_type = tl.bfloat16 if x.dtype == torch.bfloat16 else tl.float16 + routing_weights = F.softmax(gating, dim=-1, dtype=torch.float32) + topk_weights, input_topk_ids = torch.topk(routing_weights, topk, dim=-1) + topk_ids = input_topk_ids[0] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E) + + def prepare(i: int): + topk_ids.copy_(input_topk_ids[i]) + outputs = moe_align_block_size(topk_ids, config["BLOCK_SIZE_M"], E) + sorted_token_ids.copy_(outputs[0]) + expert_ids.copy_(outputs[1]) + num_tokens_post_padded.copy_(outputs[2]) + + def run(): + invoke_fused_moe_kernel( + x, + w, + o, + None, + None, + topk_weights[0], + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk, + config, + compute_type=compute_type, + use_fp8=False, # TODO(woosuk): Support FP8. + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs = [] + for num_stages in [4]: + for block_m in [64, 128, 256]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + for group_size in [1, 8, 64]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "SPLIT_K": 1, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) + return configs + + +def get_configs_io_bound() -> List[Dict[str, int]]: + # Adapted from https://github.com/openai/triton/blob/22af8d80458ee4e6269779dae0a3c34b755aade2/python/triton/ops/matmul.py#L36 + # TODO(woosuk): Implement a performance model to prune the search space. + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + for group_size in [1, 8]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "SPLIT_K": 1, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) + # Split-K + for split_k in [2, 4, 8, 16]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "SPLIT_K": split_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) + return configs + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(seed) + self.seed = seed + self.device_properties = torch.cuda.get_device_properties("cuda") + + def benchmark( + self, + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: torch.dtype, + ) -> Tuple[Dict[str, int], float]: + torch.cuda.manual_seed_all(self.seed) + + op_config = get_op_config(E, N, K, topk, str(dtype)) + if op_config is None: + config = get_default_config(M, E, N, K, topk, str(dtype)) + else: + config = op_config[min(op_config.keys(), key=lambda x: abs(x - M))] + kernel_time = benchmark_config(config, M, E, N, K, topk, dtype) + return config, kernel_time + + def tune( + self, + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: torch.dtype, + search_space: List[Dict[str, int]], + ) -> Dict[str, int]: + best_config = None + best_time = float("inf") + for config in search_space: + # A simple heuristic to prune the search space. + # TODO(woosuk): Remove this once we have a performance model. + split_k = config["SPLIT_K"] + if split_k > 1: + block_size_m = config["BLOCK_SIZE_M"] + block_size_n = config["BLOCK_SIZE_N"] + num_m_blocks = triton.cdiv(M * topk + E * (block_size_m - 1), + block_size_m) + num_n_blocks = triton.cdiv(N, block_size_n) + num_total_blocks = num_m_blocks * num_n_blocks * split_k + + num_sms = self.device_properties.multi_processor_count + if num_total_blocks > 2 * num_sms: + # Sufficient number of blocks. Split-K is not beneficial. + continue + + try: + kernel_time = benchmark_config(config, + M, + E, + N, + K, + topk, + dtype, + num_iters=10) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for M={M}") + return best_config + + +def sort_config(config: Dict[str, int]) -> Dict[str, int]: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "SPLIT_K": config["SPLIT_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + + +def save_configs( + configs: Dict[int, Dict[str, int]], + E: int, + N: int, + K: int, + topk: int, + dtype: str, +) -> None: + filename = get_config_file_name(E, N, K, topk, dtype) + logger.info("writing config to file %s", filename) + with open(filename, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args: argparse.Namespace): + logger.info(args) + + config = AutoConfig.from_pretrained(args.model) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + else: + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + + hidden_size = config.hidden_size + shard_intermediate_size = intermediate_size // args.tp_size + dtype = config.torch_dtype + + if args.batch_size is None: + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: List[Any]) -> List[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + w2_batch_sizes = [batch_size * topk for batch_size in batch_sizes] + if args.tune: + search_space = get_configs_compute_bound() + get_configs_io_bound() + + def _tune(Ms: List[int], N: int, K: int, topk_experts: int): + configs = _distribute( + "tune", + [(M, E, N, K, topk_experts, dtype, search_space) for M in Ms]) + best_configs = { + M: sort_config(config) + for M, config in zip(Ms, configs) + } + save_configs(best_configs, E, N, K, topk_experts, str(dtype)) + + logger.info("Start tuning over %d configurations...", + len(search_space)) + # w1 + start = time.time() + _tune(batch_sizes, 2 * shard_intermediate_size, hidden_size, topk) + end = time.time() + logger.info("W1 tuning took %.2f seconds", end - start) + + # w2 + start = time.time() + _tune(w2_batch_sizes, hidden_size, shard_intermediate_size, 1) + end = time.time() + logger.info("W2 tuning took %.2f seconds", end - start) + else: + + def _benchmark(Ms: List[int], N: int, K: int, topk_experts: int): + return _distribute("benchmark", + [(M, E, N, K, topk_experts, dtype) for M in Ms]) + + # w1 + outputs = _benchmark(batch_sizes, 2 * shard_intermediate_size, + hidden_size, topk) + for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): + logger.info("W1 batch size: %d, config: %s", batch_size, config) + logger.info("Kernel time: %.2f us", kernel_time) + + # w2 + outputs = _benchmark(w2_batch_sizes, hidden_size, + shard_intermediate_size, 1) + for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): + # NOTE(woosuk): Here the batch size is the number of input tokens + # to the MoE block. This is not the batch size of the w2 layer. + # The actual batch size of the w2 layer is batch_size * topk. + logger.info("W2 batch size: %d, config: %s", batch_size, config) + logger.info("Kernel time: %.2f us", kernel_time) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1") + parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--tune", action="store_true") + args = parser.parse_args() + + main(args) From 717e8c784b1ffdccb6246ea483e307544bbf4cf4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 20 May 2024 17:54:47 +0000 Subject: [PATCH 02/14] Delete benchmark_mixtral_moe --- benchmarks/kernels/benchmark_mixtral_moe.py | 215 -------------------- 1 file changed, 215 deletions(-) delete mode 100644 benchmarks/kernels/benchmark_mixtral_moe.py diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py deleted file mode 100644 index 5280b214144c..000000000000 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ /dev/null @@ -1,215 +0,0 @@ -import argparse -import json -import os -import sys - -import torch -import torch.nn.functional as F -import triton -from tqdm import tqdm - -from vllm.model_executor.layers.fused_moe import (fused_moe, - get_config_file_name) - -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - - -def main(dtype: str): - method = fused_moe - for bs in [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 - ]: - run_grid(bs, method=method, dtype=dtype) - - -def run_grid(bs, method, dtype: str): - d_model = 4096 - num_total_experts = 8 - top_k = 2 - tp_size = 2 - model_intermediate_size = 14336 - num_layers = 32 - num_calls = 100 - - num_warmup_trials = 1 - num_trials = 1 - - configs = [] - - for block_size_n in [32, 64, 128, 256]: - for block_size_m in [16, 32, 64, 128, 256]: - for block_size_k in [64, 128, 256]: - for group_size_m in [1, 16, 32, 64]: - for num_warps in [4, 8]: - for num_stages in [2, 3, 4, 5]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": num_stages, - }) - - best_config = None - best_time_us = 1e20 - - print(f'{tp_size=} {bs=}') - - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): - run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - method=method, - config=config, - dtype=dtype, - ) - except triton.runtime.autotuner.OutOfResources: - continue - - # trial - for _ in range(num_trials): - kernel_dur_ms = run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - method=method, - config=config, - dtype=dtype, - ) - - kernel_dur_us = 1000 * kernel_dur_ms - model_dur_ms = kernel_dur_ms * num_layers - - if kernel_dur_us < best_time_us: - best_config = config - best_time_us = kernel_dur_us - - tqdm.write( - f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') - - print("best_time_us", best_time_us) - print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] - filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size, - "float8" if dtype == "float8" else None) - print(f"writing config to file {filename}") - existing_content = {} - if os.path.exists(filename): - with open(filename, "r") as f: - existing_content = json.load(f) - existing_content[str(bs)] = best_config - with open(filename, "w") as f: - json.dump(existing_content, f, indent=4) - f.write("\n") - - -def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, - top_k: int, tp_size: int, model_intermediate_size: int, method, - config, dtype: str) -> float: - shard_intermediate_size = model_intermediate_size // tp_size - - hidden_states = torch.rand( - (bs, d_model), - device="cuda:0", - dtype=torch.float16, - ) - - w1 = torch.rand( - (num_total_experts, 2 * shard_intermediate_size, d_model), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w2 = torch.rand( - (num_total_experts, d_model, shard_intermediate_size), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w1_scale = None - w2_scale = None - a1_scale = None - a2_scale = None - - if dtype == "float8": - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - w1_scale = torch.ones(num_total_experts, - device=hidden_states.device, - dtype=torch.float32) - w2_scale = torch.ones(num_total_experts, - device=hidden_states.device, - dtype=torch.float32) - a1_scale = torch.ones(1, - device=hidden_states.device, - dtype=torch.float32) - a2_scale = torch.ones(1, - device=hidden_states.device, - dtype=torch.float32) - - gating_output = F.softmax(torch.rand( - (num_calls, bs, num_total_experts), - device=hidden_states.device, - dtype=torch.float32, - ), - dim=-1) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for i in range(num_calls): - hidden_states = method( - hidden_states=hidden_states, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - gating_output=gating_output[i], - topk=2, - renormalize=True, - inplace=True, - override_config=config, - use_fp8=dtype == "float8", - ) - end_event.record() - end_event.synchronize() - - dur_ms = start_event.elapsed_time(end_event) / num_calls - return dur_ms - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog='benchmark_mixtral_moe', - description='Benchmark and tune the fused_moe kernel', - ) - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['float8', 'float16'], - help='Data type used for fused_moe kernel computations', - ) - args = parser.parse_args() - sys.exit(main(args.dtype)) From 1bbed443808f4dd4b4a1f942e8baeebb724ef13c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 21 May 2024 01:56:29 +0000 Subject: [PATCH 03/14] Minor fixes --- benchmarks/kernels/benchmark_moe.py | 48 ++++++++--------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index badb94035a49..83240979871b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -113,7 +113,6 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": block_k, - "SPLIT_K": 1, "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, @@ -135,22 +134,10 @@ def get_configs_io_bound() -> List[Dict[str, int]]: "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": block_k, - "SPLIT_K": 1, "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, }) - # Split-K - for split_k in [2, 4, 8, 16]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "SPLIT_K": split_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) return configs @@ -237,7 +224,6 @@ def sort_config(config: Dict[str, int]) -> Dict[str, int]: "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], - "SPLIT_K": config["SPLIT_K"], "GROUP_SIZE_M": config["GROUP_SIZE_M"], "num_warps": config["num_warps"], "num_stages": config["num_stages"], @@ -268,13 +254,19 @@ def main(args: argparse.Namespace): topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size else: + # Default: Mixtral. E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size hidden_size = config.hidden_size shard_intermediate_size = intermediate_size // args.tp_size - dtype = config.torch_dtype + if args.dtype == "auto": + dtype = config.torch_dtype + elif args.dtype == "fp8": + dtype = torch.float8_e4m3fn + else: + raise ValueError(f"Invalid dtype: {args.dtype}") if args.batch_size is None: batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] @@ -296,7 +288,6 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: worker_idx = (worker_idx + 1) % num_gpus return ray.get(outputs) - w2_batch_sizes = [batch_size * topk for batch_size in batch_sizes] if args.tune: search_space = get_configs_compute_bound() + get_configs_io_bound() @@ -312,40 +303,23 @@ def _tune(Ms: List[int], N: int, K: int, topk_experts: int): logger.info("Start tuning over %d configurations...", len(search_space)) - # w1 - start = time.time() - _tune(batch_sizes, 2 * shard_intermediate_size, hidden_size, topk) - end = time.time() - logger.info("W1 tuning took %.2f seconds", end - start) - # w2 start = time.time() - _tune(w2_batch_sizes, hidden_size, shard_intermediate_size, 1) + _tune(batch_sizes, 2 * shard_intermediate_size, hidden_size, topk) end = time.time() - logger.info("W2 tuning took %.2f seconds", end - start) + logger.info("Tuning took %.2f seconds", end - start) else: def _benchmark(Ms: List[int], N: int, K: int, topk_experts: int): return _distribute("benchmark", [(M, E, N, K, topk_experts, dtype) for M in Ms]) - # w1 outputs = _benchmark(batch_sizes, 2 * shard_intermediate_size, hidden_size, topk) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): logger.info("W1 batch size: %d, config: %s", batch_size, config) logger.info("Kernel time: %.2f us", kernel_time) - # w2 - outputs = _benchmark(w2_batch_sizes, hidden_size, - shard_intermediate_size, 1) - for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): - # NOTE(woosuk): Here the batch size is the number of input tokens - # to the MoE block. This is not the batch size of the w2 layer. - # The actual batch size of the w2 layer is batch_size * topk. - logger.info("W2 batch size: %d, config: %s", batch_size, config) - logger.info("Kernel time: %.2f us", kernel_time) - if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -353,6 +327,10 @@ def _benchmark(Ms: List[int], N: int, K: int, topk_experts: int): type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1") parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--dtype", + type=str, + choices=["auto", "fp8"], + default="auto") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") From 1324af4a03df1d7843b8cf4205eca43ecd5c703e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 21 May 2024 02:00:50 +0000 Subject: [PATCH 04/14] Fix --- benchmarks/kernels/benchmark_moe.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 83240979871b..11e68c6cba20 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -182,22 +182,6 @@ def tune( best_config = None best_time = float("inf") for config in search_space: - # A simple heuristic to prune the search space. - # TODO(woosuk): Remove this once we have a performance model. - split_k = config["SPLIT_K"] - if split_k > 1: - block_size_m = config["BLOCK_SIZE_M"] - block_size_n = config["BLOCK_SIZE_N"] - num_m_blocks = triton.cdiv(M * topk + E * (block_size_m - 1), - block_size_m) - num_n_blocks = triton.cdiv(N, block_size_n) - num_total_blocks = num_m_blocks * num_n_blocks * split_k - - num_sms = self.device_properties.multi_processor_count - if num_total_blocks > 2 * num_sms: - # Sufficient number of blocks. Split-K is not beneficial. - continue - try: kernel_time = benchmark_config(config, M, @@ -261,10 +245,11 @@ def main(args: argparse.Namespace): hidden_size = config.hidden_size shard_intermediate_size = intermediate_size // args.tp_size + dtype = config.torch_dtype if args.dtype == "auto": - dtype = config.torch_dtype + w_dtype = dtype elif args.dtype == "fp8": - dtype = torch.float8_e4m3fn + w_dtype = torch.float16 else: raise ValueError(f"Invalid dtype: {args.dtype}") From 66f418a07968cb45169dafe223c38937089768a1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 May 2024 01:19:41 +0000 Subject: [PATCH 05/14] Minor --- benchmarks/kernels/benchmark_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 11e68c6cba20..28d12a0709c0 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -302,7 +302,7 @@ def _benchmark(Ms: List[int], N: int, K: int, topk_experts: int): outputs = _benchmark(batch_sizes, 2 * shard_intermediate_size, hidden_size, topk) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): - logger.info("W1 batch size: %d, config: %s", batch_size, config) + logger.info("Batch size: %d, config: %s", batch_size, config) logger.info("Kernel time: %.2f us", kernel_time) From 9088e17d40180c1e90523968c3fb3dfb887e0f4f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 18:05:12 +0000 Subject: [PATCH 06/14] Update --- benchmarks/kernels/benchmark_moe.py | 215 +++++++++++++--------------- 1 file changed, 100 insertions(+), 115 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 28d12a0709c0..6b55de9f51ad 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -5,63 +5,67 @@ import ray import torch -import torch.nn.functional as F -import triton.language as tl from transformers import AutoConfig -from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.fused_moe import * -logger = init_logger(__name__) - def benchmark_config( config: Dict[str, int], - M: int, - E: int, - N: int, - K: int, + num_tokens: int, + num_experts: int, + fused_intermediate_size: int, + hidden_size: int, topk: int, dtype: torch.dtype, + use_fp8: bool, num_iters: int = 100, ) -> float: - x = torch.randn(M, K, dtype=dtype) - w = torch.randn(E, N, K, dtype=dtype) - o = torch.empty(M, topk, N, dtype=dtype) - gating = torch.randn(num_iters, M, E, dtype=dtype) - - compute_type = tl.bfloat16 if x.dtype == torch.bfloat16 else tl.float16 - routing_weights = F.softmax(gating, dim=-1, dtype=torch.float32) - topk_weights, input_topk_ids = torch.topk(routing_weights, topk, dim=-1) - topk_ids = input_topk_ids[0] - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E) + init_dtype = torch.float16 if use_fp8 else dtype + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + w1 = torch.randn(num_experts, + fused_intermediate_size, + hidden_size, + dtype=init_dtype).to(dtype) + w2 = torch.randn(num_experts, + hidden_size, + fused_intermediate_size, + dtype=init_dtype).to(dtype) + gating_output = torch.randn(num_iters, + num_tokens, + num_experts, + dtype=torch.float32) + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + if use_fp8: + w1_scale = torch.ones(num_experts, dtype=torch.float32) + w2_scale = torch.ones(num_experts, dtype=torch.float32) + a1_scale = torch.ones(1, dtype=torch.float32) + a2_scale = torch.ones(1, dtype=torch.float32) + + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) def prepare(i: int): - topk_ids.copy_(input_topk_ids[i]) - outputs = moe_align_block_size(topk_ids, config["BLOCK_SIZE_M"], E) - sorted_token_ids.copy_(outputs[0]) - expert_ids.copy_(outputs[1]) - num_tokens_post_padded.copy_(outputs[2]) + input_gating.copy_(gating_output[i]) def run(): - invoke_fused_moe_kernel( + fused_moe( x, - w, - o, - None, - None, - topk_weights[0], - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, + w1, + w2, + input_gating, topk, - config, - compute_type=compute_type, - use_fp8=False, # TODO(woosuk): Support FP8. + renormalize=True, + inplace=True, + override_config=config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, ) # JIT compilation & warmup @@ -104,32 +108,11 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: # prune the search space. configs = [] for num_stages in [4]: - for block_m in [64, 128, 256]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: + for block_m in [64]: + for block_k in [32]: + for block_n in [32]: num_warps = 2 if block_n <= 64 else 4 - for group_size in [1, 8, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) - return configs - - -def get_configs_io_bound() -> List[Dict[str, int]]: - # Adapted from https://github.com/openai/triton/blob/22af8d80458ee4e6269779dae0a3c34b755aade2/python/triton/ops/matmul.py#L36 - # TODO(woosuk): Implement a performance model to prune the search space. - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - for group_size in [1, 8]: + for group_size in [1]: configs.append({ "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, @@ -152,44 +135,55 @@ def __init__(self, seed: int) -> None: def benchmark( self, - M: int, - E: int, - N: int, - K: int, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, topk: int, dtype: torch.dtype, + use_fp8: bool, ) -> Tuple[Dict[str, int], float]: torch.cuda.manual_seed_all(self.seed) - op_config = get_op_config(E, N, K, topk, str(dtype)) + dtype_str = "float8" if use_fp8 else str(dtype) + op_config = get_moe_configs(num_experts, shard_intermediate_size, + dtype_str) if op_config is None: - config = get_default_config(M, E, N, K, topk, str(dtype)) + config = get_default_config(num_tokens, num_experts, + shard_intermediate_size, hidden_size, + topk, dtype_str) else: - config = op_config[min(op_config.keys(), key=lambda x: abs(x - M))] - kernel_time = benchmark_config(config, M, E, N, K, topk, dtype) + config = op_config[min(op_config.keys(), + key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config(config, num_tokens, num_experts, + shard_intermediate_size, hidden_size, + topk, dtype, use_fp8) return config, kernel_time def tune( self, - M: int, - E: int, - N: int, - K: int, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, topk: int, dtype: torch.dtype, + use_fp8: bool, search_space: List[Dict[str, int]], ) -> Dict[str, int]: best_config = None best_time = float("inf") for config in search_space: + print(config) try: kernel_time = benchmark_config(config, - M, - E, - N, - K, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, topk, dtype, + use_fp8, num_iters=10) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. @@ -199,7 +193,7 @@ def tune( best_time = kernel_time best_config = config now = datetime.now() - print(f"{now.ctime()}] Completed tuning for M={M}") + print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") return best_config @@ -223,35 +217,31 @@ def save_configs( dtype: str, ) -> None: filename = get_config_file_name(E, N, K, topk, dtype) - logger.info("writing config to file %s", filename) + print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) f.write("\n") def main(args: argparse.Namespace): - logger.info(args) + print(args) config = AutoConfig.from_pretrained(args.model) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral. E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - shard_intermediate_size = intermediate_size // args.tp_size dtype = config.torch_dtype - if args.dtype == "auto": - w_dtype = dtype - elif args.dtype == "fp8": - w_dtype = torch.float16 - else: - raise ValueError(f"Invalid dtype: {args.dtype}") + use_fp8 = args.dtype == "fp8" if args.batch_size is None: batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] @@ -274,36 +264,31 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: return ray.get(outputs) if args.tune: - search_space = get_configs_compute_bound() + get_configs_io_bound() - - def _tune(Ms: List[int], N: int, K: int, topk_experts: int): - configs = _distribute( - "tune", - [(M, E, N, K, topk_experts, dtype, search_space) for M in Ms]) - best_configs = { - M: sort_config(config) - for M, config in zip(Ms, configs) - } - save_configs(best_configs, E, N, K, topk_experts, str(dtype)) - - logger.info("Start tuning over %d configurations...", - len(search_space)) + search_space = get_configs_compute_bound() + print(f"Start tuning over {len(search_space)} configurations...") start = time.time() - _tune(batch_sizes, 2 * shard_intermediate_size, hidden_size, topk) + configs = _distribute( + "tune", [(batch_size, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8, search_space) + for batch_size in batch_sizes]) + best_configs = { + M: sort_config(config) + for M, config in zip(batch_sizes, configs) + } + save_configs(best_configs, E, shard_intermediate_size, hidden_size, + topk, str(dtype)) end = time.time() - logger.info("Tuning took %.2f seconds", end - start) + print(f"Tuning took {end - start:.2f} seconds") else: + outputs = _distribute("benchmark", + [(batch_size, E, shard_intermediate_size, + hidden_size, topk, dtype, use_fp8) + for batch_size in batch_sizes]) - def _benchmark(Ms: List[int], N: int, K: int, topk_experts: int): - return _distribute("benchmark", - [(M, E, N, K, topk_experts, dtype) for M in Ms]) - - outputs = _benchmark(batch_sizes, 2 * shard_intermediate_size, - hidden_size, topk) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): - logger.info("Batch size: %d, config: %s", batch_size, config) - logger.info("Kernel time: %.2f us", kernel_time) + print(f"Batch size: {batch_size}, config: {config}") + print(f"Kernel time: {kernel_time:.2f} us") if __name__ == "__main__": From 2e428720d1c855869cbeb52e6e4ab5a4dfb8483c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 18:18:17 +0000 Subject: [PATCH 07/14] Update --- benchmarks/kernels/benchmark_moe.py | 32 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 6b55de9f51ad..6e246ff19220 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -14,7 +14,7 @@ def benchmark_config( config: Dict[str, int], num_tokens: int, num_experts: int, - fused_intermediate_size: int, + shard_intermediate_size: int, hidden_size: int, topk: int, dtype: torch.dtype, @@ -24,12 +24,12 @@ def benchmark_config( init_dtype = torch.float16 if use_fp8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) w1 = torch.randn(num_experts, - fused_intermediate_size, + shard_intermediate_size, hidden_size, dtype=init_dtype).to(dtype) w2 = torch.randn(num_experts, hidden_size, - fused_intermediate_size, + shard_intermediate_size // 2, dtype=init_dtype).to(dtype) gating_output = torch.randn(num_iters, num_tokens, @@ -41,10 +41,10 @@ def benchmark_config( a1_scale = None a2_scale = None if use_fp8: - w1_scale = torch.ones(num_experts, dtype=torch.float32) - w2_scale = torch.ones(num_experts, dtype=torch.float32) - a1_scale = torch.ones(1, dtype=torch.float32) - a2_scale = torch.ones(1, dtype=torch.float32) + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -145,7 +145,7 @@ def benchmark( ) -> Tuple[Dict[str, int], float]: torch.cuda.manual_seed_all(self.seed) - dtype_str = "float8" if use_fp8 else str(dtype) + dtype_str = "float8" if use_fp8 else None op_config = get_moe_configs(num_experts, shard_intermediate_size, dtype_str) if op_config is None: @@ -174,7 +174,6 @@ def tune( best_config = None best_time = float("inf") for config in search_space: - print(config) try: kernel_time = benchmark_config(config, num_tokens, @@ -210,13 +209,16 @@ def sort_config(config: Dict[str, int]) -> Dict[str, int]: def save_configs( configs: Dict[int, Dict[str, int]], - E: int, - N: int, - K: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, topk: int, - dtype: str, + dtype: torch.dtype, + use_fp8: bool, ) -> None: - filename = get_config_file_name(E, N, K, topk, dtype) + dtype_str = "float8" if use_fp8 else None + filename = get_config_file_name(num_experts, shard_intermediate_size, + dtype_str) print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) @@ -277,7 +279,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: for M, config in zip(batch_sizes, configs) } save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, str(dtype)) + topk, dtype, use_fp8) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: From e5e46e8383eaa530e3921b2b9b8de2710417e8e4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 18:19:23 +0000 Subject: [PATCH 08/14] Add get_default_configs --- .../layers/fused_moe/fused_moe.py | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 20a3c9f6f893..1c6947137a1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,6 +308,30 @@ def get_moe_configs(E: int, N: int, return None +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + if M <= E: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + return config + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -382,20 +406,9 @@ def fused_experts(hidden_states: torch.Tensor, config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - } - - if M <= E: - config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 - } + config = get_default_config(M, E, N, w1.shape[2], + topk_ids.shape[1], + "float8" if use_fp8 else None) intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, From 005b07cd58ceb821fccf29b4cbe95e9f3acaad66 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 18:19:37 +0000 Subject: [PATCH 09/14] Minor --- benchmarks/kernels/benchmark_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 6e246ff19220..d294b9908222 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -298,7 +298,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--tp-size", "-tp", type=int, default=2) parser.add_argument("--dtype", type=str, choices=["auto", "fp8"], From 2b05e509524d8a4245f9387678ac4fa062a24247 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 18:22:17 +0000 Subject: [PATCH 10/14] Fix search space --- benchmarks/kernels/benchmark_moe.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index d294b9908222..d707337dad9b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -107,20 +107,20 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: # TODO(woosuk): Increase the search space and use a performance model to # prune the search space. configs = [] - for num_stages in [4]: - for block_m in [64]: - for block_k in [32]: - for block_n in [32]: - num_warps = 2 if block_n <= 64 else 4 - for group_size in [1]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + }) return configs From 6f472bb7ecd4c7681af002cbd9faf57b2e3dceac Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 18:35:07 +0000 Subject: [PATCH 11/14] Minor --- benchmarks/kernels/benchmark_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index d707337dad9b..6796ea401fe6 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -5,6 +5,7 @@ import ray import torch +import triton from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * From fe1c0265076dbe80ffa000339e87b16e6e759606 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 4 Jun 2024 02:00:45 +0000 Subject: [PATCH 12/14] Address comments --- benchmarks/kernels/benchmark_moe.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 6796ea401fe6..4b792d406e31 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -27,11 +27,11 @@ def benchmark_config( w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, - dtype=init_dtype).to(dtype) + dtype=init_dtype) w2 = torch.randn(num_experts, hidden_size, shard_intermediate_size // 2, - dtype=init_dtype).to(dtype) + dtype=init_dtype) gating_output = torch.randn(num_iters, num_tokens, num_experts, @@ -47,6 +47,9 @@ def benchmark_config( a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) def prepare(i: int): @@ -132,7 +135,6 @@ def __init__(self, seed: int) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(seed) self.seed = seed - self.device_properties = torch.cuda.get_device_properties("cuda") def benchmark( self, @@ -147,7 +149,9 @@ def benchmark( torch.cuda.manual_seed_all(self.seed) dtype_str = "float8" if use_fp8 else None - op_config = get_moe_configs(num_experts, shard_intermediate_size, + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, dtype_str) if op_config is None: config = get_default_config(num_tokens, num_experts, @@ -218,7 +222,9 @@ def save_configs( use_fp8: bool, ) -> None: dtype_str = "float8" if use_fp8 else None - filename = get_config_file_name(num_experts, shard_intermediate_size, + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name(num_experts, shard_intermediate_size // 2, dtype_str) print(f"Writing best config to {filename}...") with open(filename, "w") as f: From 04c3d1126418163cdc6ad472e5a9410fc2dfc95f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 19:59:42 -0700 Subject: [PATCH 13/14] Add Ray tqdm --- benchmarks/kernels/benchmark_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 4b792d406e31..9b8b8947a924 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple import ray +from ray.experimental.tqdm_ray import tqdm import torch import triton from transformers import AutoConfig @@ -178,7 +179,7 @@ def tune( ) -> Dict[str, int]: best_config = None best_time = float("inf") - for config in search_space: + for config in tqdm(search_space): try: kernel_time = benchmark_config(config, num_tokens, From f78a73ba36b0373e48385194ebd19c366da2b6f0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 3 Jun 2024 20:00:32 -0700 Subject: [PATCH 14/14] isort --- benchmarks/kernels/benchmark_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 9b8b8947a924..d6fa39a4d30e 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Tuple import ray -from ray.experimental.tqdm_ray import tqdm import torch import triton +from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import *