1
+ import math
1
2
from typing import Optional
2
3
3
4
import torch
@@ -111,6 +112,143 @@ def from_tc_float6_e3m2(tensor: Tensor, M: int, N: int, dtype: torch.dtype = tor
111
112
return from_float6_e3m2 (tensor_fp6 , no_bit_packing = True , dtype = dtype )
112
113
113
114
115
+ # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
116
+ _SPLIT_K_MAP = [
117
+ { # tokens: [1, 64]
118
+ 3072 : 18 ,
119
+ 4096 : 13 ,
120
+ 5120 : 10 ,
121
+ 6144 : 9 ,
122
+ 8192 : 6 ,
123
+ 10240 : 5 ,
124
+ 14336 : 7 ,
125
+ 28672 : 7 ,
126
+ 57344 : 7
127
+ },
128
+ { # tokens: [65:128]
129
+ 3072 : 9 ,
130
+ 4096 : 6 ,
131
+ 5120 : 5 ,
132
+ 6144 : 9 ,
133
+ 8192 : 3 ,
134
+ 10240 : 5 ,
135
+ 14336 : 7 ,
136
+ 28672 : 7 ,
137
+ 57344 : 6
138
+ },
139
+ { # tokens: [129:192]
140
+ 3072 : 6 ,
141
+ 4096 : 4 ,
142
+ 5120 : 7 ,
143
+ 6144 : 3 ,
144
+ 8192 : 2 ,
145
+ 10240 : 5 ,
146
+ 14336 : 5 ,
147
+ 28672 : 5 ,
148
+ 57344 : 4
149
+ },
150
+ { # tokens: [193:256]
151
+ 3072 : 9 ,
152
+ 4096 : 3 ,
153
+ 5120 : 5 ,
154
+ 6144 : 2 ,
155
+ 8192 : 5 ,
156
+ 10240 : 4 ,
157
+ 14336 : 8 ,
158
+ 28672 : 6 ,
159
+ 57344 : 4
160
+ },
161
+ { # tokens: [257:320]
162
+ 3072 : 7 ,
163
+ 4096 : 5 ,
164
+ 5120 : 2 ,
165
+ 6144 : 5 ,
166
+ 8192 : 4 ,
167
+ 10240 : 1 ,
168
+ 14336 : 3 ,
169
+ 28672 : 3 ,
170
+ 57344 : 4
171
+ },
172
+ { # tokens: [321:384]
173
+ 3072 : 3 ,
174
+ 4096 : 2 ,
175
+ 5120 : 5 ,
176
+ 6144 : 3 ,
177
+ 8192 : 1 ,
178
+ 10240 : 8 ,
179
+ 14336 : 3 ,
180
+ 28672 : 4 ,
181
+ 57344 : 3
182
+ },
183
+ { # tokens: [385:448]
184
+ 3072 : 5 ,
185
+ 4096 : 7 ,
186
+ 5120 : 3 ,
187
+ 6144 : 5 ,
188
+ 8192 : 7 ,
189
+ 10240 : 3 ,
190
+ 14336 : 1 ,
191
+ 28672 : 1 ,
192
+ 57344 : 3
193
+ },
194
+ { # tokens: [449:512]
195
+ 3072 : 2 ,
196
+ 4096 : 5 ,
197
+ 5120 : 4 ,
198
+ 6144 : 1 ,
199
+ 8192 : 5 ,
200
+ 10240 : 2 ,
201
+ 14336 : 6 ,
202
+ 28672 : 4 ,
203
+ 57344 : 1
204
+ },
205
+ { # tokens: [513:576]
206
+ 3072 : 2 ,
207
+ 4096 : 3 ,
208
+ 5120 : 1 ,
209
+ 6144 : 1 ,
210
+ 8192 : 3 ,
211
+ 10240 : 3 ,
212
+ 14336 : 3 ,
213
+ 28672 : 1 ,
214
+ 57344 : 1
215
+ },
216
+ { # tokens: [577:640]
217
+ 3072 : 5 ,
218
+ 4096 : 4 ,
219
+ 5120 : 1 ,
220
+ 6144 : 4 ,
221
+ 8192 : 2 ,
222
+ 10240 : 1 ,
223
+ 14336 : 1 ,
224
+ 28672 : 1 ,
225
+ 57344 : 1
226
+ },
227
+ { # tokens: [641:704]
228
+ 3072 : 3 ,
229
+ 4096 : 1 ,
230
+ 5120 : 2 ,
231
+ 6144 : 2 ,
232
+ 8192 : 1 ,
233
+ 10240 : 2 ,
234
+ 14336 : 1 ,
235
+ 28672 : 1 ,
236
+ 57344 : 1
237
+ },
238
+ { # tokens: [705:768]
239
+ 3072 : 3 ,
240
+ 4096 : 1 ,
241
+ 5120 : 3 ,
242
+ 6144 : 2 ,
243
+ 8192 : 1 ,
244
+ 10240 : 1 ,
245
+ 14336 : 1 ,
246
+ 28672 : 1 ,
247
+ 57344 : 1
248
+ }
249
+ ]
250
+
251
+
114
252
class Fp6LlmLinear (nn .Module ):
115
253
"""FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112.
116
254
"""
@@ -124,12 +262,17 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None
124
262
self .in_features = weight .shape [1 ] * 16 // 3
125
263
126
264
def forward (self , x : Tensor ) -> Tensor :
127
- # TODO: splitK map
128
- out = fp16act_fp6weight_linear (x .view (- 1 , self .in_features ).half (), self .weight , self .scales , splitK = 1 )
265
+ splitK = self . get_split_k ( math . prod ( x . shape [: - 1 ]), self . out_features )
266
+ out = fp16act_fp6weight_linear (x .view (- 1 , self .in_features ).half (), self .weight , self .scales , splitK = splitK )
129
267
if self .bias is not None :
130
268
out = out + self .bias
131
269
return out .view (* x .shape [:- 1 ], self .out_features ).to (x .dtype )
132
270
271
+ @staticmethod
272
+ def get_split_k (bsize : int , out_dim : int ) -> int :
273
+ # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
274
+ return _SPLIT_K_MAP [(bsize - 1 ) // 64 ].get (out_dim , 1 ) if bsize <= 768 else 1
275
+
133
276
@classmethod
134
277
def from_float (cls , linear : nn .Linear ):
135
278
assert (linear .in_features % 64 == 0 ) and (linear .out_features % 256 == 0 )
0 commit comments