Skip to content

Commit 8918a13

Browse files
authored
[FP6-LLM] Port splitK map from DeepSpeed (pytorch#283)
1 parent 635f890 commit 8918a13

File tree

1 file changed

+145
-2
lines changed

1 file changed

+145
-2
lines changed

torchao/quantization/fp6_llm.py

+145-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import Optional
23

34
import torch
@@ -111,6 +112,143 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor
111112
return from_float6_e3m2(tensor_fp6, no_bit_packing=True, dtype=dtype)
112113

113114

115+
# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
116+
_SPLIT_K_MAP = [
117+
{ # tokens: [1, 64]
118+
3072: 18,
119+
4096: 13,
120+
5120: 10,
121+
6144: 9,
122+
8192: 6,
123+
10240: 5,
124+
14336: 7,
125+
28672: 7,
126+
57344: 7
127+
},
128+
{ # tokens: [65:128]
129+
3072: 9,
130+
4096: 6,
131+
5120: 5,
132+
6144: 9,
133+
8192: 3,
134+
10240: 5,
135+
14336: 7,
136+
28672: 7,
137+
57344: 6
138+
},
139+
{ # tokens: [129:192]
140+
3072: 6,
141+
4096: 4,
142+
5120: 7,
143+
6144: 3,
144+
8192: 2,
145+
10240: 5,
146+
14336: 5,
147+
28672: 5,
148+
57344: 4
149+
},
150+
{ # tokens: [193:256]
151+
3072: 9,
152+
4096: 3,
153+
5120: 5,
154+
6144: 2,
155+
8192: 5,
156+
10240: 4,
157+
14336: 8,
158+
28672: 6,
159+
57344: 4
160+
},
161+
{ # tokens: [257:320]
162+
3072: 7,
163+
4096: 5,
164+
5120: 2,
165+
6144: 5,
166+
8192: 4,
167+
10240: 1,
168+
14336: 3,
169+
28672: 3,
170+
57344: 4
171+
},
172+
{ # tokens: [321:384]
173+
3072: 3,
174+
4096: 2,
175+
5120: 5,
176+
6144: 3,
177+
8192: 1,
178+
10240: 8,
179+
14336: 3,
180+
28672: 4,
181+
57344: 3
182+
},
183+
{ # tokens: [385:448]
184+
3072: 5,
185+
4096: 7,
186+
5120: 3,
187+
6144: 5,
188+
8192: 7,
189+
10240: 3,
190+
14336: 1,
191+
28672: 1,
192+
57344: 3
193+
},
194+
{ # tokens: [449:512]
195+
3072: 2,
196+
4096: 5,
197+
5120: 4,
198+
6144: 1,
199+
8192: 5,
200+
10240: 2,
201+
14336: 6,
202+
28672: 4,
203+
57344: 1
204+
},
205+
{ # tokens: [513:576]
206+
3072: 2,
207+
4096: 3,
208+
5120: 1,
209+
6144: 1,
210+
8192: 3,
211+
10240: 3,
212+
14336: 3,
213+
28672: 1,
214+
57344: 1
215+
},
216+
{ # tokens: [577:640]
217+
3072: 5,
218+
4096: 4,
219+
5120: 1,
220+
6144: 4,
221+
8192: 2,
222+
10240: 1,
223+
14336: 1,
224+
28672: 1,
225+
57344: 1
226+
},
227+
{ # tokens: [641:704]
228+
3072: 3,
229+
4096: 1,
230+
5120: 2,
231+
6144: 2,
232+
8192: 1,
233+
10240: 2,
234+
14336: 1,
235+
28672: 1,
236+
57344: 1
237+
},
238+
{ # tokens: [705:768]
239+
3072: 3,
240+
4096: 1,
241+
5120: 3,
242+
6144: 2,
243+
8192: 1,
244+
10240: 1,
245+
14336: 1,
246+
28672: 1,
247+
57344: 1
248+
}
249+
]
250+
251+
114252
class Fp6LlmLinear(nn.Module):
115253
"""FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112.
116254
"""
@@ -124,12 +262,17 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None
124262
self.in_features = weight.shape[1] * 16 // 3
125263

126264
def forward(self, x: Tensor) -> Tensor:
127-
# TODO: splitK map
128-
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1)
265+
splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features)
266+
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK)
129267
if self.bias is not None:
130268
out = out + self.bias
131269
return out.view(*x.shape[:-1], self.out_features).to(x.dtype)
132270

271+
@staticmethod
272+
def get_split_k(bsize: int, out_dim: int) -> int:
273+
# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
274+
return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1
275+
133276
@classmethod
134277
def from_float(cls, linear: nn.Linear):
135278
assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0)

0 commit comments

Comments
 (0)