Skip to content

Commit

Permalink
[FP6-LLM] Port splitK map from DeepSpeed (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored May 29, 2024
1 parent 42c2376 commit 6dd63b8
Showing 1 changed file with 145 additions and 2 deletions.
147 changes: 145 additions & 2 deletions torchao/quantization/fp6_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Optional

import torch
Expand Down Expand Up @@ -111,6 +112,143 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor
return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype)


# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
_SPLIT_K_MAP = [
{ # tokens: [1, 64]
3072: 18,
4096: 13,
5120: 10,
6144: 9,
8192: 6,
10240: 5,
14336: 7,
28672: 7,
57344: 7
},
{ # tokens: [65:128]
3072: 9,
4096: 6,
5120: 5,
6144: 9,
8192: 3,
10240: 5,
14336: 7,
28672: 7,
57344: 6
},
{ # tokens: [129:192]
3072: 6,
4096: 4,
5120: 7,
6144: 3,
8192: 2,
10240: 5,
14336: 5,
28672: 5,
57344: 4
},
{ # tokens: [193:256]
3072: 9,
4096: 3,
5120: 5,
6144: 2,
8192: 5,
10240: 4,
14336: 8,
28672: 6,
57344: 4
},
{ # tokens: [257:320]
3072: 7,
4096: 5,
5120: 2,
6144: 5,
8192: 4,
10240: 1,
14336: 3,
28672: 3,
57344: 4
},
{ # tokens: [321:384]
3072: 3,
4096: 2,
5120: 5,
6144: 3,
8192: 1,
10240: 8,
14336: 3,
28672: 4,
57344: 3
},
{ # tokens: [385:448]
3072: 5,
4096: 7,
5120: 3,
6144: 5,
8192: 7,
10240: 3,
14336: 1,
28672: 1,
57344: 3
},
{ # tokens: [449:512]
3072: 2,
4096: 5,
5120: 4,
6144: 1,
8192: 5,
10240: 2,
14336: 6,
28672: 4,
57344: 1
},
{ # tokens: [513:576]
3072: 2,
4096: 3,
5120: 1,
6144: 1,
8192: 3,
10240: 3,
14336: 3,
28672: 1,
57344: 1
},
{ # tokens: [577:640]
3072: 5,
4096: 4,
5120: 1,
6144: 4,
8192: 2,
10240: 1,
14336: 1,
28672: 1,
57344: 1
},
{ # tokens: [641:704]
3072: 3,
4096: 1,
5120: 2,
6144: 2,
8192: 1,
10240: 2,
14336: 1,
28672: 1,
57344: 1
},
{ # tokens: [705:768]
3072: 3,
4096: 1,
5120: 3,
6144: 2,
8192: 1,
10240: 1,
14336: 1,
28672: 1,
57344: 1
}
]


class Fp6LlmLinear(nn.Module):
"""FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112.
"""
Expand All @@ -124,12 +262,17 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None
self.in_features = weight.shape[1] * 16 // 3

def forward(self, x: Tensor) -> Tensor:
# TODO: splitK map
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1)
splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features)
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK)
if self.bias is not None:
out = out + self.bias
return out.view(*x.shape[:-1], self.out_features).to(x.dtype)

@staticmethod
def get_split_k(bsize: int, out_dim: int) -> int:
# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1

@classmethod
def from_float(cls, linear: nn.Linear):
assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0)
Expand Down

0 comments on commit 6dd63b8

Please sign in to comment.