@@ -74,7 +74,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
74
74
}
75
75
76
76
// Fused 4/2-bit rowwise -> FP32/FP16 kernel
77
- template <typename output_t >
77
+ template <typename output_t , bool scale_bias_last >
78
78
__global__ inline void _fusednbitrowwise_to_float_cuda_kernel (
79
79
const int bit_rate,
80
80
const std::uint8_t * input,
@@ -83,7 +83,6 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
83
83
output_t * const output) {
84
84
const int num_elem_per_byte = 8 / bit_rate;
85
85
const int output_columns = (ncols - 2 * sizeof (__half)) * num_elem_per_byte;
86
-
87
86
int row = (int )blockIdx .y * blockDim .y + threadIdx .y ;
88
87
const int col = (int )blockIdx .x * blockDim .x + threadIdx .x ;
89
88
const int row_incre = blockDim .y * gridDim .y ;
@@ -92,9 +91,14 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
92
91
const std::uint8_t * input_row = input + row * ncols;
93
92
const __half* input_row_scale_bias = reinterpret_cast <const __half*>(
94
93
input_row +
95
- (output_columns + num_elem_per_byte - 1 ) / num_elem_per_byte);
94
+ (!scale_bias_last
95
+ ? 0
96
+ : (output_columns + num_elem_per_byte - 1 ) / num_elem_per_byte));
96
97
float scale = __half2float (input_row_scale_bias[0 ]);
97
98
float bias = __half2float (input_row_scale_bias[1 ]);
99
+ if constexpr (!scale_bias_last) {
100
+ input_row += 2 * sizeof (__half);
101
+ }
98
102
output_t * output_row = output + row * output_columns;
99
103
100
104
std::uint8_t quantized = input_row[col / num_elem_per_byte];
@@ -215,7 +219,8 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
215
219
template <typename output_t >
216
220
Tensor _fusednbitrowwise_to_float_gpu_t (
217
221
const Tensor& input,
218
- const int64_t bit_rate) {
222
+ const int64_t bit_rate,
223
+ const bool scale_bias_last) {
219
224
TENSOR_ON_CUDA_GPU (input);
220
225
TENSOR_NDIM_EQUALS (input, 2 );
221
226
CUDA_DEVICE_GUARD (input);
@@ -245,7 +250,9 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
245
250
{nrows, output_columns}, // 2 = sizeof(bfloat16)
246
251
input.options ().dtype (at::kBFloat16 ));
247
252
} else {
248
- TORCH_CHECK (false , " Unsupported output dtype" );
253
+ TORCH_CHECK (
254
+ false ,
255
+ " Unsupported output dtype within _fusednbitrowwise_to_float_gpu_t" );
249
256
}
250
257
251
258
if (nrows == 0 || output_columns == 0 ) {
@@ -260,18 +267,25 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
260
267
const auto gridDim_y = cuda_calc_block_count (nrows, blockDim .y );
261
268
const dim3 gridDim (gridDim_x, gridDim_y);
262
269
270
+ #define DEQUANT_LAUNCH_NBIT (scale_bias_last ) \
271
+ _fusednbitrowwise_to_float_cuda_kernel<scalar_t , scale_bias_last> \
272
+ <<<gridDim , blockDim , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
273
+ bit_rate, \
274
+ input.data_ptr <std::uint8_t >(), \
275
+ nrows, \
276
+ ncols, \
277
+ output.data_ptr <scalar_t >())
278
+
263
279
FBGEMM_DISPATCH_FLOATING_TYPES (
264
280
output.scalar_type (), " fusednbitrowwise_to_float_cuda_kernel" , [&] {
265
- _fusednbitrowwise_to_float_cuda_kernel<scalar_t >
266
- <<<gridDim , blockDim , 0 , at::cuda::getCurrentCUDAStream()>>> (
267
- bit_rate,
268
- input.data_ptr <uint8_t >(),
269
- nrows,
270
- ncols,
271
- output.data_ptr <scalar_t >());
281
+ if (scale_bias_last) {
282
+ DEQUANT_LAUNCH_NBIT (true );
283
+ } else {
284
+ DEQUANT_LAUNCH_NBIT (false );
285
+ }
272
286
C10_CUDA_KERNEL_LAUNCH_CHECK ();
273
287
});
274
-
288
+ # undef DEQUANT_LAUNCH_NBIT
275
289
return output;
276
290
}
277
291
@@ -286,7 +300,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
286
300
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu (
287
301
const at::Tensor& input,
288
302
const int64_t bit_rate) {
289
- return _fusednbitrowwise_to_float_gpu_t <float >(input, bit_rate);
303
+ return _fusednbitrowwise_to_float_gpu_t <float >(
304
+ input, bit_rate, true /* scale_bias_last */ );
290
305
}
291
306
292
307
// / @ingroup quantize-ops-cuda
@@ -301,7 +316,8 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu(
301
316
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu (
302
317
const at::Tensor& input,
303
318
const int64_t bit_rate) {
304
- return _fusednbitrowwise_to_float_gpu_t <at::Half>(input, bit_rate);
319
+ return _fusednbitrowwise_to_float_gpu_t <at::Half>(
320
+ input, bit_rate, true /* scale_bias_last */ );
305
321
}
306
322
307
323
// / @ingroup quantize-ops-cuda
@@ -321,19 +337,23 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(
321
337
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu (
322
338
const at::Tensor& input,
323
339
const int64_t bit_rate,
324
- const int64_t output_dtype) {
340
+ const int64_t output_dtype,
341
+ const bool scale_bias_last) {
325
342
Tensor output;
326
343
327
344
SparseType output_sparse_dtype = static_cast <SparseType>(output_dtype);
328
345
switch (output_sparse_dtype) {
329
346
case SparseType::FP32:
330
- output = _fusednbitrowwise_to_float_gpu_t <float >(input, bit_rate);
347
+ output = _fusednbitrowwise_to_float_gpu_t <float >(
348
+ input, bit_rate, scale_bias_last);
331
349
break ;
332
350
case SparseType::FP16:
333
- output = _fusednbitrowwise_to_float_gpu_t <at::Half>(input, bit_rate);
351
+ output = _fusednbitrowwise_to_float_gpu_t <at::Half>(
352
+ input, bit_rate, scale_bias_last);
334
353
break ;
335
354
case SparseType::BF16:
336
- output = _fusednbitrowwise_to_float_gpu_t <at::BFloat16>(input, bit_rate);
355
+ output = _fusednbitrowwise_to_float_gpu_t <at::BFloat16>(
356
+ input, bit_rate, scale_bias_last);
337
357
break ;
338
358
default :
339
359
TORCH_CHECK (false );
0 commit comments