From 6dd63b89b9b9988845abe8cd372f6a3615aee68e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 29 May 2024 10:45:36 +0800 Subject: [PATCH] [FP6-LLM] Port splitK map from DeepSpeed (#283) --- torchao/quantization/fp6_llm.py | 147 +++++++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/fp6_llm.py b/torchao/quantization/fp6_llm.py index 9f559d4164..0fb0f7dd98 100644 --- a/torchao/quantization/fp6_llm.py +++ b/torchao/quantization/fp6_llm.py @@ -1,3 +1,4 @@ +import math from typing import Optional import torch @@ -111,6 +112,143 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype) +# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +_SPLIT_K_MAP = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7 + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6 + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4 + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4 + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4 + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3 + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3 + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1 + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1 + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + } +] + + class Fp6LlmLinear(nn.Module): """FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. """ @@ -124,12 +262,17 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None self.in_features = weight.shape[1] * 16 // 3 def forward(self, x: Tensor) -> Tensor: - # TODO: splitK map - out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1) + splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) + out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK) if self.bias is not None: out = out + self.bias return out.view(*x.shape[:-1], self.out_features).to(x.dtype) + @staticmethod + def get_split_k(bsize: int, out_dim: int) -> int: + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + @classmethod def from_float(cls, linear: nn.Linear): assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0)