Skip to content

Commit 867b7f7

Browse files
faran928facebook-github-bot
authored andcommitted
Support INT4 Dequant onto GPU for Seq INT TBE look up (pytorch#3584)
Summary: Seq INT4 -> INT4 STBE look up is supported in the diff stack: https://www.internalfb.com/diff/D61305978 . This diff supports: 1. The dequanitzation of INT4 -> INT4 STBE look up onto Cuda for all float types 2. Extends the dequantization of INT4 > INT4 STBE look up onto CPU for BF16 The main gap is to handle the dequant for the case when scale bias for INT4 quantized tensor is in the front. While for CPU, just need to add the dequantization for BF16 based on dtype. This will enable us to reduce the network overhead to remote embedding server as well as D2H data transfer from onto GPU host. Differential Revision: D68187234
1 parent 3e0db25 commit 867b7f7

File tree

9 files changed

+186
-55
lines changed

9 files changed

+186
-55
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ at::Tensor fusednbitrowwise_to_half_cpu(
432432
at::Tensor fusednbitrowwise_to_float_or_half_cpu(
433433
const at::Tensor& input,
434434
const int64_t bit_rate,
435-
const int64_t output_dtype);
435+
const int64_t output_dtype,
436+
const bool scale_bias_last);
436437

437438
at::Tensor quantize_mx_cuda(
438439
const at::Tensor& input,

fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu

+39-19
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
7474
}
7575

7676
// Fused 4/2-bit rowwise -> FP32/FP16 kernel
77-
template <typename output_t>
77+
template <typename output_t, bool scale_bias_last>
7878
__global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
7979
const int bit_rate,
8080
const std::uint8_t* input,
@@ -83,7 +83,6 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
8383
output_t* const output) {
8484
const int num_elem_per_byte = 8 / bit_rate;
8585
const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte;
86-
8786
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
8887
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
8988
const int row_incre = blockDim.y * gridDim.y;
@@ -92,9 +91,14 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
9291
const std::uint8_t* input_row = input + row * ncols;
9392
const __half* input_row_scale_bias = reinterpret_cast<const __half*>(
9493
input_row +
95-
(output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
94+
(!scale_bias_last
95+
? 0
96+
: (output_columns + num_elem_per_byte - 1) / num_elem_per_byte));
9697
float scale = __half2float(input_row_scale_bias[0]);
9798
float bias = __half2float(input_row_scale_bias[1]);
99+
if constexpr (!scale_bias_last) {
100+
input_row += 2 * sizeof(__half);
101+
}
98102
output_t* output_row = output + row * output_columns;
99103

100104
std::uint8_t quantized = input_row[col / num_elem_per_byte];
@@ -215,7 +219,8 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
215219
template <typename output_t>
216220
Tensor _fusednbitrowwise_to_float_gpu_t(
217221
const Tensor& input,
218-
const int64_t bit_rate) {
222+
const int64_t bit_rate,
223+
const bool scale_bias_last) {
219224
TENSOR_ON_CUDA_GPU(input);
220225
TENSOR_NDIM_EQUALS(input, 2);
221226
CUDA_DEVICE_GUARD(input);
@@ -245,7 +250,9 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
245250
{nrows, output_columns}, // 2 = sizeof(bfloat16)
246251
input.options().dtype(at::kBFloat16));
247252
} else {
248-
TORCH_CHECK(false, "Unsupported output dtype");
253+
TORCH_CHECK(
254+
false,
255+
"Unsupported output dtype within _fusednbitrowwise_to_float_gpu_t");
249256
}
250257
251258
if (nrows == 0 || output_columns == 0) {
@@ -260,18 +267,25 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
260267
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
261268
const dim3 gridDim(gridDim_x, gridDim_y);
262269
270+
#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
271+
_fusednbitrowwise_to_float_cuda_kernel<scalar_t, scale_bias_last> \
272+
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( \
273+
bit_rate, \
274+
input.data_ptr<std::uint8_t>(), \
275+
nrows, \
276+
ncols, \
277+
output.data_ptr<scalar_t>())
278+
263279
FBGEMM_DISPATCH_FLOATING_TYPES(
264280
output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] {
265-
_fusednbitrowwise_to_float_cuda_kernel<scalar_t>
266-
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
267-
bit_rate,
268-
input.data_ptr<uint8_t>(),
269-
nrows,
270-
ncols,
271-
output.data_ptr<scalar_t>());
281+
if (scale_bias_last) {
282+
DEQUANT_LAUNCH_NBIT(true);
283+
} else {
284+
DEQUANT_LAUNCH_NBIT(false);
285+
}
272286
C10_CUDA_KERNEL_LAUNCH_CHECK();
273287
});
274-
288+
#undef DEQUANT_LAUNCH_NBIT
275289
return output;
276290
}
277291
@@ -286,7 +300,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
286300
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu(
287301
const at::Tensor& input,
288302
const int64_t bit_rate) {
289-
return _fusednbitrowwise_to_float_gpu_t<float>(input, bit_rate);
303+
return _fusednbitrowwise_to_float_gpu_t<float>(
304+
input, bit_rate, true /* scale_bias_last */);
290305
}
291306
292307
/// @ingroup quantize-ops-cuda
@@ -301,7 +316,8 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu(
301316
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(
302317
const at::Tensor& input,
303318
const int64_t bit_rate) {
304-
return _fusednbitrowwise_to_float_gpu_t<at::Half>(input, bit_rate);
319+
return _fusednbitrowwise_to_float_gpu_t<at::Half>(
320+
input, bit_rate, true /* scale_bias_last */);
305321
}
306322
307323
/// @ingroup quantize-ops-cuda
@@ -321,19 +337,23 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(
321337
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
322338
const at::Tensor& input,
323339
const int64_t bit_rate,
324-
const int64_t output_dtype) {
340+
const int64_t output_dtype,
341+
const bool scale_bias_last) {
325342
Tensor output;
326343
327344
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
328345
switch (output_sparse_dtype) {
329346
case SparseType::FP32:
330-
output = _fusednbitrowwise_to_float_gpu_t<float>(input, bit_rate);
347+
output = _fusednbitrowwise_to_float_gpu_t<float>(
348+
input, bit_rate, scale_bias_last);
331349
break;
332350
case SparseType::FP16:
333-
output = _fusednbitrowwise_to_float_gpu_t<at::Half>(input, bit_rate);
351+
output = _fusednbitrowwise_to_float_gpu_t<at::Half>(
352+
input, bit_rate, scale_bias_last);
334353
break;
335354
case SparseType::BF16:
336-
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16>(input, bit_rate);
355+
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16>(
356+
input, bit_rate, scale_bias_last);
337357
break;
338358
default:
339359
TORCH_CHECK(false);

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

+54-13
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ Tensor _fusednbitrowwise_to_float_cpu(
150150
return output;
151151
}
152152

153-
Tensor _fusednbitrowwise_sbfront_to_float_cpu(
153+
// Both float16 and bfloat16 are of same type uint16_t
154+
template <typename output_t>
155+
Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
154156
const Tensor& input,
155157
const int64_t bit_rate) {
156158
TENSOR_ON_CPU(input);
@@ -165,15 +167,36 @@ Tensor _fusednbitrowwise_sbfront_to_float_cpu(
165167
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
166168

167169
Tensor output;
168-
output = at::empty(
169-
{nrows, output_columns}, // 4 = sizeof(float)
170-
input.options().dtype(at::kFloat));
170+
if (std::is_same<output_t, float>::value) {
171+
output = at::empty(
172+
{nrows, output_columns}, // 4 = sizeof(float)
173+
input.options().dtype(at::kFloat));
174+
} else if (std::is_same<output_t, at::Half>::value) {
175+
output = at::empty(
176+
{nrows, output_columns}, // 2 = sizeof(half)
177+
input.options().dtype(at::kHalf));
178+
} else if (std::is_same<output_t, at::BFloat16>::value) {
179+
output = at::empty(
180+
{nrows, output_columns}, // 2 = sizeof(half)
181+
input.options().dtype(at::kBFloat16));
182+
} else {
183+
TORCH_CHECK(
184+
false,
185+
"Unsupported output dtype for _fusednbitrowwise_sbfront_to_float_or_half_cpu");
186+
}
171187

172-
float* output_data = static_cast<float*>(
188+
using output_ty = std::conditional_t<
189+
std::is_same<output_t, float>::value,
190+
float,
191+
fbgemm::float16>;
192+
output_ty* output_data = static_cast<output_ty*>(
173193
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
174194
// unresolved data_ptr symbol.
175195

176-
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float>(
196+
constexpr bool is_float16_bf16 = std::is_same<output_t, at::BFloat16>::value;
197+
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<
198+
output_ty,
199+
is_float16_bf16>(
177200
bit_rate,
178201
input.data_ptr<uint8_t>(),
179202
nrows,
@@ -311,7 +334,7 @@ Tensor fusednbitrowwise_to_float_cpu(
311334

312335
/// @ingroup quantize-data-cpu
313336
/// @brief Dequantize int4/int2 rows with scale and bias stored in the front
314-
/// into float32.
337+
/// into float32/float15/BFloat16.
315338
/// @param input Tensor of int4/int2 rows with scale and bias stored in the
316339
/// front.
317340
/// @param bit_rate Bit rate of each element. Should be 4 or 2.
@@ -323,8 +346,25 @@ Tensor fusednbitrowwise_to_float_cpu(
323346
/// purpose because its kernel is reference implementation and not optimized.
324347
Tensor fusednbitrowwise_sbfront_to_float_cpu(
325348
const Tensor& input,
326-
const int64_t bit_rate) {
327-
return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate);
349+
const int64_t bit_rate,
350+
const int64_t output_dtype) {
351+
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
352+
switch (output_sparse_dtype) {
353+
case SparseType::FP32:
354+
return _fusednbitrowwise_sbfront_to_float_or_half_cpu<float>(
355+
input, bit_rate);
356+
break;
357+
case SparseType::FP16:
358+
return _fusednbitrowwise_sbfront_to_float_or_half_cpu<at::Half>(
359+
input, bit_rate);
360+
break;
361+
case SparseType::BF16:
362+
return _fusednbitrowwise_sbfront_to_float_or_half_cpu<at::BFloat16>(
363+
input, bit_rate);
364+
break;
365+
default:
366+
TORCH_CHECK(false);
367+
}
328368
}
329369

330370
/// @ingroup quantize-data-cpu
@@ -340,7 +380,8 @@ Tensor fusednbitrowwise_to_half_cpu(
340380
Tensor fusednbitrowwise_to_float_or_half_cpu(
341381
const Tensor& input,
342382
const int64_t bit_rate,
343-
const int64_t output_dtype) {
383+
const int64_t output_dtype,
384+
[[maybe_unused]] const bool scale_bias_last) {
344385
Tensor output;
345386

346387
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
@@ -520,11 +561,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
520561
m.def(
521562
"FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor");
522563
m.def(
523-
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor");
564+
"FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf(Tensor input, int bit_rate, int output_dtype) -> Tensor");
524565
m.def(
525566
"FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor");
526567
m.def(
527-
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0) -> Tensor");
568+
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0, bool scale_bias_last=True) -> Tensor");
528569
m.def(
529570
"FloatToHFP8Quantized(Tensor input, int ebits, int exponent_bias, float max_pos) -> Tensor");
530571
m.def(
@@ -542,7 +583,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
542583

543584
TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) {
544585
DISPATCH_TO_QUANTIZED_CPU(
545-
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat",
586+
"FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf",
546587
fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu);
547588
}
548589

fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ Tensor FloatToFP8RowwiseQuantized_meta(const Tensor& input, bool forward) {
7272
Tensor fusednbitrowwise_to_float_or_half_meta(
7373
const Tensor& input,
7474
const int64_t bit_rate,
75-
const int64_t output_dtype) {
75+
const int64_t output_dtype,
76+
[[maybe_unused]] const bool scale_bias_last) {
7677
const at::SymIntArrayRef input_sizes = input.sym_sizes();
7778
const at::SymInt nrows = input_sizes[0];
7879
// Here we want the number of bytes in a row

fbgemm_gpu/test/tbe/inference/common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,10 @@ def execute_nbit_forward_( # noqa C901
351351
f = torch.cat(fs, dim=0).view(-1, D)
352352

353353
if fc2.dtype == torch.quint4x2:
354-
fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat(
355-
fc2.cpu(), bit_rate=4
354+
fc2_float = (
355+
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf(
356+
fc2.cpu(), bit_rate=4, output_dtype=0
357+
)
356358
)
357359
else:
358360
fc2_float = fc2.float()

fbgemm_gpu/test/tbe/inference/failures_dict_fast.json

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"fbgemm::FloatToHFP8Quantized": {},
88
"fbgemm::Fused8BitRowwiseQuantizedToFloat": {},
99
"fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf": {},
10-
"fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloat": {},
10+
"fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf": {},
11+
"fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf": {},
1112
"fbgemm::HFP8QuantizedToFloat": {},
1213
"fbgemm::asynchronous_complete_cumsum": {},
1314
"fbgemm::bounds_check_indices": {},
@@ -44,9 +45,17 @@
4445
"comment": "",
4546
"status": "xsuccess"
4647
},
48+
"NBitFowardTest.test_faketensor__test_nbit_forward_cpu_gpu_dequantize_parity": {
49+
"comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests",
50+
"status": "xfail"
51+
},
4752
"NBitFowardTest.test_faketensor__test_nbit_forward_cpu_seq_int4": {
4853
"comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests",
4954
"status": "xfail"
55+
},
56+
"NBitFowardTest.test_schema__test_nbit_forward_cpu_gpu_dequantize_parity": {
57+
"comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests",
58+
"status": "xfail"
5059
}
5160
},
5261
"fbgemm::int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {

0 commit comments

Comments
 (0)