From 40f1d20607934676cc262b6f94f5a5d16adb68b7 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 13 Jan 2025 19:05:51 +0800 Subject: [PATCH] refactor quantization for static quantized weight loading and add deepseek_v3 (#702) Co-authored-by: baishihao --- ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + ...out_dtype=torch.bfloat16}_NVIDIA_H800.json | 1 + lightllm/common/basemodel/basemodel.py | 3 +- .../layer_weights/meta_weights/__init__.py | 2 - .../meta_weights/fused_moe_weight.py | 137 ++++- .../layer_weights/meta_weights/mm_weight.py | 495 ++++++++++++------ .../layer_weights/transformer_layer_weight.py | 8 +- .../common/fused_moe/grouped_fused_moe.py | 71 ++- lightllm/common/fused_moe/topk_select.py | 51 +- lightllm/common/quantization/__init__.py | 27 +- lightllm/common/quantization/ppl_quant.py | 10 +- lightllm/common/quantization/registry.py | 7 +- lightllm/common/quantization/torchao_quant.py | 39 +- .../quantization/triton_quant/__init__.py | 0 .../quantization/triton_quant/fp8/__init__.py | 0 .../triton_quant/fp8/fp8act_quant_kernel.py | 116 ++++ .../fp8/fp8w8a8_block_gemm_kernel.py | 242 +++++++++ .../quantization/triton_quant/triton_quant.py | 70 +++ lightllm/common/quantization/vllm_quant.py | 10 +- lightllm/common/vllm_kernel/_custom_ops.py | 35 -- .../layer_infer/transformer_layer_infer.py | 4 +- .../layer_weights/transformer_layer_weight.py | 143 +++-- .../model_infer/mode_backend/base_backend.py | 2 +- .../deepseekv3_fp8_block_gemm_tuning.py | 273 ++++++++++ 41 files changed, 1462 insertions(+), 302 deletions(-) create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=1024,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=1152,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=1536,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=16384,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=18432,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=2048,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=2304,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=256,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=2048,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=32768,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=4096,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=2304,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=24576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=256,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=36864,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=512,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=8072,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json create mode 100644 lightllm/common/quantization/triton_quant/__init__.py create mode 100644 lightllm/common/quantization/triton_quant/fp8/__init__.py create mode 100644 lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py create mode 100644 lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py create mode 100644 lightllm/common/quantization/triton_quant/triton_quant.py create mode 100644 test/kernel/deepseekv3_fp8_block_gemm_tuning.py diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1024,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1024,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..6c8e3ef72 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1024,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1152,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1152,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..2e12f7082 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1152,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1536,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1536,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..dd6dcd854 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=1536,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=16384,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=16384,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..ed8b17e92 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=16384,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=18432,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=18432,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..1722f24de --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=18432,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=2048,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=2048,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..2e12f7082 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=2048,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=2304,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=2304,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..2e12f7082 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=2304,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=256,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=256,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..bc49d105f --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=256,N=7168,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=2048,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=2048,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..66f28032d --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=2048,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=32768,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=32768,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..cf99b86e6 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=32768,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "16": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "24": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "32": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "48": {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=4096,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=4096,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..547e59134 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=512,N=4096,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..dd6dcd854 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=1536,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=2304,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=2304,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..dd6dcd854 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=2304,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=24576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=24576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..6e146056c --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=24576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "16": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "24": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "32": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=256,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=256,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..dd6dcd854 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=256,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=36864,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=36864,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..6e146056c --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=36864,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "4": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "8": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "16": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "24": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "32": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=512,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=512,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..1abdd5810 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=512,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..dd6dcd854 --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=576,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=8072,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=8072,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json new file mode 100644 index 000000000..1722f24de --- /dev/null +++ b/lightllm/common/all_kernel_configs/fp8_block_mm/{K=7168,N=8072,block_size=[128,128],out_dtype=torch.bfloat16}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"1": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, "2": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "8": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "16": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "24": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "32": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "48": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "64": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "96": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "128": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "256": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "512": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1024": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "1536": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "2048": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "3072": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, "4096": {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}} \ No newline at end of file diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 081fba6d3..477bce588 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -99,6 +99,7 @@ def _init_config(self): repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size + return @final @@ -112,7 +113,7 @@ def _verify_params(self): return def _init_quant(self): - self.quant_cfg = Quantcfg(self.config["n_layer"], self.quant_type, self.quant_cfg_path) + self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 973de48da..39895dcc6 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -10,10 +10,8 @@ MultiROWMMWeightNoTP, MultiCOLMMWeight, ROWBMMWeight, - COLBMMWeight, MultiCOLMMWeightNoTp, ROWBMMWeightNoTp, - COLBMMWeightNoTp, COLMMWeightNoTp, ) from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py index 490fe6cc2..925b0b192 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py @@ -1,22 +1,40 @@ import os import torch -from .base_weight import BaseWeight -from lightllm.utils.dist_utils import get_world_size, get_rank import threading +from typing import Optional, Tuple, List, Dict, Any +from .base_weight import BaseWeight from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod - +from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.utils.dist_utils import get_world_size, get_rank from lightllm.common.vllm_kernel import _custom_ops as ops from lightllm.utils.device_utils import get_current_device_id class FusedMoeWeight(BaseWeight): def __init__( - self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type - ): + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: super().__init__() self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name self.w3_weight_name = up_proj_name + self.weight_scale_suffix = weight_scale_suffix + self.act_scale_suffix = act_scale_suffix + self.quantized_weight = weight_scale_suffix is not None + self.static_activation = act_scale_suffix is not None + + self.e_score_correction_bias_name = e_score_correction_bias_name self.weight_prefix = weight_prefix self.n_routed_experts = n_routed_experts self.split_inter_size = split_inter_size @@ -24,27 +42,41 @@ def __init__( self.tp_rank_ = get_rank() self.experts_up_projs = [None] * self.n_routed_experts self.experts_gate_projs = [None] * self.n_routed_experts + self.experts_up_proj_scales = [None] * self.n_routed_experts + self.experts_gate_proj_scales = [None] * self.n_routed_experts self.expert_gate_up_proj_etp = None self.expert_down_proj_etp = None + self.e_score_correction_bias = None self.w2_list = [None] * self.n_routed_experts + self.w2_scale_list = [None] * self.n_routed_experts self.quant_method = None + self.scoring_func = network_config["scoring_func"] + self.w1 = [None, None] # weight, weight_scale + self.w2 = [None, None] # weight, weight_scale self.lock = threading.Lock() - def set_quant_method(self, quant_method): + def set_quant_method(self, quant_method: QuantizationMethod) -> None: + if self.quantized_weight: + self.quant_method = quant_method + return if isinstance(quant_method, vLLMFP8w8a8QuantizationMethod): self.quant_method = quant_method if self.quant_method is not None: self.quant_method.is_moe = True def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - topk_weights, topk_ids = ops.select_experts( + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( hidden_states=input_tensor, router_logits=router_logits, + correction_bias=self.e_score_correction_bias, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, + scoring_func=self.scoring_func, ) w1, w1_scale = self.w1 w2, w2_scale = self.w2 @@ -66,6 +98,8 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t return def _fuse(self): + if self.quantized_weight: + self._fuse_weight_scale() with self.lock: if ( hasattr(self, "experts_up_projs") @@ -82,21 +116,48 @@ def _fuse(self): w1_list.append(expert_gate_up_proj) inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1] - self.w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) + w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - self.w2 = torch._utils._flatten_dense_tensors(self.w2_list).view( - len(self.w2_list), inter_shape, hidden_size - ) - if self.quant_method is not None: - self.w1 = self.quant_method.quantize(self.w1) - self.w2 = self.quant_method.quantize(self.w2) + w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) + if not self.quantized_weight and self.quant_method is not None: + self.w1 = self.quant_method.quantize(w1) + self.w2 = self.quant_method.quantize(w2) else: - self.w1 = [self._cuda(self.w1), None] - self.w2 = [self._cuda(self.w2), None] + self.w1[0] = self._cuda(w1) + self.w2[0] = self._cuda(w2) delattr(self, "w2_list") delattr(self, "experts_up_projs") delattr(self, "experts_gate_projs") + def _fuse_weight_scale(self): + with self.lock: + if ( + hasattr(self, "experts_up_proj_scales") + and None not in self.experts_up_proj_scales + and None not in self.experts_gate_proj_scales + and None not in self.w2_scale_list + ): + w1_scale_list = [] + for i_experts in range(self.n_routed_experts): + expert_gate_up_proj_scale = torch.cat( + [self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0 + ) + w1_scale_list.append(expert_gate_up_proj_scale) + + inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1] + w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view( + len(w1_scale_list), inter_shape, hidden_size + ) + inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] + w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( + len(self.w2_scale_list), inter_shape, hidden_size + ) + self.w1[1] = self._cuda(w1_scale) + self.w2[1] = self._cuda(w2_scale) + delattr(self, "w2_scale_list") + delattr(self, "experts_up_proj_scales") + delattr(self, "experts_gate_proj_scales") + def _load_hf_weights_etp(self, weights): world_size_ = get_world_size() assert self.n_routed_experts % world_size_ == 0 @@ -105,6 +166,8 @@ def _load_hf_weights_etp(self, weights): # tp to ep here expert_gate_up_proj_last = None expert_down_proj_last = None + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias = self._cuda(self.e_score_correction_bias_name) for i_experts_ep in range(n_expert_ep): expert_up_proj = None @@ -178,11 +241,51 @@ def load_hf_weights(self, weights): self.w2_list[i_experts] = weights[w2_weight][ :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) ] - + if self.quant_method is not None: + self._load_weight_scale(weights) self._fuse() + def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: + block_size = 1 + if hasattr(self.quant_method, "block_size"): + block_size = self.quant_method.block_size + for i_experts in range(self.n_routed_experts): + w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" + if w1_scale in weights: + self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ + self.split_inter_size + // block_size + * self.tp_rank_ : self.split_inter_size + // block_size + * (self.tp_rank_ + 1), + :, + ] + if w3_scale in weights: + self.experts_up_proj_scales[i_experts] = weights[w3_scale][ + self.split_inter_size + // block_size + * self.tp_rank_ : self.split_inter_size + // block_size + * (self.tp_rank_ + 1), + :, + ] + + if w2_scale in weights: + self.w2_scale_list[i_experts] = weights[w2_scale][ + :, + self.split_inter_size + // block_size + * self.tp_rank_ : self.split_inter_size + // block_size + * (self.tp_rank_ + 1), + ] + def _cuda(self, cpu_tensor): device_id = get_current_device_id() + if self.quantized_weight: + return cpu_tensor.contiguous().cuda(device_id) return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) def verify_load(self): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py index 6bc104fa6..33fbaa5b1 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py @@ -1,32 +1,40 @@ import os import torch from .base_weight import BaseWeightTpl +from typing import Optional, Tuple, List, Dict, Any from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.common.quantization.quantize_method import QuantizationMethod -def generate_scale_name(name): - weight_scale_name = ".".join(name.split(".")[:-1] + ["weight_scale"]) - input_scale_name = ".".join(name.split(".")[:-1] + ["input_scale"]) - return weight_scale_name, input_scale_name +def generate_scale_name(name, weight_scale_suffix, act_scale_suffix): + weight_scale_name = None + act_scale_name = None + if weight_scale_suffix is not None: + weight_scale_name = ".".join(name.split(".")[:-1] + [weight_scale_suffix]) + if act_scale_suffix is not None: + act_scale_name = ".".join(name.split(".")[:-1] + [act_scale_suffix]) + return weight_scale_name, act_scale_name STATIC_QUANT = os.getenv("STATIC_QUANT", "0").upper() in ["1", "TRUE", "ON"] class MMWeightTpl(BaseWeightTpl): - def __init__(self, data_type): + def __init__(self, data_type: torch.dtype) -> None: super().__init__() self.data_type_ = data_type - self.quant_method = None - self.weight = None - self.bias = None - self.weight_scale = None - self.input_scale = None + self.quant_method: Optional[QuantizationMethod] = None + self.weight: Optional[torch.Tensor] = None + self.bias: Optional[torch.Tensor] = None + self.weight_scale: Optional[torch.Tensor] = None + self.input_scale: Optional[torch.Tensor] = None - def set_quant_method(self, quant_method): + def set_quant_method(self, quant_method: QuantizationMethod) -> None: self.quant_method = quant_method - def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True): + def mm( + self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True + ) -> torch.Tensor: if self.quant_method is not None: return self.quant_method.apply( input_tensor, self.weight, self.bias, out, use_custom_tensor_mananger=use_custom_tensor_mananger @@ -43,11 +51,34 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True): return torch.mm(input_tensor, self.weight, out=out) return torch.addmm(self.bias, input_tensor, self.weight, out=out) - def _post_load_weights(self): + def verify_load(self) -> bool: + load_ok = True + # Verify weight. The weight must be not None. + load_ok = load_ok and self.weight is not None + # Verify bias. If bias_name is set, it must be not None. + if self.has_bias: + load_ok = load_ok and self.bias is not None + if self.quantized_weight: + load_ok = load_ok and self.weight_scale is not None + if self.static_activation: + load_ok = load_ok and self.input_scale is not None + return load_ok + + def _post_load_weights(self) -> None: if self.quant_method is not None: - if STATIC_QUANT: - if all(w is not None for w in [self.weight, self.weight_scale, self.input_scale]): - self.weight = self.quant_method.quantize((self.weight, self.weight_scale, self.input_scale)) + if self.quantized_weight: + if ( + self.weight is not None + and self.weight_scale is not None + and (not self.static_activation or self.input_scale is not None) + ): + if self.weight_scale.ndim > 1: + self.weight_scale = self.weight_scale.transpose(0, 1).cuda(self.device_id_) + self.weight = [ + self.weight.transpose(0, 1).cuda(self.device_id_), + self.weight_scale, + self.input_scale, + ] else: self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_)) return @@ -55,29 +86,41 @@ def _post_load_weights(self): class MMWeight(MMWeightTpl): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: super().__init__(data_type) self.start = split_n_embed * self.tp_rank_ self.end = split_n_embed * (self.tp_rank_ + 1) self.weight_name = weight_name self.bias_name = bias_name - self.weight_scale_name, self.input_scale_name = generate_scale_name(weight_name) - - def verify_load(self): - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.bias_name is not None: - load_ok = load_ok and self.bias is not None - return load_ok + self.has_bias = bias_name is not None + self.weight_scale_name, self.act_scale_name = generate_scale_name( + weight_name, weight_scale_suffix, act_scale_suffix + ) + self.quantized_weight = self.weight_scale_name is not None + self.static_activation = self.act_scale_name is not None class ROWMMWeight(MMWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): - super().__init__(weight_name, data_type, split_n_embed, bias_name) - - def load_hf_weights(self, weights): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_name, data_type, split_n_embed, bias_name, weight_scale_suffix, act_scale_suffix) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: weight = None weight_scale = None input_scale = None @@ -88,12 +131,25 @@ def load_hf_weights(self, weights): bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end] self.bias = bias.cuda(self.device_id_) - if STATIC_QUANT and self.weight_scale_name in weights: - weight_scale = weights[self.weight_scale_name].to(torch.float)[self.start : self.end] - self.weight_scale = weight_scale.cuda() + if self.weight_scale_name is not None and self.weight_scale_name in weights: + block_size = 1 + if self.quant_method is not None: + if hasattr(self.quant_method, "block_size"): + block_size = self.quant_method.block_size + + weight_scale = weights[self.weight_scale_name] + # per channel or block-wise + if weight_scale.shape[0] > 1: + scale_start = (self.start + block_size - 1) // block_size + scale_end = (self.end + block_size - 1) // block_size + weight_scale = weight_scale.to(torch.float)[scale_start:scale_end] + else: + # per tensor + weight_scale = weight_scale.to(torch.float) + self.weight_scale = weight_scale - if STATIC_QUANT and self.input_scale_name in weights: - input_scale = weights[self.input_scale_name].to(torch.float) + if self.act_scale_name is not None and self.act_scale_name in weights: + input_scale = weights[self.act_scale_name].to(torch.float) self.input_scale = input_scale.cuda() if weight is None and weight_scale is None and input_scale is None: @@ -103,17 +159,33 @@ def load_hf_weights(self, weights): class ROWMMWeightNoTP(ROWMMWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): - super().__init__(weight_name, data_type, split_n_embed, bias_name) + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_name, data_type, split_n_embed, bias_name, weight_scale_suffix, act_scale_suffix) self.start = 0 self.end = split_n_embed class COLMMWeight(MMWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): - super().__init__(weight_name, data_type, split_n_embed, bias_name) - - def load_hf_weights(self, weights): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_name, data_type, split_n_embed, bias_name, weight_scale_suffix, act_scale_suffix) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: weight = None weight_scale = None input_scale = None @@ -124,12 +196,22 @@ def load_hf_weights(self, weights): bias = weights[self.bias_name] self.bias = (bias / self.world_size_).to(self.data_type_).cuda(self.device_id_) - if STATIC_QUANT and self.weight_scale_name in weights: - weight_scale = weights[self.weight_scale_name].to(torch.float) - self.weight_scale = weight_scale.cuda() + if self.quantized_weight and self.weight_scale_name in weights: + block_size = 1 + if self.quant_method is not None: + if hasattr(self.quant_method, "block_size"): + block_size = self.quant_method.block_size + weight_scale = weights[self.weight_scale_name] + # block-wise + if weight_scale.ndim >= 2: + weight_scale = weight_scale[:, self.start // block_size : self.end // block_size].to(torch.float) + else: + # per tensor or per-channel + weight_scale = weight_scale.to(torch.float) + self.weight_scale = weight_scale - if STATIC_QUANT and self.input_scale_name in weights: - input_scale = weights[self.input_scale_name].to(torch.float) + if self.static_activation and self.act_scale_name in weights: + input_scale = weights[self.act_scale_name].to(torch.float) self.input_scale = input_scale.cuda() if weight is None and weight_scale is None and input_scale is None: @@ -138,28 +220,31 @@ def load_hf_weights(self, weights): return -class COLMMWeightNoTp(MMWeight): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): - super().__init__(weight_name, data_type, split_n_embed, bias_name) +class COLMMWeightNoTp(COLMMWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_name, data_type, split_n_embed, bias_name, weight_scale_suffix, act_scale_suffix) self.start = 0 self.end = split_n_embed - def load_hf_weights(self, weights): - weight = None - if self.weight_name in weights: - weight = weights[self.weight_name].to(self.data_type_) - self.weight = weight[:, self.start : self.end] - if self.bias_name in weights: - bias = weights[self.bias_name] - self.bias = bias.to(self.data_type_).cuda(self.tp_rank_) - if weight is None: - return - self._post_load_weights() - return - class MultiMMWeight(MMWeightTpl): - def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + split_n_embeds: List[int], + bias_names: Optional[List[str]] = [], + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: super().__init__(data_type) if isinstance(split_n_embeds, int): self.split_n_embeds = [split_n_embeds] * len(weight_names) @@ -171,11 +256,13 @@ def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]): self.weight_names = weight_names self.bias_names = bias_names self.weight_scale_names = [] - self.input_scale_names = [] + self.act_scale_names = [] for weight_name in weight_names: - weight_scale_name, input_scale_name = generate_scale_name(weight_name) + weight_scale_name, act_scale_name = generate_scale_name(weight_name, weight_scale_suffix, act_scale_suffix) self.weight_scale_names.append(weight_scale_name) - self.input_scale_names.append(input_scale_name) + self.act_scale_names.append(act_scale_name) + self.quantized_weight = weight_scale_name is not None + self.static_activation = act_scale_name is not None self.weights = [None] * len(self.weight_names) self.biases = [None] * len(self.bias_names) @@ -183,40 +270,43 @@ def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]): self.weight_scales = [None] * len(self.weight_names) self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0 - def verify_load(self): - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.has_bias: - load_ok = load_ok and self.bias is not None - return load_ok - class MultiROWMMWeight(MultiMMWeight): - def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): - super().__init__(weight_names, data_type, split_n_embed, bias_names) - - def _fuse(self): - if self.weight is None and all(w is not None for w in self.weights): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + split_n_embeds: List[int], + bias_names: Optional[List[str]] = [], + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_names, data_type, split_n_embeds, bias_names, weight_scale_suffix, act_scale_suffix) + + def _fuse(self) -> None: + if self.weight is None and (None not in self.weights): self.weight = torch.cat(self.weights, dim=0) self._post_load_weights() + delattr(self, "weights") - if self.weight_scale is None and all(w is not None for w in self.weight_scales): + if self.weight_scale is None and (None not in self.weight_scales): self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda() self._post_load_weights() + delattr(self, "weight_scales") - if self.input_scale is None and all(w is not None for w in self.input_scales): + if self.static_activation and self.input_scale is None and (None not in self.input_scales): input_scales = torch.stack(self.input_scales, dim=0) self.input_scale = torch.max(input_scales).cuda() self._post_load_weights() + delattr(self, "input_scales") if self.has_bias: - if self.bias is None and all(b is not None for b in self.biases): + if self.bias is None and (None not in self.biases): self.bias = torch.cat(self.biases, dim=0).cuda(self.device_id_) + delattr(self, "biases") return self - def load_hf_weights(self, weights): + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: weight = None for i in range(len(self.weight_names)): if self.weight_names[i] in weights: @@ -225,29 +315,56 @@ def load_hf_weights(self, weights): if self.has_bias and self.bias_names[i] in weights: bias = weights[self.bias_names[i]].to(self.data_type_) self.biases[i] = bias[self.starts[i] : self.ends[i]] - if STATIC_QUANT and self.weight_scale_names[i] in weights: - weight_scale = weights[self.weight_scale_names[i]][self.starts[i] : self.ends[i]] - self.weight_scales[i] = weight_scale.to(torch.float) - if STATIC_QUANT and self.input_scale_names[i] in weights: - input_scale = weights[self.input_scale_names[i]].to(torch.float) + if self.quantized_weight and self.weight_scale_names[i] in weights: + block_size = 1 + if self.quant_method is not None: + if hasattr(self.quant_method, "block_size"): + block_size = self.quant_method.block_size + weight_scale = weights[self.weight_scale_names[i]] + # block-wise or per-channel + if weight_scale.shape[0] > 1: + weight_scale = weight_scale[self.starts[i] // block_size : self.ends[i] // block_size].to( + torch.float + ) + else: + # per tensor + weight_scale = weight_scale.to(torch.float) + self.weight_scales[i] = weight_scale + if self.static_activation and self.act_scale_names[i] in weights: + input_scale = weights[self.act_scale_names[i]].to(torch.float) self.input_scales[i] = input_scale - self._fuse() return class MultiROWMMWeightNoTP(MultiROWMMWeight): - def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): - super().__init__(weight_names, data_type, split_n_embed, bias_names) + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + split_n_embeds: List[int], + bias_names: Optional[List[str]] = [], + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_names, data_type, split_n_embeds, bias_names, weight_scale_suffix, act_scale_suffix) self.starts = [0 for i in self.split_n_embeds] self.ends = [i for i in self.split_n_embeds] class MultiCOLMMWeight(MultiROWMMWeight): - def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): - super().__init__(weight_names, data_type, split_n_embed, bias_names) - - def load_hf_weights(self, weights): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + split_n_embeds: List[int], + bias_names: Optional[List[str]] = [], + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_names, data_type, split_n_embeds, bias_names, weight_scale_suffix, act_scale_suffix) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: weight = None for i in range(len(self.weight_names)): if self.weight_names[i] in weights: @@ -256,21 +373,29 @@ def load_hf_weights(self, weights): if self.has_bias and self.bias_names[i] in weights: bias = weights[self.bias_names[i]].to(self.data_type_) self.biases[i] = bias[:, self.starts[i] : self.ends[i]] - if STATIC_QUANT and self.weight_scale_names[i] in weights: + if self.quantized_weight and self.weight_scale_names[i] in weights: weight_scale = weights[self.weight_scale_names[i]] self.weight_scales[i] = weight_scale.to(torch.float) - if STATIC_QUANT and self.input_scale_names[i] in weights: - input_scale = weights[self.input_scale_names[i]].to(torch.float) + if self.static_activation and self.act_scale_names[i] in weights: + input_scale = weights[self.act_scale_names[i]].to(torch.float) self.input_scales[i] = input_scale self._fuse() return class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP): - def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]): - super().__init__(weight_names, data_type, split_n_embed, bias_names) - - def load_hf_weights(self, weights): + def __init__( + self, + weight_names: List[str], + data_type: torch.dtype, + split_n_embeds: List[int], + bias_names: Optional[List[str]] = [], + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_names, data_type, split_n_embeds, bias_names, weight_scale_suffix, act_scale_suffix) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): weight = None for i in range(len(self.weight_names)): if self.weight_names[i] in weights: @@ -283,22 +408,24 @@ def load_hf_weights(self, weights): return -class BMMWeightTpl(BaseWeightTpl): - def __init__(self, data_type): - super().__init__() - self.data_type_ = data_type - self.quant_method = None - self.weight = None - self.bias = None +class BMMWeightTpl(MMWeightTpl): + def __init__(self, data_type: torch.dtype): + super().__init__(data_type) - def set_quant_method(self, quant_method): - self.quant_method = None + def set_quant_method(self, quant_method: QuantizationMethod) -> None: + if self.quantized_weight: + # for the quantized fp8 weight of Deepseek v3 + self.quant_method = quant_method - def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True): + def bmm( + self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True + ) -> torch.Tensor: if self.quant_method is not None: - return self.quant_method.apply(input_tensor, self.weight, self.bias, out) + fpweight = self.dequant_weight(self.weight[0], self.weight[1]) + else: + fpweight = self.weight if out is None: - shape = (input_tensor.shape[0], input_tensor.shape[1], self.weight.shape[2]) + shape = (input_tensor.shape[0], input_tensor.shape[1], fpweight.shape[2]) dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: @@ -306,22 +433,46 @@ def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True): else: out = torch.empty(shape, dtype=dtype, device=device) if self.bias is None: - return torch.bmm(input_tensor, self.weight, out=out) - return torch.addbmm(self.bias, input_tensor, self.weight, out=out) + return torch.bmm(input_tensor, fpweight, out=out) + return torch.addbmm(self.bias, input_tensor, fpweight, out=out) - def _post_load_weights(self): + def _post_load_weights(self) -> None: + if self.quant_method is not None: + if self.quantized_weight: + if ( + self.weight is not None + and self.weight_scale is not None + and (not self.static_activation or self.input_scale is not None) + ): + if self.weight_scale.ndim > 1: + self.weight_scale = self.weight_scale.cuda(self.device_id_) + self.weight = [self.weight.cuda(self.device_id_), self.weight_scale, self.input_scale] + return self.weight = self.weight.cuda(self.device_id_) class BMMWeight(BMMWeightTpl): - def __init__(self, weight_name, data_type, split_n_embed, bias_name=None): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: super().__init__(data_type) self.start = split_n_embed * self.tp_rank_ self.end = split_n_embed * (self.tp_rank_ + 1) self.weight_name = weight_name self.bias_name = bias_name + self.weight_scale_name, self.act_scale_name = generate_scale_name( + weight_name, weight_scale_suffix, act_scale_suffix + ) + self.quantized_weight = self.weight_scale_name is not None + self.static_activation = self.act_scale_name is not None - def verify_load(self): + def verify_load(self) -> None: load_ok = True # Verify weight. The weight must be not None. load_ok = load_ok and self.weight is not None @@ -332,62 +483,68 @@ def verify_load(self): class ROWBMMWeight(BMMWeight): - load_hf_weights = ROWMMWeight.load_hf_weights - def __init__( self, - weight_name, - data_type, - split_n_embed, - bias_name=None, - ): - super().__init__(weight_name, data_type, split_n_embed, bias_name) - - -class ROWBMMWeightNoTp(BMMWeight): - load_hf_weights = ROWMMWeight.load_hf_weights - - def __init__( - self, - weight_name, - data_type, - split_n_embed, - bias_name=None, - ): - super().__init__(weight_name, data_type, split_n_embed, bias_name) - self.start = 0 - self.end = split_n_embed - - -class COLBMMWeight(BMMWeight): - load_hf_weights = COLMMWeight.load_hf_weights + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_name, data_type, split_n_embed, bias_name, weight_scale_suffix, act_scale_suffix) + + def dequant_weight(self, weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + # for Deepseek v3 + # TODO a fast bmm quant kernel + weight = weight.to(self.data_type_) + block_size = weight.shape[-1] // scale.shape[-1] + w_shape = weight.shape + scale = scale.unsqueeze(-1).repeat(1, 1, 1, block_size).reshape(w_shape[0], w_shape[1], -1) + scale = scale.unsqueeze(2).repeat(1, 1, block_size, 1).reshape(w_shape) + return (weight * scale).to(self.data_type_) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + weight = None + weight_scale = None + input_scale = None + if self.weight_name in weights: + weight = weights[self.weight_name] + self.weight = weight[self.start : self.end] + if self.bias_name in weights: + bias = weights[self.bias_name].to(self.data_type_)[self.start : self.end] + self.bias = bias.cuda(self.device_id_) - def __init__( - self, - weight_name, - data_type, - split_n_embed, - bias_name=None, - ): - super().__init__(weight_name, data_type, split_n_embed, bias_name) + if self.weight_scale_name is not None and self.weight_scale_name in weights: + weight_scale = weights[self.weight_scale_name] + # per channel or block-wise + if weight_scale.shape[0] > 1: + weight_scale = weight_scale.to(torch.float)[self.start : self.end] + else: + # per tensor + weight_scale = weight_scale.to(torch.float) + self.weight_scale = weight_scale - def _post_load_weights(self): - self.weight = self.weight.transpose(0, 1).cuda(self.device_id_) + if self.act_scale_name is not None and self.act_scale_name in weights: + input_scale = weights[self.act_scale_name].to(torch.float) + self.input_scale = input_scale.cuda() + if weight is None and weight_scale is None and input_scale is None: + return + self._post_load_weights() + return -class COLBMMWeightNoTp(BMMWeight): - load_hf_weights = COLMMWeightNoTp.load_hf_weights +class ROWBMMWeightNoTp(ROWBMMWeight): def __init__( self, - weight_name, - data_type, - split_n_embed, - bias_name=None, - ): - super().__init__(weight_name, data_type, split_n_embed, bias_name) + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_scale_suffix: Optional[str] = None, + act_scale_suffix: Optional[str] = None, + ) -> None: + super().__init__(weight_name, data_type, split_n_embed, bias_name, weight_scale_suffix, act_scale_suffix) self.start = 0 self.end = split_n_embed - - def _post_load_weights(self): - self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_) diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 8cd293810..29bd7dcc6 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -20,8 +20,9 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo self.quant_cfg = quant_cfg self._parse_config() self._init_weight_names() + self._init_qweight_names() self._init_weight() - self.set_quantization() + self._set_quantization() return def _parse_config(self): @@ -30,6 +31,9 @@ def _parse_config(self): def _init_weight_names(self): pass + def _init_qweight_names(self): + pass + def _init_weight(self): pass @@ -45,7 +49,7 @@ def load_hf_weights(self, weights): elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) - def set_quantization(self): + def _set_quantization(self): if self.quant_cfg.quant_type is None: return mix_quant_list = self.quant_cfg.get_mixed_list(self.layer_num_) diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 43fbdb7a0..969016215 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -211,7 +211,10 @@ def grouped_matmul_kernel( expert_num, # int topk_num, # int token_scale_ptr, # [1,] - weight_scale_ptr, # [expert_num,] + weight_scale_ptr, # [expert_num,] or [export_num, n // block_size_n, k // block_size_k] + weight_scale_stride0, + weight_scale_stride1, + weight_scale_stride2, token_ptr, # [token_num, hidden_dim] token_stride_0, token_stride_1, @@ -232,6 +235,8 @@ def grouped_matmul_kernel( num_sm, # int compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, + block_size_n: tl.constexpr, + block_size_k: tl.constexpr, # tile sizes BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -282,14 +287,20 @@ def grouped_matmul_kernel( mask=offs_am < cur_m, other=0.0, ) - if use_fp8_w8a8: - a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last") - b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last") - ab_scale = a_scale * b_scale offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) + if use_fp8_w8a8: + if block_size_k > 0 and block_size_n > 0: + a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last") + offs_bsn = offs_bn // block_size_n + b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1 + else: + a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last") + b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last") + ab_scale = a_scale * b_scale + if use_fp8_w8a8: a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None] b_ptrs = ( @@ -303,7 +314,7 @@ def grouped_matmul_kernel( ) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for _ in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + for step_k in range(0, tl.cdiv(k, BLOCK_SIZE_K)): # hint to Triton compiler to do proper loop pipelining # tl.multiple_of(a_ptrs, [16, 16]) tl.multiple_of(b_ptrs, [16, 16]) @@ -316,7 +327,12 @@ def grouped_matmul_kernel( b = tl.load(b_ptrs, mask=(offs_bn[None, :] < n) & (offs_k[:, None] < k)) if use_fp8_w8a8: - accumulator = tl.dot(b, a, acc=accumulator) + if block_size_k > 0 and block_size_n > 0: + offs_ks = step_k * BLOCK_SIZE_K // block_size_k + b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2) + accumulator += tl.dot(b, a) * a_scale * b_scale[:, None] + else: + accumulator = tl.dot(b, a, acc=accumulator) else: accumulator += tl.dot(a, b) @@ -325,8 +341,11 @@ def grouped_matmul_kernel( offs_k += BLOCK_SIZE_K if use_fp8_w8a8: - accumulator = accumulator.T - accumulator *= ab_scale + if block_size_k > 0 and block_size_n > 0: + accumulator = accumulator.T + else: + accumulator = accumulator.T + accumulator *= ab_scale if MUL_ROUTED_WEIGHT: accumulator *= a_m_scale[:, None] @@ -363,7 +382,9 @@ def grouped_matmul( expert_to_token_num is tensor shape [expert_num], expert_to_token_index is tensor shape [expert_num, token_num * topk_num], expert_weights is tensor shape [expert_num, out_dim, hidden_dim] - expert_to_weights_scale is tensor shape [expert_num], when use_fp8_w8a8 is False, it must be None + expert_to_weights_scale is tensor shape [expert_num] or + [expert_num, out_dim // block_size_, hidden_dim // block_size_k], + when use_fp8_w8a8 is False, it must be None expert_token_limit use to limit handles token per expert. out is tensor shape [token_num * topk_num, out_dim] """ @@ -376,6 +397,14 @@ def grouped_matmul( assert expert_to_weights.is_contiguous() assert expert_weights.is_contiguous() + # for deepseek_v3 block-wise quant + block_size_n = 0 + block_size_k = 0 + if use_fp8_w8a8: + if expert_to_weights_scale.ndim == 3: + block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1] + block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2] + if not run_config: run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config( M=token_inputs.shape[0], @@ -405,6 +434,15 @@ def grouped_matmul( topk_num, token_input_scale, expert_to_weights_scale, + expert_to_weights_scale.stride(0) + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + else 0, + expert_to_weights_scale.stride(1) + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + else 0, + expert_to_weights_scale.stride(2) + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 + else 0, token_inputs, token_inputs.stride(0), token_inputs.stride(1), @@ -424,6 +462,8 @@ def grouped_matmul( num_sm=1, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + block_size_n=block_size_n, + block_size_k=block_size_k, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, @@ -462,6 +502,15 @@ def grouped_matmul( topk_num, token_input_scale, expert_to_weights_scale, + expert_to_weights_scale.stride(0) + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + else 0, + expert_to_weights_scale.stride(1) + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2 + else 0, + expert_to_weights_scale.stride(2) + if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3 + else 0, token_inputs, token_inputs.stride(0), token_inputs.stride(1), @@ -481,6 +530,8 @@ def grouped_matmul( num_sm=num_sm, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + block_size_n=block_size_n, + block_size_k=block_size_k, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index 586a1b927..c1cc49732 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -19,6 +19,7 @@ import torch from lightllm.common.vllm_kernel import _custom_ops as ops +from typing import Callable, List, Optional, Tuple def fused_topk( @@ -53,15 +54,23 @@ def fused_topk( def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, + correction_bias: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + if scoring_func == "sigmoid": + scores = torch.sigmoid(gating_output) + else: + scores = torch.softmax(gating_output, dim=-1) + + if correction_bias is not None: + scores.add_(correction_bias) - scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] @@ -79,3 +88,43 @@ def grouped_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + correction_bias: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + custom_routing_function: Optional[Callable] = None, +): + from lightllm.common.fused_moe.topk_select import fused_topk, grouped_topk + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize + ) + + return topk_weights, topk_ids diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index e227f3dbb..6f714182e 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -4,20 +4,35 @@ from .ppl_quant import * from .torchao_quant import * from .vllm_quant import * +from .triton_quant.triton_quant import * class Quantcfg: - def __init__(self, layer_num, quant_type=None, cfg_path=None): - self.layer_num = layer_num + def __init__(self, network_config, quant_type=None, custom_cfg_path=None): + self.layer_num = network_config["n_layer"] self.quant_type = quant_type - self.parse_cfg(cfg_path) + self.network_config_ = network_config + self._parse_custom_cfg(custom_cfg_path) + self._parse_network_config(network_config) - def parse_cfg(self, cfg_path): + def _parse_network_config(self, network_config): + hf_quantization_config = network_config.get("quantization_config", None) + if hf_quantization_config is None: + self.quantized_weight = False + self.static_activation = False + self.hf_quantization_config = None + return + self.quantized_weight = True + activation_scheme = network_config.get("activation_scheme", "dynamic") + self.static_activation = activation_scheme == "static" + self.hf_quantization_config = hf_quantization_config + + def _parse_custom_cfg(self, custom_cfg_path): self.quant_cfg = collections.defaultdict(dict) - if cfg_path is None: + if custom_cfg_path is None: return - with open(cfg_path, "r") as file: + with open(custom_cfg_path, "r") as file: data = yaml.safe_load(file) self.quant_type = data["quant_type"] diff --git a/lightllm/common/quantization/ppl_quant.py b/lightllm/common/quantization/ppl_quant.py index 644c2174a..8d9eb5ab6 100644 --- a/lightllm/common/quantization/ppl_quant.py +++ b/lightllm/common/quantization/ppl_quant.py @@ -2,7 +2,6 @@ import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS -from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @QUANTMETHODS.register("ppl-w4a16-128") @@ -10,6 +9,9 @@ class PPLW4A16QuantizationMethod(QuantizationMethod): def __init__(self, group_size=128): super().__init__() self.group_size = group_size + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): """ @@ -31,7 +33,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ qweight is quant weight: (N//8, K) int32 (int4*8 packed with pack_order) return tensor: (M, N) float16 """ - qweight, scale_weight = weights + qweight, scale_weight = weights[:2] if workspace is None: workspace = torch.empty(size=[33554432 * 2], dtype=torch.int8, device="cuda") # 32MB workspace PPLW4A16QuantizationMethod.apply.__defaults__ = (None, None, workspace) @@ -40,7 +42,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + out = self.cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) else: out = torch.empty(shape, dtype, device=device) from lightllm_ppl_int4_kernel import matmul_i4_fp16 @@ -81,7 +83,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ """ """ from flash_llm_fp6_llm import linear_forward_cuda - qweight, scale = weights + qweight, scale = weights[:2] out = linear_forward_cuda(input_tensor, qweight, scale, 1) if self.bias: out.add_(bias) diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index f8610c5b6..3f9256dfe 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -19,12 +19,7 @@ def get(self, key, *args, **kwargs): quant_method_class = self._quant_methods.get(key) if not quant_method_class: raise ValueError(f"QuantMethod '{key}' not supported.") - tmp_key = key.split("-") - if len(tmp_key) == 2: - return quant_method_class() - else: - group_size = int(tmp_key[-1]) - return quant_method_class(group_size) + return quant_method_class() QUANTMETHODS = QuantMethodFactory() diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index b26e82352..67677b50c 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -2,7 +2,6 @@ import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS -from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager import torch.nn.functional as F try: @@ -45,17 +44,35 @@ def apply(self, input_tensor, weights, bias=None, out=None, use_custom_tensor_ma return F.linear(input_tensor, weights, bias) -@QUANTMETHODS.register([f"ao-w4a16-{group_size}" for group_size in [32, 64, 128, 256]]) -class AOW4A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self, group_size=128): +@QUANTMETHODS.register(["ao-w4a16-256"]) +class AOW4A16QuantizationMethodGroup256(AOBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.group_size = 256 + self.quant_func = int4_weight_only(group_size=self.group_size) + + +@QUANTMETHODS.register(["ao-w4a16-128"]) +class AOW4A16QuantizationMethodGroup128(AOBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.group_size = 128 + self.quant_func = int4_weight_only(group_size=self.group_size) + + +@QUANTMETHODS.register(["ao-w4a16-64"]) +class AOW4A16QuantizationMethodGroup64(AOBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.group_size = 64 + self.quant_func = int4_weight_only(group_size=self.group_size) + + +@QUANTMETHODS.register(["ao-w4a16-32"]) +class AOW4A16QuantizationMethodGroup32(AOBaseQuantizationMethod): + def __init__(self): super().__init__() - assert group_size in [ - 32, - 64, - 128, - 256, - ], f"torchao int4-weightonly requires groupsize in [32,64,128,256], but gets {group_size}" - self.group_size = group_size + self.group_size = 32 self.quant_func = int4_weight_only(group_size=self.group_size) diff --git a/lightllm/common/quantization/triton_quant/__init__.py b/lightllm/common/quantization/triton_quant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/quantization/triton_quant/fp8/__init__.py b/lightllm/common/quantization/triton_quant/fp8/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py new file mode 100644 index 000000000..6edd40b63 --- /dev/null +++ b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py @@ -0,0 +1,116 @@ +import torch +import triton +import triton.language as tl + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple + + +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py +@triton.jit +def _per_token_group_quant_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + y_stride, + N, + eps, + fp8_min, + fp8_max, + BLOCK: tl.constexpr, +): + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + x_q: torch.Tensor, + x_s: torch.Tensor, + eps: float = 1e-10, + dtype: torch.dtype = torch.float8_e4m3fn, +): + """group-wise, per-token quantization on input tensor `x`. + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + x_q: the tensor to save the quantized result of x. + x_s: the tensor to save the scale of x. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + """ + assert x.shape[-1] % group_size == 0, "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + fp8_min = -fp8_max + + M = x.numel() // group_size + N = group_size + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return + + +def torch_quant(x, group_size, dtype=torch.float8_e4m3fn): + M, N = x.shape + x_q = torch.randn((M, N)).cuda().to(torch.float8_e4m3fn) + x_s = torch.randn((M, N // group_size), dtype=torch.float32).cuda() + x = x.reshape(-1, group_size) + finfo = torch.finfo(dtype) + fp8_max = finfo.max + fp8_min = -fp8_max + + x_s = x.to(torch.float32).abs().max(-1)[0] / fp8_max + x_q = x.to(torch.float32) / x_s.reshape(-1, 1) + x_q = x_q.clamp(fp8_min, fp8_max).to(dtype) + return x_q.reshape(M, N), x_s + + +if __name__ == "__main__": + group_size = 128 + x = torch.randn((1024, 8192), dtype=torch.bfloat16).cuda() + + x_q = torch.randn((1024, 8192)).cuda().to(torch.float8_e4m3fn) + x_s = torch.randn((1024, 8192 // group_size), dtype=torch.float32).cuda() + + per_token_group_quant_fp8(x, group_size, x_q, x_s) + th_x_q, th_x_s = torch_quant(x, group_size) + print("th_x_s - x_s", torch.abs(th_x_s - x_s.reshape(-1)).max()) + print("th_x_q - x_q", torch.abs(th_x_q.to(torch.float32) - x_q.to(torch.float32)).max()) diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py new file mode 100644 index 000000000..06176ec34 --- /dev/null +++ b/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py @@ -0,0 +1,242 @@ +import torch +import triton +import triton.language as tl + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple +from triton import Config + + +class Fp8BlockMMKernelConfig(KernelConfigs): + kernel_name: str = "fp8_block_mm" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + M: int, + N: int, + K: int, + block_size: Tuple[int, int], + out_dtype: str, + ) -> dict: + key_params = { + "N": N, + "K": K, + "block_size": block_size, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + # find by M + config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] + return config + else: + config = { + "BLOCK_M": 64, + "BLOCK_N": block_size[0], + "BLOCK_K": block_size[1], + "GROUP_M": 32, + "num_warps": 4, + "num_stages": 3, + } + return config + + @classmethod + def save_config( + cls, N: int, K: int, block_size: Tuple[int, int], out_dtype: str, config_json: Dict[int, Dict[int, Dict]] + ): + + key_params = { + "N": N, + "K": K, + "block_size": block_size, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def grouped_launch(pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr): + + grid_m = tl.cdiv(m, block_m) + grid_n = tl.cdiv(n, block_n) + + width = group_m * grid_n + group_id = pid // width + group_size = tl.minimum(grid_m - group_id * group_m, group_m) + + pid_m = group_id * group_m + (pid % group_size) + pid_n = (pid % width) // group_size + + return pid_m, pid_n + + +@triton.jit +def _block_scaled_block_gemm( + A, + B, + C, + Ascale, + Bscale, + M, + N, + K, + group_n, + group_k, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_Ascale_m, + stride_Ascale_k, + stride_Bscale_k, + stride_Bscale_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + pid_m, pid_n = grouped_launch(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M) + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + Ascale_ptrs = Ascale + offs_am * stride_Ascale_m + offs_bsn = offs_bn // group_n + Bscale_ptrs = Bscale + offs_bsn * stride_Bscale_n + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + # tmp_a_s = tl.zeros((BLOCK_M,), dtype=tl.float32) + 1 + # tmp_b_s = tl.zeros((BLOCK_N,), dtype=tl.float32) + 1 + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + offs_ks = k * BLOCK_K // group_k + a_s = tl.load(Ascale_ptrs + offs_ks * stride_Ascale_k) + b_s = tl.load(Bscale_ptrs + offs_ks * stride_Bscale_k) + acc += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + acc = acc.to(C.dtype.element_ty) + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, acc, mask=mask) + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + Ascale: torch.Tensor, + Bscale: torch.Tensor, + C: torch.Tensor, + block_size: List[int], + dtype: torch.dtype = torch.bfloat16, + **run_config, +) -> torch.Tensor: + """w8a8fp8 block-wise quantization mm. + + Args: + A: Matrix A with shape of [M, K]. + B: Matrix B with shape of [K, N]. + Ascale: per-token block-wise Quantization scale for A: [M, K / block_size[0]]. + Bscale: Quantization scale for B: [K / block_size[0], M / block_size[1]]. + C: The output matrix with the shape of [M, N]. + block_size: block granularity of quantization (e.g., [128, 128]). + dtype: The data type of C. + Returns: + torch.Tensor: C. + """ + assert len(block_size) == 2 + block_k, block_n = block_size[0], block_size[1] + assert A.shape[0] == Ascale.shape[0] and A.shape[-1] == B.shape[0] + assert A.is_contiguous() and C.is_contiguous() + M, K = A.shape + _, N = B.shape + assert triton.cdiv(K, block_k) == Ascale.shape[-1] and Ascale.shape[-1] == Bscale.shape[0] + assert triton.cdiv(N, block_n) == Bscale.shape[1] + if not run_config: + run_config = Fp8BlockMMKernelConfig.try_to_get_best_config(M, N, K, block_size, dtype) + grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) + _block_scaled_block_gemm[grid]( + A, + B, + C, + Ascale, + Bscale, + M, + N, + K, + block_n, + block_k, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + Ascale.stride(0), + Ascale.stride(1), + Bscale.stride(0), + Bscale.stride(1), + **run_config, + ) + + return C + + +if __name__ == "__main__": + import time + + block_size = 128 + output_dtype = torch.bfloat16 + M, N, K = 4096, 256, 7168 + A = torch.randn((M, K), dtype=output_dtype).cuda().to(torch.float8_e4m3fn) # Activation + B = torch.randn((K, N), dtype=output_dtype).cuda().to(torch.float8_e4m3fn) # Weight + Ascale = torch.randn((M, K // block_size)).cuda() # + 0.2 + Bscale = torch.ones((K // block_size, N // block_size)).cuda() + + C = torch.randn((M, N), dtype=output_dtype).cuda() # weight + B = B.T.contiguous().T + # warmup + + w8a8_block_fp8_matmul(A, B, Ascale, Bscale, C, (block_size, block_size), output_dtype) + + #### groud truth + print(Ascale.unsqueeze(-1).repeat(1, 1, block_size).reshape(M, K).to(output_dtype).shape) + d_A = A.to(output_dtype) * (Ascale.unsqueeze(-1).repeat(1, 1, block_size).reshape(M, K).to(output_dtype)) + d_B = B.to(output_dtype).contiguous() + + gt_C = d_A.mm(d_B) + # caluate the simlarity + import torch.nn.functional as F + + cosine_sim = F.cosine_similarity(C.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) + + print(f"Cosine Similarity between C and gt_C: {cosine_sim.item()}") + + fn2 = lambda: torch.mm(d_A, d_B, out=gt_C) + ms2 = triton.testing.do_bench(fn2) + print(f"bf16 time : {ms2} ms") + + fn2 = lambda: w8a8_block_fp8_matmul(A, B, Ascale, Bscale, C, (block_size, block_size), output_dtype) + ms2 = triton.testing.do_bench_cudagraph(fn2) + print(f"fp8 time : {ms2} ms") diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py new file mode 100644 index 000000000..d005a950c --- /dev/null +++ b/lightllm/common/quantization/triton_quant/triton_quant.py @@ -0,0 +1,70 @@ +import os +import torch +import torch.nn.functional as F +from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.common.quantization.registry import QUANTMETHODS +from .fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul +from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 + + +class TritonBaseQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + + def quantize(self, weight: torch.Tensor): + """ """ + pass + + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None): + """ """ + pass + + +@QUANTMETHODS.register(["triton-fp8w8a8-block128"]) +class TritonFP8w8a8QuantizationMethod(TritonBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.is_moe = False + self.block_size = 128 + + def quantize(self, weight: torch.Tensor): + # TODO block-wise quant kernel + pass + + def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): + qweight, weight_scale, input_scale = weights + m, k = input_tensor.shape + n = qweight.shape[1] + if input_scale is None: + input_scale = self.cache_manager.alloc_tensor( + (m, k // self.block_size), torch.float32, device=input_tensor.device, is_graph_out=False + ) + input_tensor_q = self.cache_manager.alloc_tensor( + (m, k), qweight.dtype, device=qweight.device, is_graph_out=False + ) + per_token_group_quant_fp8(input_tensor, self.block_size, input_tensor_q, input_scale) + else: + # TODO + raise "statci input scale is not supported by triton fp8 block gemm kernel." + m = input_tensor.shape[0] + n = qweight.shape[1] + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor( + (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False + ) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + w8a8_block_fp8_matmul( + input_tensor_q, + qweight, + input_scale, + weight_scale, + out, + (self.block_size, self.block_size), + dtype=input_tensor.dtype, + ) + return out diff --git a/lightllm/common/quantization/vllm_quant.py b/lightllm/common/quantization/vllm_quant.py index bfd29fc3a..89456fe3d 100644 --- a/lightllm/common/quantization/vllm_quant.py +++ b/lightllm/common/quantization/vllm_quant.py @@ -2,7 +2,6 @@ import torch from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS -from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager import torch.nn.functional as F try: @@ -16,6 +15,9 @@ class vLLMBaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() assert HAS_VLLM, "vllm is not installed, you can't use quant api of it" + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager def quantize(self, weight: torch.Tensor): """ """ @@ -54,7 +56,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ n = qweight.shape[1] if out is None: if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor( + out = self.cache_manager.alloc_tensor( (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False ) else: @@ -122,7 +124,7 @@ def apply_scaled_mm_fp8( n = weights[0].shape[1] if out is None: if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor( + out = self.cache_manager.alloc_tensor( (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False ) else: @@ -139,7 +141,7 @@ def apply_pingpong_fp8( n = weights[0].shape[1] if out is None: if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor( + out = self.cache_manager.alloc_tensor( (m, n), input_tensor.dtype, device=input_tensor.device, is_graph_out=False ) else: diff --git a/lightllm/common/vllm_kernel/_custom_ops.py b/lightllm/common/vllm_kernel/_custom_ops.py index 7de42c3c6..fd3d5e229 100644 --- a/lightllm/common/vllm_kernel/_custom_ops.py +++ b/lightllm/common/vllm_kernel/_custom_ops.py @@ -12,41 +12,6 @@ try: from lightllm.common.vllm_kernel._ops import * - def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ): - from lightllm.common.fused_moe.topk_select import fused_topk, grouped_topk - - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize - ) - - return topk_weights, topk_ids - except ImportError: logger.error("vllm or lightllm_kernel is not installed, you can't use custom ops") diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 87777463b..1c84f8bcf 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -171,9 +171,9 @@ def _context_attention_kernel_with_CC( ) compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank) wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank) - wv = layer_weight.v_b_proj_.weight.transpose(1, 2).view(-1, layer_weight.kv_lora_rank) + wv = layer_weight.v_b_proj_.weight.transpose(0, 1).reshape(layer_weight.kv_lora_rank, -1) torch.mm(compressed_kv, wk.transpose(0, 1), out=k_nope.reshape(compressed_kv.shape[0], -1)) - torch.mm(compressed_kv, wv.transpose(0, 1), out=v.reshape(compressed_kv.shape[0], -1)) + torch.mm(compressed_kv, wv, out=v.reshape(compressed_kv.shape[0], -1)) q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index a28277dbf..5b58f8b04 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -16,8 +16,6 @@ FusedMoeWeight, ROWBMMWeight, ROWBMMWeightNoTp, - COLBMMWeight, - COLBMMWeightNoTp, ) from functools import partial @@ -93,6 +91,15 @@ def _init_weight_names(self): self.rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" else: self.rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight" + self.e_score_correction_bias_name = f"model.layers.{self.layer_num_}.mlp.gate.e_score_correction_bias" + + def _init_qweight_names(self): + self.act_scale_suffix = None + self.weight_scale_suffix = None + if self.quant_cfg.static_activation: + self.act_scale_suffix = "input_scale" + if self.quant_cfg.quantized_weight: + self.weight_scale_suffix = "weight_scale_inv" def _init_weight(self): if not self.enable_dp: @@ -122,33 +129,63 @@ def _load_q_rope(self, q_weight_): return q_rope_proj_.reshape(-1, self.qk_rope_head_dim * self.tp_q_head_num_).transpose(0, 1).contiguous() def _load_kb(self, kv_b_proj_): - kv_b_proj_ = kv_b_proj_ k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[ :, : self.qk_nope_head_dim, : ] return k_b_proj_.contiguous().to(self.data_type_) + def _load_kb_scale(self, kv_b_proj_, block_size): + k_b_proj_scale_ = kv_b_proj_.view( + self.num_attention_heads, self.qk_nope_head_dim * 2 // block_size, self.kv_lora_rank // block_size + )[:, : self.qk_nope_head_dim // block_size, :] + return k_b_proj_scale_.contiguous().to(self.data_type_) + def _load_vb(self, kv_b_proj_): - kv_b_proj_ = kv_b_proj_ - v_b_proj_ = kv_b_proj_.T.view( - self.kv_lora_rank, - self.num_attention_heads, - self.qk_nope_head_dim * 2, - )[:, :, self.qk_nope_head_dim :] + v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[ + :, :, self.qk_nope_head_dim : + ].transpose(0, 1) return v_b_proj_.contiguous().to(self.data_type_) + def _load_vb_scale(self, kv_b_proj_scale_, block_size): + v_b_proj_scale_ = kv_b_proj_scale_.T.view( + self.kv_lora_rank // block_size, + self.num_attention_heads, + self.qk_nope_head_dim * 2 // block_size, + )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) + return v_b_proj_scale_.contiguous().to(self.data_type_) + def load_hf_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) - if self.rope_weight_name in weights: - rope_weight_ = weights[self.rope_weight_name] - weights[f"model.layers.{self.layer_num_}.self_attn.q_rope_proj.weight"] = self._load_q_rope(rope_weight_) + if ( + self.quant_cfg.quantized_weight + and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights + ): + kv_b_proj_scale_ = weights[ + f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix + ] + block_size = 1 + if self.quant_cfg is not None: + hf_quantization_config = self.quant_cfg.hf_quantization_config + block_size = hf_quantization_config.get("weight_block_size", [128, 128])[0] + weights[ + f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + self.weight_scale_suffix + ] = self._load_kb_scale(kv_b_proj_scale_, block_size) + weights[ + f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + self.weight_scale_suffix + ] = self._load_vb_scale(kv_b_proj_scale_, block_size) return super().load_hf_weights(weights) + def _set_quantization(self): + super()._set_quantization() + # moe_gate of deepseek always keep bf16/fp16. + if self.is_moe: + self.moe_gate.quant_method = None + def _init_qkvo(self): q_split_n_embed = self.qk_nope_head_dim * self.tp_q_head_num_ q_split_n_embed_with_rope = ( @@ -159,45 +196,53 @@ def _init_qkvo(self): f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", self.data_type_, q_split_n_embed_with_rope, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) else: self.q_a_proj_ = ROWMMWeightNoTP( f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", self.data_type_, self.q_lora_rank, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) self.q_b_proj_ = ROWMMWeight( f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", self.data_type_, q_split_n_embed_with_rope, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) - self.q_rope_proj_ = ROWMMWeightNoTP( - f"model.layers.{self.layer_num_}.self_attn.q_rope_proj.weight", - self.data_type_, - self.qk_rope_head_dim * self.tp_q_head_num_, - ) - self.kv_a_proj_with_mqa_ = ROWMMWeightNoTP( f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", self.data_type_, self.kv_lora_rank + self.qk_rope_head_dim, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) self.k_b_proj_ = ROWBMMWeight( f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", self.data_type_, split_n_embed=self.tp_q_head_num_, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) - self.v_b_proj_ = COLBMMWeight( + self.v_b_proj_ = ROWBMMWeight( f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", self.data_type_, split_n_embed=self.tp_q_head_num_, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) self.o_weight_ = COLMMWeight( f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", self.data_type_, q_split_n_embed, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) def _init_qkvo_dp(self): @@ -208,65 +253,97 @@ def _init_qkvo_dp(self): f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", self.data_type_, q_split_n_embed_with_rope, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) else: self.q_a_proj_ = ROWMMWeightNoTP( f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", self.data_type_, self.q_lora_rank, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) self.q_b_proj_ = ROWMMWeightNoTP( f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", self.data_type_, q_split_n_embed_with_rope, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) - self.q_rope_proj_ = ROWMMWeightNoTP( - f"model.layers.{self.layer_num_}.self_attn.q_rope_proj.weight", - self.data_type_, - self.qk_rope_head_dim * self.tp_q_head_num_, - ) - self.kv_a_proj_with_mqa_ = ROWMMWeightNoTP( f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", self.data_type_, self.kv_lora_rank + self.qk_rope_head_dim, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) self.k_b_proj_ = ROWBMMWeightNoTp( f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", self.data_type_, split_n_embed=self.tp_q_head_num_, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) - self.v_b_proj_ = COLBMMWeightNoTp( + self.v_b_proj_ = ROWBMMWeightNoTp( f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", self.data_type_, split_n_embed=self.tp_q_head_num_, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) self.o_weight_ = COLMMWeightNoTp( f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", self.data_type_, q_split_n_embed, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) def _load_mlp(self, mlp_prefix, split_inter_size, no_tp=False): if no_tp: self.gate_up_proj = MultiROWMMWeightNoTP( - [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], + self.data_type_, + split_inter_size, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, + ) + self.down_proj = COLMMWeightNoTp( + f"{mlp_prefix}.down_proj.weight", + self.data_type_, + split_inter_size, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) - self.down_proj = COLMMWeightNoTp(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) else: self.gate_up_proj = MultiROWMMWeight( - [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], self.data_type_, split_inter_size + [f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], + self.data_type_, + split_inter_size, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, + ) + self.down_proj = COLMMWeight( + f"{mlp_prefix}.down_proj.weight", + self.data_type_, + split_inter_size, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) - self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size) def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeightNoTP( - f"model.layers.{self.layer_num_}.mlp.gate.weight", self.data_type_, moe_intermediate_size + f"model.layers.{self.layer_num_}.mlp.gate.weight", + self.data_type_, + moe_intermediate_size, + weight_scale_suffix=None, + act_scale_suffix=None, ) shared_intermediate_size = moe_intermediate_size * self.network_config_["n_shared_experts"] shared_split_inter_size = shared_intermediate_size // self.world_size_ @@ -276,10 +353,14 @@ def _init_moe(self): gate_proj_name="gate_proj", down_proj_name="down_proj", up_proj_name="up_proj", + e_score_correction_bias_name=self.e_score_correction_bias_name, weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", n_routed_experts=self.n_routed_experts, split_inter_size=moe_intermediate_size // self.world_size_, data_type=self.data_type_, + network_config=self.network_config_, + weight_scale_suffix=self.weight_scale_suffix, + act_scale_suffix=self.act_scale_suffix, ) def _init_ffn(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index dc0afeff1..e7fe3728a 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -185,7 +185,7 @@ def init_model(self, kvargs): self.model = CohereTpPartModel(model_kvargs) elif self.model_type == "phi3": self.model = Phi3TpPartModel(model_kvargs) - elif self.model_type == "deepseek_v2": + elif self.model_type in ["deepseek_v2", "deepseek_v3"]: self.model = Deepseek2TpPartModel(model_kvargs) elif self.model_type == "internvl_chat": llm_model_type = model_cfg.get("llm_config").get("model_type") diff --git a/test/kernel/deepseekv3_fp8_block_gemm_tuning.py b/test/kernel/deepseekv3_fp8_block_gemm_tuning.py new file mode 100644 index 000000000..366ab8a18 --- /dev/null +++ b/test/kernel/deepseekv3_fp8_block_gemm_tuning.py @@ -0,0 +1,273 @@ +import torch +import time +import os +import torch.multiprocessing as mp +from typing import List +from lightllm.utils.log_utils import init_logger +from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul +from lightllm.utils.watchdog_utils import Watchdog + +logger = init_logger(__name__) + + +def set_seed(): + import torch + import random + import numpy as np + + seed = 42 + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + return + + +@torch.no_grad() +def test_fp8_block_gemm( + M: int, + N: int, + K: int, + block_size: int, + dtype: torch.dtype, + test_count: int = 20, + **run_config, +): + set_seed() + + input_tuples = [] + for _ in range(test_count): + A = torch.randn((M, K), dtype=torch.float32).cuda().to(torch.float8_e4m3fn) # Activation + B = torch.randn((K, N), dtype=torch.float32).cuda().to(torch.float8_e4m3fn) # Weight + Ascale = torch.ones((M, (K + block_size - 1) // block_size)).cuda() + Bscale = torch.ones(((K + block_size - 1) // block_size, (N + block_size - 1) // block_size)).cuda() + C = torch.randn((M, N), dtype=dtype).cuda() # weight + input_tuples.append((A, B, Ascale, Bscale, C)) + w8a8_block_fp8_matmul(A, B, Ascale, Bscale, C, (block_size, block_size), dtype, **run_config) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for index in range(test_count): + A, B, Ascale, Bscale, C = input_tuples[index] + w8a8_block_fp8_matmul( + A, + B, + Ascale, + Bscale, + C, + (block_size, block_size), + **run_config, + ) + + graph.replay() + torch.cuda.synchronize() + start = time.time() + graph.replay() + torch.cuda.synchronize() + cost_time = (time.time() - start) * 1000 + logger.info(f"fp8 mm {M} {N} {K} block {block_size} cost time: {cost_time} ms") + return cost_time + + +def worker( + M: int, + N: int, + K: int, + block_size: int, + dtype: torch.dtype, + test_count: int, + test_configs, + queue, +): + dog = Watchdog(timeout=10) + dog.start() + + try: + for index in range(len(test_configs)): + tuning_config = test_configs[index] + cost_time = test_fp8_block_gemm( + M=M, + N=N, + K=K, + block_size=block_size, + dtype=dtype, + test_count=test_count, + **tuning_config, + ) + dog.heartbeat() + queue.put(cost_time) # Put result in queue + except Exception as ex: + logger.exception(str(ex) + f"config {tuning_config}") + import sys + + sys.exit(-1) + pass + + +def get_test_configs(split_id, split_count): + fp8_gemm_configs = [ + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 3, "num_warps": 8}, + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8, "num_stages": 5, "num_warps": 2}, + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 3, "num_warps": 8}, + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 3, "num_warps": 8}, + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8, "num_stages": 4, "num_warps": 4}, + ] + index = 0 + for cfg in fp8_gemm_configs: + if index % split_count == split_id: + yield cfg + index += 1 + else: + index += 1 + + +def tuning_configs( + device_id: int, # use for mult mp tunning + device_count: int, + M: int, + N: int, + K: int, + block_size: int, + dtype: torch.dtype, + test_count: int, +): + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) + best_config, best_cost_time = None, 10000000 + queue = mp.Queue() + test_configs = [] + for t_config in get_test_configs(device_id, device_count): + test_configs.append(t_config) + if len(test_configs) < 64: + continue + + p = mp.Process( + target=worker, + args=( + M, + N, + K, + block_size, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + while len(test_configs) != 0: + p = mp.Process( + target=worker, + args=( + M, + N, + K, + block_size, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + logger.info(f"{best_config} best cost: {best_cost_time}") + return best_config, best_cost_time + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + + from lightllm.utils.tuning_utils import mp_tuning + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import Fp8BlockMMKernelConfig + import collections + + block_size = 128 + store_json_ans = collections.defaultdict(dict) + for N, K in [ + (256, 7168), + (512, 7168), + (576, 7168), + (1536, 1536), + (1536, 7168), + (2048, 512), + (2304, 7168), + (8072, 7168), + (4096, 512), + (7168, 256), + (7168, 1024), + (7168, 1152), + (7168, 2048), + (7168, 2304), + (7168, 16384), + (7168, 18432), + (24576, 7168), + (32768, 512), + (36864, 7168), + ]: + for M in [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096]: + ans = mp_tuning( + tuning_configs, + { + "M": M, + "N": N, + "K": K, + "block_size": block_size, + "dtype": torch.bfloat16, + "test_count": 4, + }, + ) + store_json_ans[M] = ans + + Fp8BlockMMKernelConfig.save_config( + N=N, + K=K, + block_size=[block_size, block_size], + out_dtype=torch.bfloat16, + config_json=store_json_ans, + ) + + pass