From 97924d72f8d1442af76a630643ea988c30dd0bf5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 00:45:38 +0800 Subject: [PATCH 01/80] add fp16_to_fp6 prototype --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index d29f70be0c..e7ece50eed 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -18,6 +18,47 @@ #include #include #include +#include + + +// inspired by __internal_float2half() and float2half() from "cuda_fp16.h" +unsigned char fp16_to_fp6(const __half a) { + unsigned short fp16_bits; + std::memcpy(&fp16_bits, &a, sizeof(a)); + + unsigned short result; + unsigned short remainder = 0u; + unsigned short sign = (fp16_bits >> 15u) << 5u; + fp16_bits &= 0x7FFFu; // clear sign bit + + if (fp16_bits >= 0b0'11111'0000000000u) { + throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); + } else if (fp16_bits >= 0b0'10011'1110000000u) { // FP6 overflow + result = sign | 0b0'111'11; + } else if (fp16_bits >= 0b0'01101'0000000000u) { // FP6 normal number + remainder = fp16_bits << 8u; // truncated mantissa bits + fp16_bits -= 0b0'01100'0000000000u; // update exponent bits + fp16_bits >>= 8u; // truncate mantissa bits + result = sign | fp16_bits; // add sign bit + } else if (fp16_bits >= 0'01111010'0000000001u) { // FP6 subnormal number + unsigned short fp16_exp_bits = fp16_bits >> 10u; + unsigned short shift = 0xEu - fp16_exp_bits; + unsigned short fp16_man_bits = fp16_bits & 0x3FFu; + fp16_man_bits |= 0x400u; // add implicit 1 to mantissa + remainder = fp16_man_bits << (16u - shift); + result = sign | (fp16_man_bits >> shift); + result &= 0x3Fu; + } else { // FP6 underflow + result = sign; + } + + // round to nearest even + if ((remainder > 0x80u) || ((remainder == 0x80u) && ((result & 1u) == 1u))) { + result += 1; + } + + return result; +} /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. From 8bf081c8c9ea8e7c36bda53ac4a7d5f17a017b9a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 00:48:58 +0800 Subject: [PATCH 02/80] minor rename --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index e7ece50eed..3ac20b8773 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -36,17 +36,17 @@ unsigned char fp16_to_fp6(const __half a) { } else if (fp16_bits >= 0b0'10011'1110000000u) { // FP6 overflow result = sign | 0b0'111'11; } else if (fp16_bits >= 0b0'01101'0000000000u) { // FP6 normal number - remainder = fp16_bits << 8u; // truncated mantissa bits - fp16_bits -= 0b0'01100'0000000000u; // update exponent bits - fp16_bits >>= 8u; // truncate mantissa bits + remainder = fp16_bits << 8u; // truncated mantissa + fp16_bits -= 0b0'01100'0000000000u; // update exponent + fp16_bits >>= 8u; // truncate mantissa result = sign | fp16_bits; // add sign bit } else if (fp16_bits >= 0'01111010'0000000001u) { // FP6 subnormal number - unsigned short fp16_exp_bits = fp16_bits >> 10u; - unsigned short shift = 0xEu - fp16_exp_bits; - unsigned short fp16_man_bits = fp16_bits & 0x3FFu; - fp16_man_bits |= 0x400u; // add implicit 1 to mantissa - remainder = fp16_man_bits << (16u - shift); - result = sign | (fp16_man_bits >> shift); + unsigned short fp16_exp = fp16_bits >> 10u; + unsigned short shift = 0xEu - fp16_exp; + unsigned short fp16_man = fp16_bits & 0x3FFu; + fp16_man |= 0x400u; // add implicit 1 to mantissa + remainder = fp16_man << (16u - shift); + result = sign | (fp16_man >> shift); result &= 0x3Fu; } else { // FP6 underflow result = sign; From 314e9f6541ae7c95a17ba19524efb16493d9f7e1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 01:18:23 +0000 Subject: [PATCH 03/80] fix rounding issue --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 3ac20b8773..d82584cf6b 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -53,7 +53,7 @@ unsigned char fp16_to_fp6(const __half a) { } // round to nearest even - if ((remainder > 0x80u) || ((remainder == 0x80u) && ((result & 1u) == 1u))) { + if ((remainder > 0x8000u) || ((remainder == 0x8000u) && ((result & 1u) == 1u))) { result += 1; } From 79ce0db5ea9c1d52414ebf6ac25224ef40362007 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 06:50:43 +0000 Subject: [PATCH 04/80] update quant --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 41 +++++++++++------------ 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index d82584cf6b..1b5f67bae7 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -23,31 +23,30 @@ // inspired by __internal_float2half() and float2half() from "cuda_fp16.h" unsigned char fp16_to_fp6(const __half a) { - unsigned short fp16_bits; - std::memcpy(&fp16_bits, &a, sizeof(a)); + unsigned short bits; + std::memcpy(&bits, &a, sizeof(a)); - unsigned short result; unsigned short remainder = 0u; - unsigned short sign = (fp16_bits >> 15u) << 5u; - fp16_bits &= 0x7FFFu; // clear sign bit + unsigned short sign = bits >> 15u << 5u; + bits &= 0x7FFFu; // clear sign bit + unsigned short result; - if (fp16_bits >= 0b0'11111'0000000000u) { + if (bits >= 0b11111'0000000000u) { throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - } else if (fp16_bits >= 0b0'10011'1110000000u) { // FP6 overflow - result = sign | 0b0'111'11; - } else if (fp16_bits >= 0b0'01101'0000000000u) { // FP6 normal number - remainder = fp16_bits << 8u; // truncated mantissa - fp16_bits -= 0b0'01100'0000000000u; // update exponent - fp16_bits >>= 8u; // truncate mantissa - result = sign | fp16_bits; // add sign bit - } else if (fp16_bits >= 0'01111010'0000000001u) { // FP6 subnormal number - unsigned short fp16_exp = fp16_bits >> 10u; - unsigned short shift = 0xEu - fp16_exp; - unsigned short fp16_man = fp16_bits & 0x3FFu; - fp16_man |= 0x400u; // add implicit 1 to mantissa - remainder = fp16_man << (16u - shift); - result = sign | (fp16_man >> shift); - result &= 0x3Fu; + } else if (bits >= 0b10011'1110000000u) { // FP6 overflow. clamp to max + result = sign | 0b111'11u; + } else if (bits >= 0b01101'0000000000u) { // FP6 normal number + remainder = bits << 8u; + bits -= (0b01100u << 10u); // update exponent + result = sign | (bits >> 8u); + } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number + unsigned short exp = bits >> 10u; + unsigned short man = bits & 0x3FFu; + unsigned short shift = 0b01111u - 0b011u + 1u + 8u - exp; + man |= 0x400u; // set implicit 1 to mantissa + remainder = man << (16u - shift); + man >>= shift; + result = sign | man; } else { // FP6 underflow result = sign; } From 45a92f3b5628338151856f17e7bf8c3f9d851d39 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 07:31:46 +0000 Subject: [PATCH 05/80] add unpacked version --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 38 +++++++++++++++++------ torchao/csrc/fp6_llm/fp6_llm.cpp | 1 + 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 1b5f67bae7..73da3eef62 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -22,27 +22,27 @@ // inspired by __internal_float2half() and float2half() from "cuda_fp16.h" -unsigned char fp16_to_fp6(const __half a) { - unsigned short bits; +uint8_t fp16_to_fp6(const __half a) { + uint16_t bits; std::memcpy(&bits, &a, sizeof(a)); - unsigned short remainder = 0u; - unsigned short sign = bits >> 15u << 5u; + uint16_t remainder = 0u; + uint16_t sign = bits >> 15u << 5u; bits &= 0x7FFFu; // clear sign bit - unsigned short result; + uint16_t result; if (bits >= 0b11111'0000000000u) { throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - } else if (bits >= 0b10011'1110000000u) { // FP6 overflow. clamp to max - result = sign | 0b111'11u; + } else if (bits >= 0b10011'1110000000u) { // FP6 overflow + throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); } else if (bits >= 0b01101'0000000000u) { // FP6 normal number remainder = bits << 8u; bits -= (0b01100u << 10u); // update exponent result = sign | (bits >> 8u); } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number - unsigned short exp = bits >> 10u; - unsigned short man = bits & 0x3FFu; - unsigned short shift = 0b01111u - 0b011u + 1u + 8u - exp; + uint16_t exp = bits >> 10u; + uint16_t man = bits & 0x3FFu; + uint16_t shift = 0b01111u - 0b011u + 1u + 8u - exp; man |= 0x400u; // set implicit 1 to mantissa remainder = man << (16u - shift); man >>= shift; @@ -251,9 +251,27 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal return fp16_tensor; } +// this is used for debugging +at::Tensor _fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { + TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); + + __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + + for (int i = 0; i < fp16_tensor.numel(); i++) { + fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); + } + + return fp6_tensor; +} + TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); + m.impl("torchao::_fp16_to_fp6_unpacked", &_fp16_to_fp6_unpacked_cpu); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 794c79df11..014aaff04e 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -7,5 +7,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); + m.def("_fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } From a8555e3d9c9fc906ac71cbf1ed9679545ace3fa1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 07:43:32 +0000 Subject: [PATCH 06/80] remove unnecessary comment --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 73da3eef62..90ac280c2d 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -33,7 +33,7 @@ uint8_t fp16_to_fp6(const __half a) { if (bits >= 0b11111'0000000000u) { throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - } else if (bits >= 0b10011'1110000000u) { // FP6 overflow + } else if (bits >= 0b10011'1110000000u) { throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); } else if (bits >= 0b01101'0000000000u) { // FP6 normal number remainder = bits << 8u; From 012176e3f195031dacf53a0852f085b745aa98d5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 20:37:36 +0800 Subject: [PATCH 07/80] add CUDA version --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 42 ++++++++++++++++++++--- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 +- torchao/ops.py | 4 +++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 90ac280c2d..51ce166008 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -22,7 +22,7 @@ // inspired by __internal_float2half() and float2half() from "cuda_fp16.h" -uint8_t fp16_to_fp6(const __half a) { +__device__ __host__ uint8_t fp16_to_fp6(const __half a) { uint16_t bits; std::memcpy(&bits, &a, sizeof(a)); @@ -32,9 +32,13 @@ uint8_t fp16_to_fp6(const __half a) { uint16_t result; if (bits >= 0b11111'0000000000u) { +#ifndef __CUDACC__ throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); +#endif } else if (bits >= 0b10011'1110000000u) { +#ifndef __CUDACC__ throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); +#endif } else if (bits >= 0b01101'0000000000u) { // FP6 normal number remainder = bits << 8u; bits -= (0b01100u << 10u); // update exponent @@ -252,13 +256,14 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal } // this is used for debugging -at::Tensor _fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { +at::Tensor fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_cpu()); at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); - __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); for (int i = 0; i < fp16_tensor.numel(); i++) { @@ -268,10 +273,39 @@ at::Tensor _fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { return fp6_tensor; } +__global__ void fp16_to_fp6_unpacked_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + fp6_ptr[tid] = fp16_to_fp6(fp16_ptr[tid]); + } +} + +at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { + TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_cuda()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); + + const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp16_tensor.numel(); + + int block_size = 256; + int grid_size = (n + block_size - 1) / block_size; + fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + + return fp6_tensor; +} + TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); - m.impl("torchao::_fp16_to_fp6_unpacked", &_fp16_to_fp6_unpacked_cpu); + m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cpu); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cuda); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 014aaff04e..4abf655fce 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -7,6 +7,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); - m.def("_fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); + m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 3a25dbf6db..ae07b13d7c 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -42,6 +42,10 @@ def _(fp6_weight): return torch.empty_like(fp6_weight) +def fp16_to_fp6_unpacked(fp16_tensor: Tensor) -> Tensor: + return torch.ops.torchao.fp16_to_fp6_unpacked.default(fp16_tensor) + + def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor (containing only FP6 values) into FP6 tensor. From d4b8681d68acbc0e4d95bf42998008f8992d78c4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 20:58:45 +0800 Subject: [PATCH 08/80] add fp6 packed cpu --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 33 +++++++++++++++++++++++ torchao/csrc/fp6_llm/fp6_llm.cpp | 1 + torchao/ops.py | 4 +++ 3 files changed, 38 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 51ce166008..1d8043ddce 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -258,6 +258,7 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal // this is used for debugging at::Tensor fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_contiguous()); TORCH_CHECK(fp16_tensor.is_cpu()); at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); @@ -282,6 +283,7 @@ __global__ void fp16_to_fp6_unpacked_kernel(const __half *fp16_ptr, uint8_t *fp6 at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_contiguous()); TORCH_CHECK(fp16_tensor.is_cuda()); at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); @@ -298,10 +300,41 @@ at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { return fp6_tensor; } +at::Tensor fp16_to_fp6_packed_cpu(at::Tensor fp16_tensor) { + TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_contiguous()); + TORCH_CHECK(fp16_tensor.is_cpu()); + TORCH_CHECK(fp16_tensor.ndimension() == 2); + + int M = fp16_tensor.size(0); + int N = fp16_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); + + const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + + for (int i = 0, j = 0; i < fp16_tensor.numel(); i += 4, j += 3) { + uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); + uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); + uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); + uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); + + fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 + } + + return fp6_tensor; +} + TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cpu); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cpu); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 4abf655fce..a8bd97b27d 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -8,5 +8,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); + m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index ae07b13d7c..1a15f85f1c 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -46,6 +46,10 @@ def fp16_to_fp6_unpacked(fp16_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp16_to_fp6_unpacked.default(fp16_tensor) +def fp16_to_fp6_packed(fp16_tensor: Tensor) -> Tensor: + return torch.ops.torchao.fp16_to_fp6_packed.default(fp16_tensor) + + def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor (containing only FP6 values) into FP6 tensor. From f0f3101388466f922ad9769fc4d8c6ec6eb89569 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 23:12:55 +0800 Subject: [PATCH 09/80] add CUDA for packed --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 51 ++++++++++++++++++++--- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 +- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 1d8043ddce..a55d55490e 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -21,7 +21,7 @@ #include -// inspired by __internal_float2half() and float2half() from "cuda_fp16.h" +// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" __device__ __host__ uint8_t fp16_to_fp6(const __half a) { uint16_t bits; std::memcpy(&bits, &a, sizeof(a)); @@ -32,11 +32,11 @@ __device__ __host__ uint8_t fp16_to_fp6(const __half a) { uint16_t result; if (bits >= 0b11111'0000000000u) { -#ifndef __CUDACC__ +#ifndef __CUDA_ARCH__ throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); #endif } else if (bits >= 0b10011'1110000000u) { -#ifndef __CUDACC__ +#ifndef __CUDA_ARCH__ throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif } else if (bits >= 0b01101'0000000000u) { // FP6 normal number @@ -206,7 +206,7 @@ void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t namespace torchao { // https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 -at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) +at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional"); TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); @@ -293,7 +293,7 @@ at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp16_tensor.numel(); - int block_size = 256; + constexpr int block_size = 256; int grid_size = (n + block_size - 1) / block_size; fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); @@ -330,8 +330,46 @@ at::Tensor fp16_to_fp6_packed_cpu(at::Tensor fp16_tensor) { return fp6_tensor; } +__global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < n) { + uint8_t val0 = fp16_to_fp6(fp16_ptr[idx]); + uint8_t val1 = fp16_to_fp6(fp16_ptr[idx + 1]); + uint8_t val2 = fp16_to_fp6(fp16_ptr[idx + 2]); + uint8_t val3 = fp16_to_fp6(fp16_ptr[idx + 3]); + + fp6_ptr[idx / 4 * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[idx / 4 * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[idx / 4 * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } +} + +at::Tensor fp16_to_fp6_packed_cuda(at::Tensor fp16_tensor) { + TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_contiguous()); + TORCH_CHECK(fp16_tensor.is_cuda()); + TORCH_CHECK(fp16_tensor.ndimension() == 2); + + int M = fp16_tensor.size(0); + int N = fp16_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); + + const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp16_tensor.numel(); + + constexpr int block_size = 256; + int grid_size = (n + block_size * 4 - 1) / block_size * 4; + fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + + return fp6_tensor; +} + TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); + m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cpu); m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cpu); @@ -339,6 +377,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cuda); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cuda); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index a8bd97b27d..5cc853e0f7 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -6,7 +6,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); - m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); + m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); From f542eb167dc732cf6f74200248b261fb6fd227c6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 23:26:12 +0800 Subject: [PATCH 10/80] some rename --- test/test_ops.py | 2 +- torchao/ops.py | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index e260e86f0f..cfd904d349 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -76,7 +76,7 @@ def test_fp16_to_fp6(self): fp16_weight[fp16_weight.abs() < fp6_absmin] = 0 # smoke test - torchao.ops.fp16_to_fp6(fp16_weight) + torchao.ops.fp16_to_fp6_original(fp16_weight) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] diff --git a/torchao/ops.py b/torchao/ops.py index 1a15f85f1c..d608aceddf 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -46,18 +46,32 @@ def fp16_to_fp6_unpacked(fp16_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp16_to_fp6_unpacked.default(fp16_tensor) +@torch.library.impl_abstract("torchao::fp16_to_fp6_unpacked") +def _(fp16_tensor): + return torch.empty_like(fp16_tensor, dtype=torch.uint8) + + def fp16_to_fp6_packed(fp16_tensor: Tensor) -> Tensor: - return torch.ops.torchao.fp16_to_fp6_packed.default(fp16_tensor) + *leading_dims, last_dim = fp16_tensor.shape + return torch.ops.torchao.fp16_to_fp6_packed.default(fp16_tensor.view(-1, last_dim)).view(*leading_dims, -1) + + +@torch.library.impl_abstract("torchao::fp16_to_fp6_packed") +def _(fp16_tensor): + torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") + *leading_dims, last_dim = fp16_tensor.shape + torch._check(last_dim % 4 == 0, lambda: f"last dimension must be a multiple of 4, got {last_dim}") + return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp16_tensor.device, dtype=torch.uint8) -def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: +def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor (containing only FP6 values) into FP6 tensor. """ - return torch.ops.torchao.fp16_to_fp6.default(fp16_tensor) + return torch.ops.torchao.fp16_to_fp6_original.default(fp16_tensor) -@torch.library.impl_abstract("torchao::fp16_to_fp6") +@torch.library.impl_abstract("torchao::fp16_to_fp6_original") def _(fp16_tensor): torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") From 40dc7256e85ed230bea2a7552ff8e9fd5556c682 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 16 May 2024 23:27:34 +0800 Subject: [PATCH 11/80] update name --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index cfd904d349..31cb8fb970 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -80,7 +80,7 @@ def test_fp16_to_fp6(self): # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils) + opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp16act_fp6weight_linear(self): From eef2f95b2add59f977dc311eda68abfc70a474e3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 17 May 2024 02:17:26 +0000 Subject: [PATCH 12/80] add OpenMP --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index a55d55490e..bcc4f8d45e 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -266,8 +266,10 @@ at::Tensor fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp16_tensor.numel(); - for (int i = 0; i < fp16_tensor.numel(); i++) { +#pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) { fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); } @@ -315,13 +317,16 @@ at::Tensor fp16_to_fp6_packed_cpu(at::Tensor fp16_tensor) { const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp16_tensor.numel(); - for (int i = 0, j = 0; i < fp16_tensor.numel(); i += 4, j += 3) { +#pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 4) { uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); + int j = i / 4 * 3; fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 From f61aa37382be5c12ed0246427a75623590b4e9d6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 17 May 2024 05:59:21 +0000 Subject: [PATCH 13/80] fix CUDA bug --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index bcc4f8d45e..6a9f3471f6 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -297,7 +297,7 @@ at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { constexpr int block_size = 256; int grid_size = (n + block_size - 1) / block_size; - fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); return fp6_tensor; } @@ -367,8 +367,8 @@ at::Tensor fp16_to_fp6_packed_cuda(at::Tensor fp16_tensor) { int n = fp16_tensor.numel(); constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / block_size * 4; - fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); return fp6_tensor; } From 1640bbf904890972a8e9de42385e7775818310cf Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 17 May 2024 09:59:53 +0000 Subject: [PATCH 14/80] add fp6->fp16 --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 6a9f3471f6..b0099b8387 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -63,6 +63,27 @@ __device__ __host__ uint8_t fp16_to_fp6(const __half a) { return result; } +// assume the lower 6 bits contain the data +__device__ __host__ __half fp6_to_fp16(const uint8_t a) { + // we shift the bits so that sign, exponent, and mantissa bits are in their + // correct positions in FP16 + // FP6: SE EEMM + // FP16: S00E EEMM 0000 0000 + uint16_t bits = a; + uint16_t sign = (a << 10u) & 0x8000u; + uint16_t exp_and_man = (a & 0x1Fu) << 8u; + uint16_t result_bits = sign | exp_and_man; + + // the result will be off by the difference in exponent bias + // FP6: Ebias = 011 = 2^3 + // FP16: Ebias = 01111 = 2^15 + // correction = 2^12 = 4096 + // we can correct this by direct FP16 multiplication + __half result; + std::memcpy(&result, &result_bits, sizeof(result)); + return result * __float2half(4096.0f); +} + /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. */ @@ -373,11 +394,32 @@ at::Tensor fp16_to_fp6_packed_cuda(at::Tensor fp16_tensor) { return fp6_tensor; } +at::Tensor fp6_unpacked_to_fp16_cpu(at::Tensor fp6_tensor) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat16).device(fp6_tensor.device()); + at::Tensor fp16_tensor = at::empty(fp6_tensor.sizes(), options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + __half *fp16_ptr = reinterpret_cast<__half *>(fp16_tensor.data_ptr()); + int n = fp6_tensor.numel(); + +#pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) { + fp16_ptr[i] = fp6_to_fp16(fp6_ptr[i]); + } + + return fp16_tensor; +} + TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cpu); m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cpu); + m.impl("torchao::fp6_unpacked_to_fp16", &fp6_unpacked_to_fp16_cpu); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { From 7a00b3126afebecb4f94e6ebcd63e356465b70b2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 17 May 2024 22:58:37 +0800 Subject: [PATCH 15/80] add FP6->FP32 --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 225 +++++++++++++--------- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 + torchao/ops.py | 9 + 3 files changed, 141 insertions(+), 95 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index b0099b8387..c9577d1c12 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -22,7 +22,7 @@ // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -__device__ __host__ uint8_t fp16_to_fp6(const __half a) { +__device__ __host__ static uint8_t fp16_to_fp6(const __half a) { uint16_t bits; std::memcpy(&bits, &a, sizeof(a)); @@ -64,24 +64,24 @@ __device__ __host__ uint8_t fp16_to_fp6(const __half a) { } // assume the lower 6 bits contain the data -__device__ __host__ __half fp6_to_fp16(const uint8_t a) { - // we shift the bits so that sign, exponent, and mantissa bits are in their - // correct positions in FP16 - // FP6: SE EEMM - // FP16: S00E EEMM 0000 0000 - uint16_t bits = a; - uint16_t sign = (a << 10u) & 0x8000u; - uint16_t exp_and_man = (a & 0x1Fu) << 8u; - uint16_t result_bits = sign | exp_and_man; +__device__ __host__ static float fp6_to_fp32(const uint8_t a) { + // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. + // this also handles subnormal numbers correctly. + // FP6: SE EEMM + // FP32: S000 00EE EMM0 0000 0000 0000 0000 0000 + uint32_t bits = a; + uint32_t sign = bits >> 5u << 31u; + uint32_t exp_and_man = (bits & 0x1Fu) << 21u; + uint32_t result_bits = sign | exp_and_man; // the result will be off by the difference in exponent bias - // FP6: Ebias = 011 = 2^3 - // FP16: Ebias = 01111 = 2^15 - // correction = 2^12 = 4096 - // we can correct this by direct FP16 multiplication - __half result; + // FP6: Ebias = 3 + // FP32: Ebias = 127 + // correction = 2^(127-3) + // we can correct this by direct FP32 multiplication, which also handles subnormal numbers correctly. + float result; std::memcpy(&result, &result_bits, sizeof(result)); - return result * __float2half(4096.0f); + return result * 0x1p124; } /* @@ -276,38 +276,17 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal return fp16_tensor; } -// this is used for debugging -at::Tensor fp16_to_fp6_unpacked_cpu(at::Tensor fp16_tensor) { - TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); - TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cpu()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); - - const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp16_tensor.numel(); - -#pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i++) { - fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); - } - - return fp6_tensor; -} - __global__ void fp16_to_fp6_unpacked_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < n) { + if (tid < n) fp6_ptr[tid] = fp16_to_fp6(fp16_ptr[tid]); - } } -at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { +// this is useful for debugging +at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cuda()); + TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); @@ -316,41 +295,14 @@ at::Tensor fp16_to_fp6_unpacked_cuda(at::Tensor fp16_tensor) { uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp16_tensor.numel(); - constexpr int block_size = 256; - int grid_size = (n + block_size - 1) / block_size; - fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); - - return fp6_tensor; -} - -at::Tensor fp16_to_fp6_packed_cpu(at::Tensor fp16_tensor) { - TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); - TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cpu()); - TORCH_CHECK(fp16_tensor.ndimension() == 2); - - int M = fp16_tensor.size(0); - int N = fp16_tensor.size(1); - TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); - at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - - const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp16_tensor.numel(); - -#pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 4) { - uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); - uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); - uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); - uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); - - int j = i / 4 * 3; - fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 + if (fp16_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size - 1) / block_size; + fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); } return fp6_tensor; @@ -370,10 +322,10 @@ __global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_p } } -at::Tensor fp16_to_fp6_packed_cuda(at::Tensor fp16_tensor) { +at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cuda()); + TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); TORCH_CHECK(fp16_tensor.ndimension() == 2); int M = fp16_tensor.size(0); @@ -387,44 +339,127 @@ at::Tensor fp16_to_fp6_packed_cuda(at::Tensor fp16_tensor) { uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp16_tensor.numel(); - constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + if (fp16_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 4) { + uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); + uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); + uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); + uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); + + int j = i / 4 * 3; + fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 + } + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + } return fp6_tensor; } -at::Tensor fp6_unpacked_to_fp16_cpu(at::Tensor fp6_tensor) { +__global__ void fp6_unpacked_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) + fp32_ptr[idx] = fp6_to_fp32(fp6_ptr[idx]); +} + +at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu()); + TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); - at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat16).device(fp6_tensor.device()); - at::Tensor fp16_tensor = at::empty(fp6_tensor.sizes(), options); + at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); + at::Tensor fp32_tensor = at::empty(fp6_tensor.sizes(), options); const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - __half *fp16_ptr = reinterpret_cast<__half *>(fp16_tensor.data_ptr()); + float *fp32_ptr = fp32_tensor.data_ptr(); int n = fp6_tensor.numel(); -#pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i++) { - fp16_ptr[i] = fp6_to_fp16(fp6_ptr[i]); + if (fp6_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp32_ptr[i] = fp6_to_fp32(fp6_ptr[i]); + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); } - return fp16_tensor; + return fp32_tensor; +} + +__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; + if (idx < n) { + uint8_t bits0 = fp6_ptr[idx]; // 0000 0011 + uint8_t bits1 = fp6_ptr[idx + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[idx + 2]; // 2233 3333 + + int j = idx / 3 * 4; + fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); + fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); + } +} + +at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); + TORCH_CHECK(fp6_tensor.ndimension() == 2); + + int M = fp6_tensor.size(0); + int N = fp6_tensor.size(1); + TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); + at::Tensor fp32_tensor = at::empty({M, N / 3 * 4}, options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + float *fp32_ptr = fp32_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (fp6_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 3) { + uint8_t bits0 = fp6_ptr[i]; // 0000 0011 + uint8_t bits1 = fp6_ptr[i + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[i + 2]; // 2233 3333 + + int j = i / 3 * 4; + fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); + fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); + } + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 3 - 1) / (block_size * 3); + fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); + } + + return fp32_tensor; } TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); - m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cpu); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cpu); - m.impl("torchao::fp6_unpacked_to_fp16", &fp6_unpacked_to_fp16_cpu); + m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cuda); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cuda); + m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 5cc853e0f7..065ecd8d2b 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -9,5 +9,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); + m.def("fp6_unpacked_to_fp32(Tensor fp6_tensor) -> Tensor"); + m.def("fp6_packed_to_fp32(Tensor fp6_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index d608aceddf..3f3403f93e 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -64,6 +64,15 @@ def _(fp16_tensor): return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp16_tensor.device, dtype=torch.uint8) +def fp6_unpacked_to_fp32(fp6_tensor: Tensor) -> Tensor: + return torch.ops.torchao.fp6_unpacked_to_fp32.default(fp6_tensor) + + +def fp6_packed_to_fp32(fp6_tensor: Tensor) -> Tensor: + *leading_dims, last_dim = fp6_tensor.shape + return torch.ops.torchao.fp6_packed_to_fp32.default(fp6_tensor.view(-1, last_dim)).view(*leading_dims, -1) + + def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor (containing only FP6 values) into FP6 tensor. From b2fcc6ce0a21a7830c10812b15faab93f04eba52 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 17 May 2024 23:02:00 +0800 Subject: [PATCH 16/80] move files around --- torchao/csrc/cuda/fp6_llm/fp6.cu | 260 ++++++++++++++++++++++ torchao/csrc/cuda/fp6_llm/weight_quant.cu | 244 -------------------- 2 files changed, 260 insertions(+), 244 deletions(-) create mode 100644 torchao/csrc/cuda/fp6_llm/fp6.cu diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu new file mode 100644 index 0000000000..9b6c7aa235 --- /dev/null +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -0,0 +1,260 @@ +#include +#include +#include +#include + + +// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" +__device__ __host__ static uint8_t fp16_to_fp6(const __half a) { + uint16_t bits; + std::memcpy(&bits, &a, sizeof(a)); + + uint16_t remainder = 0u; + uint16_t sign = bits >> 15u << 5u; + bits &= 0x7FFFu; // clear sign bit + uint16_t result; + + if (bits >= 0b11111'0000000000u) { +#ifndef __CUDA_ARCH__ + throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); +#endif + } else if (bits >= 0b10011'1110000000u) { +#ifndef __CUDA_ARCH__ + throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); +#endif + } else if (bits >= 0b01101'0000000000u) { // FP6 normal number + remainder = bits << 8u; + bits -= (0b01100u << 10u); // update exponent + result = sign | (bits >> 8u); + } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number + uint16_t exp = bits >> 10u; + uint16_t man = bits & 0x3FFu; + uint16_t shift = 0b01111u - 0b011u + 1u + 8u - exp; + man |= 0x400u; // set implicit 1 to mantissa + remainder = man << (16u - shift); + man >>= shift; + result = sign | man; + } else { // FP6 underflow + result = sign; + } + + // round to nearest even + if ((remainder > 0x8000u) || ((remainder == 0x8000u) && ((result & 1u) == 1u))) { + result += 1; + } + + return result; +} + +// assume the lower 6 bits contain the data +__device__ __host__ static float fp6_to_fp32(const uint8_t a) { + // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. + // this also handles subnormal numbers correctly. + // FP6: SE EEMM + // FP32: S000 00EE EMM0 0000 0000 0000 0000 0000 + uint32_t bits = a; + uint32_t sign = bits >> 5u << 31u; + uint32_t exp_and_man = (bits & 0x1Fu) << 21u; + uint32_t result_bits = sign | exp_and_man; + + // the result will be off by the difference in exponent bias + // FP6: Ebias = 3 + // FP32: Ebias = 127 + // correction = 2^(127-3) + // we can correct this by direct FP32 multiplication, which also handles subnormal numbers correctly. + float result; + std::memcpy(&result, &result_bits, sizeof(result)); + return result * 0x1p124; +} + +#include +#include +#include + +namespace torchao { + +__global__ void fp16_to_fp6_unpacked_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) + fp6_ptr[tid] = fp16_to_fp6(fp16_ptr[tid]); +} + +// this is useful for debugging +at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { + TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_contiguous()); + TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); + + const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp16_tensor.numel(); + + if (fp16_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size - 1) / block_size; + fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + } + + return fp6_tensor; +} + +__global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < n) { + uint8_t val0 = fp16_to_fp6(fp16_ptr[idx]); + uint8_t val1 = fp16_to_fp6(fp16_ptr[idx + 1]); + uint8_t val2 = fp16_to_fp6(fp16_ptr[idx + 2]); + uint8_t val3 = fp16_to_fp6(fp16_ptr[idx + 3]); + + fp6_ptr[idx / 4 * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[idx / 4 * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[idx / 4 * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } +} + +at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { + TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); + TORCH_CHECK(fp16_tensor.is_contiguous()); + TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); + TORCH_CHECK(fp16_tensor.ndimension() == 2); + + int M = fp16_tensor.size(0); + int N = fp16_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); + + const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp16_tensor.numel(); + + if (fp16_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 4) { + uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); + uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); + uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); + uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); + + int j = i / 4 * 3; + fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 + } + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + } + + return fp6_tensor; +} + +__global__ void fp6_unpacked_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) + fp32_ptr[idx] = fp6_to_fp32(fp6_ptr[idx]); +} + +at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); + at::Tensor fp32_tensor = at::empty(fp6_tensor.sizes(), options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + float *fp32_ptr = fp32_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (fp6_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp32_ptr[i] = fp6_to_fp32(fp6_ptr[i]); + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); + } + + return fp32_tensor; +} + +__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; + if (idx < n) { + uint8_t bits0 = fp6_ptr[idx]; // 0000 0011 + uint8_t bits1 = fp6_ptr[idx + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[idx + 2]; // 2233 3333 + + int j = idx / 3 * 4; + fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); + fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); + } +} + +at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); + TORCH_CHECK(fp6_tensor.ndimension() == 2); + + int M = fp6_tensor.size(0); + int N = fp6_tensor.size(1); + TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); + at::Tensor fp32_tensor = at::empty({M, N / 3 * 4}, options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + float *fp32_ptr = fp32_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (fp6_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 3) { + uint8_t bits0 = fp6_ptr[i]; // 0000 0011 + uint8_t bits1 = fp6_ptr[i + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[i + 2]; // 2233 3333 + + int j = i / 3 * 4; + fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); + fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); + } + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 3 - 1) / (block_size * 3); + fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); + } + + return fp32_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); +} + +} diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index c9577d1c12..e6220c1a24 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -21,69 +21,6 @@ #include -// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -__device__ __host__ static uint8_t fp16_to_fp6(const __half a) { - uint16_t bits; - std::memcpy(&bits, &a, sizeof(a)); - - uint16_t remainder = 0u; - uint16_t sign = bits >> 15u << 5u; - bits &= 0x7FFFu; // clear sign bit - uint16_t result; - - if (bits >= 0b11111'0000000000u) { -#ifndef __CUDA_ARCH__ - throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); -#endif - } else if (bits >= 0b10011'1110000000u) { -#ifndef __CUDA_ARCH__ - throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); -#endif - } else if (bits >= 0b01101'0000000000u) { // FP6 normal number - remainder = bits << 8u; - bits -= (0b01100u << 10u); // update exponent - result = sign | (bits >> 8u); - } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number - uint16_t exp = bits >> 10u; - uint16_t man = bits & 0x3FFu; - uint16_t shift = 0b01111u - 0b011u + 1u + 8u - exp; - man |= 0x400u; // set implicit 1 to mantissa - remainder = man << (16u - shift); - man >>= shift; - result = sign | man; - } else { // FP6 underflow - result = sign; - } - - // round to nearest even - if ((remainder > 0x8000u) || ((remainder == 0x8000u) && ((result & 1u) == 1u))) { - result += 1; - } - - return result; -} - -// assume the lower 6 bits contain the data -__device__ __host__ static float fp6_to_fp32(const uint8_t a) { - // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. - // this also handles subnormal numbers correctly. - // FP6: SE EEMM - // FP32: S000 00EE EMM0 0000 0000 0000 0000 0000 - uint32_t bits = a; - uint32_t sign = bits >> 5u << 31u; - uint32_t exp_and_man = (bits & 0x1Fu) << 21u; - uint32_t result_bits = sign | exp_and_man; - - // the result will be off by the difference in exponent bias - // FP6: Ebias = 3 - // FP32: Ebias = 127 - // correction = 2^(127-3) - // we can correct this by direct FP32 multiplication, which also handles subnormal numbers correctly. - float result; - std::memcpy(&result, &result_bits, sizeof(result)); - return result * 0x1p124; -} - /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. */ @@ -276,190 +213,9 @@ at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scal return fp16_tensor; } -__global__ void fp16_to_fp6_unpacked_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < n) - fp6_ptr[tid] = fp16_to_fp6(fp16_ptr[tid]); -} - -// this is useful for debugging -at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { - TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); - TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); - - const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp16_tensor.numel(); - - if (fp16_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i++) - fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); - } else { - constexpr int block_size = 256; - int grid_size = (n + block_size - 1) / block_size; - fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); - } - - return fp6_tensor; -} - -__global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) { - uint8_t val0 = fp16_to_fp6(fp16_ptr[idx]); - uint8_t val1 = fp16_to_fp6(fp16_ptr[idx + 1]); - uint8_t val2 = fp16_to_fp6(fp16_ptr[idx + 2]); - uint8_t val3 = fp16_to_fp6(fp16_ptr[idx + 3]); - - fp6_ptr[idx / 4 * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[idx / 4 * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[idx / 4 * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - } -} - -at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { - TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); - TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); - TORCH_CHECK(fp16_tensor.ndimension() == 2); - - int M = fp16_tensor.size(0); - int N = fp16_tensor.size(1); - TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); - at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - - const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp16_tensor.numel(); - - if (fp16_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 4) { - uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); - uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); - uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); - uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); - - int j = i / 4 * 3; - fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 - } - } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); - } - - return fp6_tensor; -} - -__global__ void fp6_unpacked_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) - fp32_ptr[idx] = fp6_to_fp32(fp6_ptr[idx]); -} - -at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); - at::Tensor fp32_tensor = at::empty(fp6_tensor.sizes(), options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - float *fp32_ptr = fp32_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - if (fp6_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i++) - fp32_ptr[i] = fp6_to_fp32(fp6_ptr[i]); - } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); - } - - return fp32_tensor; -} - -__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - if (idx < n) { - uint8_t bits0 = fp6_ptr[idx]; // 0000 0011 - uint8_t bits1 = fp6_ptr[idx + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[idx + 2]; // 2233 3333 - - int j = idx / 3 * 4; - fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); - fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); - } -} - -at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); - TORCH_CHECK(fp6_tensor.ndimension() == 2); - - int M = fp6_tensor.size(0); - int N = fp6_tensor.size(1); - TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); - at::Tensor fp32_tensor = at::empty({M, N / 3 * 4}, options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - float *fp32_ptr = fp32_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - if (fp6_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 3) { - uint8_t bits0 = fp6_ptr[i]; // 0000 0011 - uint8_t bits1 = fp6_ptr[i + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[i + 2]; // 2233 3333 - - int j = i / 3 * 4; - fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); - fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); - } - } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 3 - 1) / (block_size * 3); - fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); - } - - return fp32_tensor; -} - TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); - m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); - m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); - m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } } From d9ca476c9b6561d7512c97c911ead64acb91d123 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 17 May 2024 23:11:19 +0800 Subject: [PATCH 17/80] rearrange stuff --- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 2 -- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index e6220c1a24..9aa78858fe 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -18,8 +18,6 @@ #include #include #include -#include - /* * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 065ecd8d2b..18e1600532 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -7,9 +7,9 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); + m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); m.def("fp6_unpacked_to_fp32(Tensor fp6_tensor) -> Tensor"); m.def("fp6_packed_to_fp32(Tensor fp6_tensor) -> Tensor"); - m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } From ba89a0b08d2ecf0c16b9a973430bc83bd33b41d6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 18 May 2024 10:02:48 +0800 Subject: [PATCH 18/80] add more things --- torchao/csrc/cuda/fp6_llm/fp6.cu | 135 +++++++++++++++++++++++++++++-- torchao/csrc/fp6_llm/fp6_llm.cpp | 1 + torchao/ops.py | 5 ++ 3 files changed, 135 insertions(+), 6 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 9b6c7aa235..08a2b2af71 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -14,15 +14,14 @@ __device__ __host__ static uint8_t fp16_to_fp6(const __half a) { bits &= 0x7FFFu; // clear sign bit uint16_t result; - if (bits >= 0b11111'0000000000u) { #ifndef __CUDA_ARCH__ + if (bits >= 0b11111'0000000000u) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); -#endif - } else if (bits >= 0b10011'1110000000u) { -#ifndef __CUDA_ARCH__ + if (bits >= 0b10011'1110000000u) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif - } else if (bits >= 0b01101'0000000000u) { // FP6 normal number + + if (bits >= 0b01101'0000000000u) { // FP6 normal number remainder = bits << 8u; bits -= (0b01100u << 10u); // update exponent result = sign | (bits >> 8u); @@ -46,6 +45,75 @@ __device__ __host__ static uint8_t fp16_to_fp6(const __half a) { return result; } +__device__ __host__ static uint8_t fp32_to_fp6_v1(float a) { +#ifndef __CUDA_ARCH__ + if (std::isnan(a) | std::isinf(a)) + throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); + if (std::abs(a) >= 30.0f) // 2^4 * (1 + 0.5 + 0.25 + 0.125) + throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); +#endif + + a *= 0x1p-124; + uint32_t bits; + std::memcpy(&bits, &a, sizeof(a)); + + uint8_t sign = bits >> 31u << 5u; + uint8_t exp_and_man = (bits >> 21u) & 0x1Fu; + uint8_t result = sign | exp_and_man; + + // round to nearest even + uint32_t remainder = bits << 11u; + if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && ((result & 1u) == 1u))) { + result += 1; + } + + return result; +} + +// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" +__device__ __host__ static uint8_t fp32_to_fp6_v2(const float a) { + uint32_t bits; + std::memcpy(&bits, &a, sizeof(a)); + + uint32_t remainder = 0u; + uint32_t sign = bits >> 31u << 5u; + bits &= 0x7FFF'FFFFu; // clear sign bit + uint32_t result; + +#ifndef __CUDA_ARCH__ + constexpr uint32_t exp_max = 7u + 127u - 3u; + if (bits >= 0x7F80'0000u) + throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); + if (bits >= ((exp_max) << 23u) | (0b111 << 20u)) + throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); +#endif + + if (bits >= ((1u + 15u - 3u) << 23u)) { // FP6 normal number + remainder = bits << 8u; + bits -= (0b01100u << 10u); // update exponent + result = sign | (bits >> 8u); + } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number + uint32_t exp = bits >> 10u; + uint32_t man = bits & 0x3FFu; + uint32_t shift = 0b01111u - 0b011u + 1u + 8u - exp; + man |= 0x400u; // set implicit 1 to mantissa + remainder = man << (16u - shift); + man >>= shift; + result = sign | man; + } else { // FP6 underflow + result = sign; + } + + // round to nearest even + if ((remainder > 0x8000u) || ((remainder == 0x8000u) && ((result & 1u) == 1u))) { + result += 1; + } + + return result; +} + +#define fp32_to_fp6 fp32_to_fp6_v1 + // assume the lower 6 bits contain the data __device__ __host__ static float fp6_to_fp32(const uint8_t a) { // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. @@ -158,6 +226,59 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { return fp6_tensor; } +__global__ void fp32_to_fp6_packed_kernel(const float *fp32_ptr, uint8_t *fp6_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < n) { + uint8_t val0 = fp32_to_fp6(fp32_ptr[idx]); + uint8_t val1 = fp32_to_fp6(fp32_ptr[idx + 1]); + uint8_t val2 = fp32_to_fp6(fp32_ptr[idx + 2]); + uint8_t val3 = fp32_to_fp6(fp32_ptr[idx + 3]); + + fp6_ptr[idx / 4 * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[idx / 4 * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[idx / 4 * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } +} + +at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { + TORCH_CHECK(fp32_tensor.dtype() == torch::kFloat32); + TORCH_CHECK(fp32_tensor.is_contiguous()); + TORCH_CHECK(fp32_tensor.is_cpu() || fp32_tensor.is_cuda()); + TORCH_CHECK(fp32_tensor.ndimension() == 2); + + int M = fp32_tensor.size(0); + int N = fp32_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp32_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); + + const float *fp32_ptr = fp32_tensor.data_ptr(); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp32_tensor.numel(); + + if (fp32_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 4) { + uint8_t val0 = fp32_to_fp6(fp32_ptr[i]); + uint8_t val1 = fp32_to_fp6(fp32_ptr[i + 1]); + uint8_t val2 = fp32_to_fp6(fp32_ptr[i + 2]); + uint8_t val3 = fp32_to_fp6(fp32_ptr[i + 3]); + + int j = i / 4 * 3; + fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 + } + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + fp32_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); + } + + return fp6_tensor; +} + __global__ void fp6_unpacked_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) @@ -237,7 +358,7 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { } else { constexpr int block_size = 256; int grid_size = (n + block_size * 3 - 1) / (block_size * 3); - fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); + fp6_packed_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); } return fp32_tensor; @@ -246,6 +367,7 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } @@ -253,6 +375,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 18e1600532..b7acdc5779 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -10,6 +10,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); + m.def("fp32_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); m.def("fp6_unpacked_to_fp32(Tensor fp6_tensor) -> Tensor"); m.def("fp6_packed_to_fp32(Tensor fp6_tensor) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 3f3403f93e..dea92041f7 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -64,6 +64,11 @@ def _(fp16_tensor): return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp16_tensor.device, dtype=torch.uint8) +def fp32_to_fp6_packed(fp32_tensor: Tensor) -> Tensor: + *leading_dims, last_dim = fp32_tensor.shape + return torch.ops.torchao.fp32_to_fp6_packed.default(fp32_tensor.view(-1, last_dim)).view(*leading_dims, -1) + + def fp6_unpacked_to_fp32(fp6_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp6_unpacked_to_fp32.default(fp6_tensor) From e7b3135a8b17d71fad5e612aaf697d88e2fe397f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 18 May 2024 13:17:13 +0800 Subject: [PATCH 19/80] update --- torchao/csrc/cuda/fp6_llm/fp6.cu | 48 +++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 08a2b2af71..2aa1d8ddf3 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -80,32 +80,46 @@ __device__ __host__ static uint8_t fp32_to_fp6_v2(const float a) { bits &= 0x7FFF'FFFFu; // clear sign bit uint32_t result; + constexpr uint32_t EXP_BIAS_DIFF = 127u - 3u; + + // only checks for invalid values on CPU, since we can't throw exception in CUDA #ifndef __CUDA_ARCH__ - constexpr uint32_t exp_max = 7u + 127u - 3u; - if (bits >= 0x7F80'0000u) + // all exponent bits are 1s + if (bits >= (255u << 23u)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - if (bits >= ((exp_max) << 23u) | (0b111 << 20u)) + + // FP6 overflow when FP32 value is more than (or equal to) half way above max FP6 value + // max FP6 is E=111, M=11. add extra 1 to M to get half way above it. + if (bits >= (((EXP_BIAS_DIFF + 7u) << 23u) | (0b111 << 20u))) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif - if (bits >= ((1u + 15u - 3u) << 23u)) { // FP6 normal number - remainder = bits << 8u; - bits -= (0b01100u << 10u); // update exponent - result = sign | (bits >> 8u); - } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number - uint32_t exp = bits >> 10u; - uint32_t man = bits & 0x3FFu; - uint32_t shift = 0b01111u - 0b011u + 1u + 8u - exp; - man |= 0x400u; // set implicit 1 to mantissa - remainder = man << (16u - shift); + // min FP6 subnormal number is 2^(-2) * 2^(-2) + + if (bits >= ((EXP_BIAS_DIFF + 1u) << 23u)) { // FP6 normal number (E>=001) + remainder = bits << (1u + 8u + 2u); + bits -= (EXP_BIAS_DIFF << 23u); // update exponent + result = sign | (bits >> 21u); + } else if (bits > ((EXP_BIAS_DIFF - 2u) << 23u)) { // FP6 subnormal number + uint32_t exp = bits >> 23u; + uint32_t man = bits & 0x7F'FFFFu; + + // to make subnormal FP6 from normal FP16 + // step 1: add implicit 1 to mantissa + man |= 0x80'0000u; + + // step 2: shift mantissa right so that exponent value is equal to + // FP6 subnormal exponent value, which is -2 + uint32_t shift = 127u - 2u - exp; + remainder = man << (1u + 8u + 2u + shift); man >>= shift; - result = sign | man; - } else { // FP6 underflow - result = sign; + result = sign | (man >> 21u); // implicit E=000 + } else { // FP6 underflow + result = sign; // implicit E=000 and M=00 } // round to nearest even - if ((remainder > 0x8000u) || ((remainder == 0x8000u) && ((result & 1u) == 1u))) { + if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && ((result & 1u) == 1u))) { result += 1; } From 8b3ac041237a3dc9104701f688a8a0ef3b5ef32d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 18 May 2024 18:29:21 +0800 Subject: [PATCH 20/80] update. add comments --- torchao/csrc/cuda/fp6_llm/fp6.cu | 267 +++++++++++++++---------------- torchao/ops.py | 6 + 2 files changed, 135 insertions(+), 138 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 2aa1d8ddf3..464bfd7871 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -3,6 +3,31 @@ #include #include +// reference implementation. this doesn't have a lot of bit manipulation, so it's less error-prone +__device__ __host__ static uint8_t fp32_to_fp6_ref(float a) { +#ifndef __CUDA_ARCH__ + if (std::isnan(a) | std::isinf(a)) + throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); + if (std::abs(a) >= 30.0f) + throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); +#endif + + a *= 0x1p-124; // 2^(127-3) + uint32_t bits; + std::memcpy(&bits, &a, sizeof(a)); + + uint8_t sign = bits >> 31u << 5u; + uint8_t exp_and_man = (bits >> 21u) & 0x1Fu; + uint8_t result = sign | exp_and_man; + + // round to nearest even + uint32_t remainder = bits << 11u; + if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 1u))) { + result += 1; + } + + return result; +} // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" __device__ __host__ static uint8_t fp16_to_fp6(const __half a) { @@ -14,64 +39,53 @@ __device__ __host__ static uint8_t fp16_to_fp6(const __half a) { bits &= 0x7FFFu; // clear sign bit uint16_t result; + constexpr uint16_t EXP_BIAS_DIFF = 15u - 3u; + + // only checks for invalid values on CPU, since we can't throw exception in CUDA #ifndef __CUDA_ARCH__ - if (bits >= 0b11111'0000000000u) + // all exponent bits are 1s + if (bits >= (0x1Fu << 10u)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - if (bits >= 0b10011'1110000000u) + // max FP6 (28) + half of least significand (2) = 30 + if (bits >= (((EXP_BIAS_DIFF + 7u) << 10u) | (0x7u << 7u))) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif - if (bits >= 0b01101'0000000000u) { // FP6 normal number - remainder = bits << 8u; - bits -= (0b01100u << 10u); // update exponent - result = sign | (bits >> 8u); - } else if (bits >= 0b01010'0000000001u) { // FP6 subnormal number + // FP6 normal number (E>=001) + if (bits >= ((EXP_BIAS_DIFF + 1u) << 10u)) { + remainder = bits << (1u + 5u + 2u); + bits -= (EXP_BIAS_DIFF << 10u); // update exponent + result = sign | (bits >> (10u - 2u)); + } + // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) + else if (bits > ((EXP_BIAS_DIFF - 2u) << 10u)) { uint16_t exp = bits >> 10u; uint16_t man = bits & 0x3FFu; - uint16_t shift = 0b01111u - 0b011u + 1u + 8u - exp; - man |= 0x400u; // set implicit 1 to mantissa - remainder = man << (16u - shift); - man >>= shift; - result = sign | man; - } else { // FP6 underflow - result = sign; - } - - // round to nearest even - if ((remainder > 0x8000u) || ((remainder == 0x8000u) && ((result & 1u) == 1u))) { - result += 1; - } - - return result; -} -__device__ __host__ static uint8_t fp32_to_fp6_v1(float a) { -#ifndef __CUDA_ARCH__ - if (std::isnan(a) | std::isinf(a)) - throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - if (std::abs(a) >= 30.0f) // 2^4 * (1 + 0.5 + 0.25 + 0.125) - throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); -#endif - - a *= 0x1p-124; - uint32_t bits; - std::memcpy(&bits, &a, sizeof(a)); + // to make subnormal FP6 from normal FP16 + // step 1: add implicit 1 to mantissa + man |= 0x400u; - uint8_t sign = bits >> 31u << 5u; - uint8_t exp_and_man = (bits >> 21u) & 0x1Fu; - uint8_t result = sign | exp_and_man; + // step 2: shift mantissa right so that exponent value is equal to + // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) + uint16_t shift = EXP_BIAS_DIFF + 1u - exp; + remainder = man << (1u + 5u + 2u + shift); + result = sign | (man >> (shift + (10u - 2u))); // implicit E=000 + } + // FP6 underflow. E=000, M=00 + else { + result = sign; + } // round to nearest even - uint32_t remainder = bits << 11u; - if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && ((result & 1u) == 1u))) { + if ((remainder > 0x8000u) || ((remainder == 0x8000u) && (result & 1u))) { result += 1; } - return result; } // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -__device__ __host__ static uint8_t fp32_to_fp6_v2(const float a) { +__device__ __host__ static uint8_t fp32_to_fp6(const float a) { uint32_t bits; std::memcpy(&bits, &a, sizeof(a)); @@ -85,22 +99,21 @@ __device__ __host__ static uint8_t fp32_to_fp6_v2(const float a) { // only checks for invalid values on CPU, since we can't throw exception in CUDA #ifndef __CUDA_ARCH__ // all exponent bits are 1s - if (bits >= (255u << 23u)) + if (bits >= (0xFFu << 23u)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - - // FP6 overflow when FP32 value is more than (or equal to) half way above max FP6 value - // max FP6 is E=111, M=11. add extra 1 to M to get half way above it. - if (bits >= (((EXP_BIAS_DIFF + 7u) << 23u) | (0b111 << 20u))) + // max FP6 (28) + half of least significand (2) = 30 + if (bits >= (((EXP_BIAS_DIFF + 7u) << 23u) | (0x7u << 20u))) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif - // min FP6 subnormal number is 2^(-2) * 2^(-2) - - if (bits >= ((EXP_BIAS_DIFF + 1u) << 23u)) { // FP6 normal number (E>=001) + // FP6 normal number (E>=001) + if (bits >= ((EXP_BIAS_DIFF + 1u) << 23u)) { remainder = bits << (1u + 8u + 2u); - bits -= (EXP_BIAS_DIFF << 23u); // update exponent - result = sign | (bits >> 21u); - } else if (bits > ((EXP_BIAS_DIFF - 2u) << 23u)) { // FP6 subnormal number + bits -= (EXP_BIAS_DIFF << 23u); // update exponent + result = sign | (bits >> (23u - 2u)); + } + // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) + else if (bits > ((EXP_BIAS_DIFF - 2u) << 23u)) { uint32_t exp = bits >> 23u; uint32_t man = bits & 0x7F'FFFFu; @@ -109,24 +122,24 @@ __device__ __host__ static uint8_t fp32_to_fp6_v2(const float a) { man |= 0x80'0000u; // step 2: shift mantissa right so that exponent value is equal to - // FP6 subnormal exponent value, which is -2 - uint32_t shift = 127u - 2u - exp; + // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) + uint32_t shift = EXP_BIAS_DIFF + 1u - exp; remainder = man << (1u + 8u + 2u + shift); - man >>= shift; - result = sign | (man >> 21u); // implicit E=000 - } else { // FP6 underflow - result = sign; // implicit E=000 and M=00 + result = sign | (man >> (shift + (23u - 2u))); // implicit E=000 + } + // FP6 underflow. E=000, M=00 + else { + result = sign; } // round to nearest even - if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && ((result & 1u) == 1u))) { + if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 1u))) { result += 1; } - return result; } -#define fp32_to_fp6 fp32_to_fp6_v1 +#define fp32_to_fp6 fp32_to_fp6 // assume the lower 6 bits contain the data __device__ __host__ static float fp6_to_fp32(const uint8_t a) { @@ -134,19 +147,16 @@ __device__ __host__ static float fp6_to_fp32(const uint8_t a) { // this also handles subnormal numbers correctly. // FP6: SE EEMM // FP32: S000 00EE EMM0 0000 0000 0000 0000 0000 - uint32_t bits = a; + uint32_t bits = a; // bit extension uint32_t sign = bits >> 5u << 31u; uint32_t exp_and_man = (bits & 0x1Fu) << 21u; uint32_t result_bits = sign | exp_and_man; - // the result will be off by the difference in exponent bias - // FP6: Ebias = 3 - // FP32: Ebias = 127 - // correction = 2^(127-3) - // we can correct this by direct FP32 multiplication, which also handles subnormal numbers correctly. + // the result will be off by the difference in exponent bias (3 in FP6 and 127 in FP32) + // we can correct this by direct FP32 multiplication, which also handles subnormal numbers. float result; std::memcpy(&result, &result_bits, sizeof(result)); - return result * 0x1p124; + return result * 0x1p124; // 2^(127-3) } #include @@ -187,18 +197,21 @@ at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { return fp6_tensor; } +__device__ __host__ static void _fp16_to_fp6_packed(const __half *fp16_ptr, uint8_t *fp6_ptr) { + uint8_t val0 = fp16_to_fp6(fp16_ptr[0]); + uint8_t val1 = fp16_to_fp6(fp16_ptr[1]); + uint8_t val2 = fp16_to_fp6(fp16_ptr[2]); + uint8_t val3 = fp16_to_fp6(fp16_ptr[3]); + + fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 +} + __global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) { - uint8_t val0 = fp16_to_fp6(fp16_ptr[idx]); - uint8_t val1 = fp16_to_fp6(fp16_ptr[idx + 1]); - uint8_t val2 = fp16_to_fp6(fp16_ptr[idx + 2]); - uint8_t val3 = fp16_to_fp6(fp16_ptr[idx + 3]); - - fp6_ptr[idx / 4 * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[idx / 4 * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[idx / 4 * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - } + if (idx < n) + _fp16_to_fp6_packed(fp16_ptr + idx, fp6_ptr + idx / 4 * 3); } at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { @@ -220,17 +233,8 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { if (fp16_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 4) { - uint8_t val0 = fp16_to_fp6(fp16_ptr[i]); - uint8_t val1 = fp16_to_fp6(fp16_ptr[i + 1]); - uint8_t val2 = fp16_to_fp6(fp16_ptr[i + 2]); - uint8_t val3 = fp16_to_fp6(fp16_ptr[i + 3]); - - int j = i / 4 * 3; - fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 - } + for (int i = 0; i < n; i += 4) + _fp16_to_fp6_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 4 - 1) / (block_size * 4); @@ -240,18 +244,21 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { return fp6_tensor; } +__device__ __host__ static void _fp32_to_fp6_packed(const float *fp32_ptr, uint8_t *fp6_ptr) { + uint8_t val0 = fp32_to_fp6(fp32_ptr[0]); + uint8_t val1 = fp32_to_fp6(fp32_ptr[1]); + uint8_t val2 = fp32_to_fp6(fp32_ptr[2]); + uint8_t val3 = fp32_to_fp6(fp32_ptr[3]); + + fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 +} + __global__ void fp32_to_fp6_packed_kernel(const float *fp32_ptr, uint8_t *fp6_ptr, int n) { const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) { - uint8_t val0 = fp32_to_fp6(fp32_ptr[idx]); - uint8_t val1 = fp32_to_fp6(fp32_ptr[idx + 1]); - uint8_t val2 = fp32_to_fp6(fp32_ptr[idx + 2]); - uint8_t val3 = fp32_to_fp6(fp32_ptr[idx + 3]); - - fp6_ptr[idx / 4 * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[idx / 4 * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[idx / 4 * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - } + if (idx < n) + _fp32_to_fp6_packed(fp32_ptr + idx, fp6_ptr + idx / 4 * 3); } at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { @@ -273,17 +280,8 @@ at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { if (fp32_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 4) { - uint8_t val0 = fp32_to_fp6(fp32_ptr[i]); - uint8_t val1 = fp32_to_fp6(fp32_ptr[i + 1]); - uint8_t val2 = fp32_to_fp6(fp32_ptr[i + 2]); - uint8_t val3 = fp32_to_fp6(fp32_ptr[i + 3]); - - int j = i / 4 * 3; - fp6_ptr[j] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[j + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[j + 2] = (val2 << 6) | (val3); // 2233 3333 - } + for (int i = 0; i < n; i += 4) + _fp32_to_fp6_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 4 - 1) / (block_size * 4); @@ -324,19 +322,21 @@ at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { return fp32_tensor; } +__device__ __host__ static void _fp6_packed_to_fp32(const uint8_t *fp6_ptr, float *fp32_ptr) { + uint8_t bits0 = fp6_ptr[0]; // 0000 0011 + uint8_t bits1 = fp6_ptr[1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[2]; // 2233 3333 + + fp32_ptr[0] = fp6_to_fp32(bits0 >> 2); + fp32_ptr[1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp32_ptr[2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp32_ptr[3] = fp6_to_fp32(bits2 & 0x3Fu); +} + __global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - if (idx < n) { - uint8_t bits0 = fp6_ptr[idx]; // 0000 0011 - uint8_t bits1 = fp6_ptr[idx + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[idx + 2]; // 2233 3333 - - int j = idx / 3 * 4; - fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); - fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); - } + if (idx < n) + _fp6_packed_to_fp32(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); } at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { @@ -358,17 +358,8 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { if (fp6_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 3) { - uint8_t bits0 = fp6_ptr[i]; // 0000 0011 - uint8_t bits1 = fp6_ptr[i + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[i + 2]; // 2233 3333 - - int j = i / 3 * 4; - fp32_ptr[j] = fp6_to_fp32(bits0 >> 2); - fp32_ptr[j + 1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp32_ptr[j + 2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp32_ptr[j + 3] = fp6_to_fp32(bits2 & 0x3Fu); - } + for (int i = 0; i < n; i += 3) + _fp6_packed_to_fp32(fp6_ptr + i, fp32_ptr + i / 3 * 4); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 3 - 1) / (block_size * 3); @@ -380,18 +371,18 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); - m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); + m.impl("torchao::fp16_to_fp6_packed", &_fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_packed", &_fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &_fp6_packed_to_fp32); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); - m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); + m.impl("torchao::fp16_to_fp6_packed", &_fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_packed", &_fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &_fp6_packed_to_fp32); } } diff --git a/torchao/ops.py b/torchao/ops.py index dea92041f7..b2959d2e6b 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -82,6 +82,12 @@ def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor (containing only FP6 values) into FP6 tensor. """ + try: + from qtorch.quant import float_quantize + except ImportError as e: + raise RuntimeError("Please install qtorch to use this function") from e + + fp16_tensor = float_quantize(fp16_tensor.float(), 3, 2, rounding="nearest").half() return torch.ops.torchao.fp16_to_fp6_original.default(fp16_tensor) From 063588254cf0f6515f62a1f1d8894c0856b4ff3c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 18 May 2024 20:25:43 +0800 Subject: [PATCH 21/80] some rename. add some tests --- test/test_ops.py | 39 ++++++++++++++++++++++++++++++++ torchao/csrc/cuda/fp6_llm/fp6.cu | 29 ++++++++++++------------ 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 31cb8fb970..e810198048 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -136,5 +136,44 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): assert relative_error.mean() < 1e-2 +class TestFp6(TestCase): + def _skip_cpu(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA not available. We don't compile for CPU-only build") + + @parameterized.expand( + [ + (0.0, 0b000000), # simple values + (1.0, 0b001100), # normal numbers + (1.25, 0b001101), + (28.0, 0b011111), # max + (0.1875, 0b00011), # subnormal number + (0.0625, 0b000001), # min + (29.0, 0b011111), # rounding + (26.0, 0b011110), # round to nearest even + (0.03, 0b000000), # underflow + ] + ) + def test_to_fp6_correctness(self, input, output): + self._skip_cpu() + configs = [ + (torch.half, torchao.ops.fp16_to_fp6_unpacked), + # (torch.float, torchao.ops.fp32_to_fp6_unpacked), + ] + for dtype, func in configs: + x = torch.tensor(input, dtype=dtype) + assert func(x).item() == output + assert func(-x).item() == (output | 0b100000) + assert func(x.cuda()).item() == output + assert func(-x.cuda()).item() == (output | 0b100000) + + @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) + def test_fp16_to_fp6_exception(self, input): + self._skip_cpu() + x = torch.tensor(input).half() + with self.assertRaises(Exception): + torchao.ops.fp16_to_fp6_unpacked(x) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 464bfd7871..1b85c3db48 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -85,7 +85,7 @@ __device__ __host__ static uint8_t fp16_to_fp6(const __half a) { } // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -__device__ __host__ static uint8_t fp32_to_fp6(const float a) { +__device__ __host__ static uint8_t fp32_to_fp6_bits(const float a) { uint32_t bits; std::memcpy(&bits, &a, sizeof(a)); @@ -139,7 +139,8 @@ __device__ __host__ static uint8_t fp32_to_fp6(const float a) { return result; } -#define fp32_to_fp6 fp32_to_fp6 +// #define fp32_to_fp6 fp32_to_fp6_ref +#define fp32_to_fp6 fp32_to_fp6_bits // assume the lower 6 bits contain the data __device__ __host__ static float fp6_to_fp32(const uint8_t a) { @@ -197,7 +198,7 @@ at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { return fp6_tensor; } -__device__ __host__ static void _fp16_to_fp6_packed(const __half *fp16_ptr, uint8_t *fp6_ptr) { +__device__ __host__ static void fp16_4_to_fp6_4_packed(const __half *fp16_ptr, uint8_t *fp6_ptr) { uint8_t val0 = fp16_to_fp6(fp16_ptr[0]); uint8_t val1 = fp16_to_fp6(fp16_ptr[1]); uint8_t val2 = fp16_to_fp6(fp16_ptr[2]); @@ -211,7 +212,7 @@ __device__ __host__ static void _fp16_to_fp6_packed(const __half *fp16_ptr, uint __global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; if (idx < n) - _fp16_to_fp6_packed(fp16_ptr + idx, fp6_ptr + idx / 4 * 3); + fp16_4_to_fp6_4_packed(fp16_ptr + idx, fp6_ptr + idx / 4 * 3); } at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { @@ -234,7 +235,7 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { if (fp16_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i += 4) - _fp16_to_fp6_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); + fp16_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 4 - 1) / (block_size * 4); @@ -244,7 +245,7 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { return fp6_tensor; } -__device__ __host__ static void _fp32_to_fp6_packed(const float *fp32_ptr, uint8_t *fp6_ptr) { +__device__ __host__ static void fp32_4_to_fp6_4_packed(const float *fp32_ptr, uint8_t *fp6_ptr) { uint8_t val0 = fp32_to_fp6(fp32_ptr[0]); uint8_t val1 = fp32_to_fp6(fp32_ptr[1]); uint8_t val2 = fp32_to_fp6(fp32_ptr[2]); @@ -258,7 +259,7 @@ __device__ __host__ static void _fp32_to_fp6_packed(const float *fp32_ptr, uint8 __global__ void fp32_to_fp6_packed_kernel(const float *fp32_ptr, uint8_t *fp6_ptr, int n) { const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; if (idx < n) - _fp32_to_fp6_packed(fp32_ptr + idx, fp6_ptr + idx / 4 * 3); + fp32_4_to_fp6_4_packed(fp32_ptr + idx, fp6_ptr + idx / 4 * 3); } at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { @@ -281,7 +282,7 @@ at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { if (fp32_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i += 4) - _fp32_to_fp6_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); + fp32_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 4 - 1) / (block_size * 4); @@ -371,18 +372,18 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &_fp16_to_fp6_packed); - m.impl("torchao::fp32_to_fp6_packed", &_fp32_to_fp6_packed); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &_fp6_packed_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &_fp16_to_fp6_packed); - m.impl("torchao::fp32_to_fp6_packed", &_fp32_to_fp6_packed); + m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &_fp6_packed_to_fp32); + m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } } From 4240692cb8702235de9c646a8ee705ecd4698c8f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 18 May 2024 20:28:55 +0800 Subject: [PATCH 22/80] add fp32->fp6 unpacked --- test/test_ops.py | 2 +- torchao/csrc/cuda/fp6_llm/fp6.cu | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index e810198048..7690686220 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -158,7 +158,7 @@ def test_to_fp6_correctness(self, input, output): self._skip_cpu() configs = [ (torch.half, torchao.ops.fp16_to_fp6_unpacked), - # (torch.float, torchao.ops.fp32_to_fp6_unpacked), + (torch.float, torchao.ops.fp32_to_fp6_unpacked), ] for dtype, func in configs: x = torch.tensor(input, dtype=dtype) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 1b85c3db48..396e4b8ec2 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -245,6 +245,38 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { return fp6_tensor; } +__global__ void fp32_to_fp6_unpacked_kernel(const float *fp32_ptr, uint8_t *fp6_ptr, int n) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) + fp6_ptr[tid] = fp32_to_fp6(fp32_ptr[tid]); +} + +// this is useful for debugging +at::Tensor fp32_to_fp6_unpacked(at::Tensor fp32_tensor) { + TORCH_CHECK(fp32_tensor.dtype() == torch::kFloat32); + TORCH_CHECK(fp32_tensor.is_contiguous()); + TORCH_CHECK(fp32_tensor.is_cpu() || fp32_tensor.is_cuda()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp32_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp32_tensor.sizes(), options); + + const float *fp32_ptr = fp32_tensor.data_ptr(); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp32_tensor.numel(); + + if (fp32_tensor.is_cpu()) { + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp6_ptr[i] = fp16_to_fp6(fp32_ptr[i]); + } else { + constexpr int block_size = 256; + int grid_size = (n + block_size - 1) / block_size; + fp32_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); + } + + return fp6_tensor; +} + __device__ __host__ static void fp32_4_to_fp6_4_packed(const float *fp32_ptr, uint8_t *fp6_ptr) { uint8_t val0 = fp32_to_fp6(fp32_ptr[0]); uint8_t val1 = fp32_to_fp6(fp32_ptr[1]); @@ -373,6 +405,7 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_unpacked", &fp32_to_fp6_unpacked); m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); @@ -381,6 +414,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); + m.impl("torchao::fp32_to_fp6_unpacked", &fp32_to_fp6_unpacked); m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); From 7eb6fa8c002c06d6e033e10ef93d34e8b214cdec Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 18 May 2024 20:33:05 +0800 Subject: [PATCH 23/80] fix --- torchao/csrc/cuda/fp6_llm/fp6.cu | 2 +- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 ++ torchao/ops.py | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 396e4b8ec2..aca6c0bc2c 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -267,7 +267,7 @@ at::Tensor fp32_to_fp6_unpacked(at::Tensor fp32_tensor) { if (fp32_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i++) - fp6_ptr[i] = fp16_to_fp6(fp32_ptr[i]); + fp6_ptr[i] = fp32_to_fp6(fp32_ptr[i]); } else { constexpr int block_size = 256; int grid_size = (n + block_size - 1) / block_size; diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index b7acdc5779..b7f1584bff 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -8,8 +8,10 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); + m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); + m.def("fp32_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("fp32_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); m.def("fp6_unpacked_to_fp32(Tensor fp6_tensor) -> Tensor"); m.def("fp6_packed_to_fp32(Tensor fp6_tensor) -> Tensor"); diff --git a/torchao/ops.py b/torchao/ops.py index b2959d2e6b..677e931343 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -69,6 +69,10 @@ def fp32_to_fp6_packed(fp32_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp32_to_fp6_packed.default(fp32_tensor.view(-1, last_dim)).view(*leading_dims, -1) +def fp32_to_fp6_unpacked(fp32_tensor: Tensor) -> Tensor: + return torch.ops.torchao.fp32_to_fp6_unpacked.default(fp32_tensor) + + def fp6_unpacked_to_fp32(fp6_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp6_unpacked_to_fp32.default(fp6_tensor) From 26669b62de733dfc027b10eea613a442417fd6f2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 10:20:37 +0800 Subject: [PATCH 24/80] use template. add BF16 --- torchao/csrc/cuda/fp6_llm/fp6.cu | 116 ++++++++++++------------------- 1 file changed, 45 insertions(+), 71 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index aca6c0bc2c..1965bfb694 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -29,48 +30,57 @@ __device__ __host__ static uint8_t fp32_to_fp6_ref(float a) { return result; } +__device__ __host__ static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } + // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -__device__ __host__ static uint8_t fp16_to_fp6(const __half a) { - uint16_t bits; - std::memcpy(&bits, &a, sizeof(a)); +template +__device__ __host__ static uint8_t bits_to_fp6(T bits) { + // sanity checks. will be removed in template specialization. +#ifndef __CUDA_ARCH__ + if (N_EXP_BITS < 3) + throw std::invalid_argument("Number of exponent bits must be >= 3."); + if (N_MAN_BITS < 3) + throw std::invalid_argument("Number of mantissa bits must be >= 3."); +#endif - uint16_t remainder = 0u; - uint16_t sign = bits >> 15u << 5u; - bits &= 0x7FFFu; // clear sign bit - uint16_t result; + constexpr uint32_t N_EXP_MAN_BITS = N_EXP_BITS + N_MAN_BITS; + T remainder = 0u; + T sign = bits >> N_EXP_MAN_BITS << 5u; + bits &= ones_mask(N_EXP_MAN_BITS); // clear sign bit + T result; - constexpr uint16_t EXP_BIAS_DIFF = 15u - 3u; + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP_BITS - 1u) - 3u; // only checks for invalid values on CPU, since we can't throw exception in CUDA #ifndef __CUDA_ARCH__ // all exponent bits are 1s - if (bits >= (0x1Fu << 10u)) + if (bits >= (ones_mask(N_EXP_BITS) << N_MAN_BITS)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - // max FP6 (28) + half of least significand (2) = 30 - if (bits >= (((EXP_BIAS_DIFF + 7u) << 10u) | (0x7u << 7u))) + // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN_BITS >= 3) + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN_BITS) | (0x7u << (N_MAN_BITS - 3u)))) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif // FP6 normal number (E>=001) - if (bits >= ((EXP_BIAS_DIFF + 1u) << 10u)) { - remainder = bits << (1u + 5u + 2u); - bits -= (EXP_BIAS_DIFF << 10u); // update exponent - result = sign | (bits >> (10u - 2u)); + if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN_BITS)) { + remainder = bits << (1u + N_EXP_BITS + 2u); + bits -= (EXP_BIAS_DIFF << N_MAN_BITS); // update exponent + result = sign | (bits >> (N_MAN_BITS - 2u)); } // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) - else if (bits > ((EXP_BIAS_DIFF - 2u) << 10u)) { - uint16_t exp = bits >> 10u; - uint16_t man = bits & 0x3FFu; + else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN_BITS)) { + T exp = bits >> N_MAN_BITS; + T man = bits & ones_mask(N_MAN_BITS); // to make subnormal FP6 from normal FP16 // step 1: add implicit 1 to mantissa - man |= 0x400u; + man |= (1u << N_MAN_BITS); // step 2: shift mantissa right so that exponent value is equal to // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) - uint16_t shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + 5u + 2u + shift); - result = sign | (man >> (shift + (10u - 2u))); // implicit E=000 + T shift = EXP_BIAS_DIFF + 1u - exp; + remainder = man << (1u + N_EXP_BITS + 2u + shift); + result = sign | (man >> (shift + (N_MAN_BITS - 2u))); // implicit E=000 } // FP6 underflow. E=000, M=00 else { @@ -78,65 +88,29 @@ __device__ __host__ static uint8_t fp16_to_fp6(const __half a) { } // round to nearest even - if ((remainder > 0x8000u) || ((remainder == 0x8000u) && (result & 1u))) { + constexpr T HALF_REMAINDER = 1u << N_EXP_MAN_BITS; + if ((remainder > HALF_REMAINDER) || ((remainder == HALF_REMAINDER) && (result & 0x1u))) { result += 1; } return result; } -// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" __device__ __host__ static uint8_t fp32_to_fp6_bits(const float a) { uint32_t bits; std::memcpy(&bits, &a, sizeof(a)); + return bits_to_fp6(bits); +} - uint32_t remainder = 0u; - uint32_t sign = bits >> 31u << 5u; - bits &= 0x7FFF'FFFFu; // clear sign bit - uint32_t result; - - constexpr uint32_t EXP_BIAS_DIFF = 127u - 3u; - - // only checks for invalid values on CPU, since we can't throw exception in CUDA -#ifndef __CUDA_ARCH__ - // all exponent bits are 1s - if (bits >= (0xFFu << 23u)) - throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - // max FP6 (28) + half of least significand (2) = 30 - if (bits >= (((EXP_BIAS_DIFF + 7u) << 23u) | (0x7u << 20u))) - throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); -#endif - - // FP6 normal number (E>=001) - if (bits >= ((EXP_BIAS_DIFF + 1u) << 23u)) { - remainder = bits << (1u + 8u + 2u); - bits -= (EXP_BIAS_DIFF << 23u); // update exponent - result = sign | (bits >> (23u - 2u)); - } - // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) - else if (bits > ((EXP_BIAS_DIFF - 2u) << 23u)) { - uint32_t exp = bits >> 23u; - uint32_t man = bits & 0x7F'FFFFu; - - // to make subnormal FP6 from normal FP16 - // step 1: add implicit 1 to mantissa - man |= 0x80'0000u; - - // step 2: shift mantissa right so that exponent value is equal to - // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) - uint32_t shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + 8u + 2u + shift); - result = sign | (man >> (shift + (23u - 2u))); // implicit E=000 - } - // FP6 underflow. E=000, M=00 - else { - result = sign; - } +__device__ __host__ static uint8_t fp16_to_fp6(const __half a) { + uint16_t bits; + std::memcpy(&bits, &a, sizeof(a)); + return bits_to_fp6(bits); +} - // round to nearest even - if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 1u))) { - result += 1; - } - return result; +__device__ __host__ static uint8_t bf16_to_fp6(const __nv_bfloat16 a) { + uint16_t bits; + std::memcpy(&bits, &a, sizeof(a)); + return bits_to_fp6(bits); } // #define fp32_to_fp6 fp32_to_fp6_ref From e09b61f4da3283adb266663e1856b61b3cd90f14 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 12:18:20 +0800 Subject: [PATCH 25/80] use template --- torchao/csrc/cuda/fp6_llm/fp6.cu | 202 +++++++++++++------------------ 1 file changed, 87 insertions(+), 115 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 1965bfb694..60cc1414c5 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -1,10 +1,9 @@ -#include -#include -#include -#include +#include +#include #include // reference implementation. this doesn't have a lot of bit manipulation, so it's less error-prone +// this is not exposed to PyTorch __device__ __host__ static uint8_t fp32_to_fp6_ref(float a) { #ifndef __CUDA_ARCH__ if (std::isnan(a) | std::isinf(a)) @@ -30,57 +29,71 @@ __device__ __host__ static uint8_t fp32_to_fp6_ref(float a) { return result; } +// we need to do this because C++17 does not allow using struct as template non-type parameter +// use the upper 16 bits for num exponent, lower 16 bits for num mantissa +static constexpr uint32_t encode_fp_spec(uint32_t n_exp_bits, uint32_t n_man_bits) { + return (n_exp_bits << 16u) | n_man_bits; +} + +static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u); +static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u); +static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u); + +// NOTE: only works for len < 32 __device__ __host__ static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -template +template __device__ __host__ static uint8_t bits_to_fp6(T bits) { + constexpr uint32_t N_EXP = FP_SPEC >> 16u; + constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); + constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; + // sanity checks. will be removed in template specialization. #ifndef __CUDA_ARCH__ - if (N_EXP_BITS < 3) + if (N_EXP < 3) throw std::invalid_argument("Number of exponent bits must be >= 3."); - if (N_MAN_BITS < 3) + if (N_MAN < 3) throw std::invalid_argument("Number of mantissa bits must be >= 3."); #endif - constexpr uint32_t N_EXP_MAN_BITS = N_EXP_BITS + N_MAN_BITS; T remainder = 0u; - T sign = bits >> N_EXP_MAN_BITS << 5u; - bits &= ones_mask(N_EXP_MAN_BITS); // clear sign bit + T sign = bits >> N_EXP_MAN << 5u; + bits &= ones_mask(N_EXP_MAN); // clear sign bit T result; - constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP_BITS - 1u) - 3u; + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; // only checks for invalid values on CPU, since we can't throw exception in CUDA #ifndef __CUDA_ARCH__ // all exponent bits are 1s - if (bits >= (ones_mask(N_EXP_BITS) << N_MAN_BITS)) + if (bits >= (ones_mask(N_EXP) << N_MAN)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN_BITS >= 3) - if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN_BITS) | (0x7u << (N_MAN_BITS - 3u)))) + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN- 3u)))) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif // FP6 normal number (E>=001) - if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN_BITS)) { - remainder = bits << (1u + N_EXP_BITS + 2u); - bits -= (EXP_BIAS_DIFF << N_MAN_BITS); // update exponent - result = sign | (bits >> (N_MAN_BITS - 2u)); + if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { + remainder = bits << (1u + N_EXP + 2u); + bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent + result = sign | (bits >> (N_MAN - 2u)); } // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) - else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN_BITS)) { - T exp = bits >> N_MAN_BITS; - T man = bits & ones_mask(N_MAN_BITS); + else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) { + T exp = bits >> N_MAN; + T man = bits & ones_mask(N_MAN); // to make subnormal FP6 from normal FP16 // step 1: add implicit 1 to mantissa - man |= (1u << N_MAN_BITS); + man |= (1u << N_MAN); // step 2: shift mantissa right so that exponent value is equal to // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) T shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + N_EXP_BITS + 2u + shift); - result = sign | (man >> (shift + (N_MAN_BITS - 2u))); // implicit E=000 + remainder = man << (1u + N_EXP + 2u + shift); + result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 } // FP6 underflow. E=000, M=00 else { @@ -88,33 +101,38 @@ __device__ __host__ static uint8_t bits_to_fp6(T bits) { } // round to nearest even - constexpr T HALF_REMAINDER = 1u << N_EXP_MAN_BITS; + constexpr T HALF_REMAINDER = 1u << N_EXP_MAN; if ((remainder > HALF_REMAINDER) || ((remainder == HALF_REMAINDER) && (result & 0x1u))) { result += 1; } return result; } -__device__ __host__ static uint8_t fp32_to_fp6_bits(const float a) { - uint32_t bits; - std::memcpy(&bits, &a, sizeof(a)); - return bits_to_fp6(bits); -} +template +__device__ __host__ static void bits_4_to_fp6_4_packed(const T *bits_ptr, uint8_t *fp6_ptr) { + uint8_t val0 = bits_to_fp6(bits_ptr[0]); + uint8_t val1 = bits_to_fp6(bits_ptr[1]); + uint8_t val2 = bits_to_fp6(bits_ptr[2]); + uint8_t val3 = bits_to_fp6(bits_ptr[3]); -__device__ __host__ static uint8_t fp16_to_fp6(const __half a) { - uint16_t bits; - std::memcpy(&bits, &a, sizeof(a)); - return bits_to_fp6(bits); + fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 } -__device__ __host__ static uint8_t bf16_to_fp6(const __nv_bfloat16 a) { - uint16_t bits; - std::memcpy(&bits, &a, sizeof(a)); - return bits_to_fp6(bits); +template +__global__ void bits_to_fp6_unpacked_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) + fp6_ptr[idx] = bits_to_fp6(bits_ptr[idx]); } -// #define fp32_to_fp6 fp32_to_fp6_ref -#define fp32_to_fp6 fp32_to_fp6_bits +template +__global__ void bits_to_fp6_packed_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < n) + bits_4_to_fp6_4_packed(bits_ptr + idx, fp6_ptr + idx / 4 * 3); +} // assume the lower 6 bits contain the data __device__ __host__ static float fp6_to_fp32(const uint8_t a) { @@ -134,18 +152,29 @@ __device__ __host__ static float fp6_to_fp32(const uint8_t a) { return result * 0x1p124; // 2^(127-3) } +__device__ __host__ static void fp6_4_packed_to_fp32_4(const uint8_t *fp6_ptr, float *fp32_ptr) { + uint8_t bits0 = fp6_ptr[0]; // 0000 0011 + uint8_t bits1 = fp6_ptr[1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[2]; // 2233 3333 + + fp32_ptr[0] = fp6_to_fp32(bits0 >> 2); + fp32_ptr[1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp32_ptr[2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp32_ptr[3] = fp6_to_fp32(bits2 & 0x3Fu); +} + +__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; + if (idx < n) + fp6_4_packed_to_fp32_4(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); +} + #include #include #include namespace torchao { -__global__ void fp16_to_fp6_unpacked_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < n) - fp6_ptr[tid] = fp16_to_fp6(fp16_ptr[tid]); -} - // this is useful for debugging at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); @@ -155,40 +184,23 @@ at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); - const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + const uint16_t *fp16_ptr = reinterpret_cast(fp16_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp16_tensor.numel(); if (fp16_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i++) - fp6_ptr[i] = fp16_to_fp6(fp16_ptr[i]); + fp6_ptr[i] = bits_to_fp6(fp16_ptr[i]); } else { constexpr int block_size = 256; int grid_size = (n + block_size - 1) / block_size; - fp16_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + bits_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); } return fp6_tensor; } -__device__ __host__ static void fp16_4_to_fp6_4_packed(const __half *fp16_ptr, uint8_t *fp6_ptr) { - uint8_t val0 = fp16_to_fp6(fp16_ptr[0]); - uint8_t val1 = fp16_to_fp6(fp16_ptr[1]); - uint8_t val2 = fp16_to_fp6(fp16_ptr[2]); - uint8_t val3 = fp16_to_fp6(fp16_ptr[3]); - - fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 -} - -__global__ void fp16_to_fp6_packed_kernel(const __half *fp16_ptr, uint8_t *fp6_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) - fp16_4_to_fp6_4_packed(fp16_ptr + idx, fp6_ptr + idx / 4 * 3); -} - at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); TORCH_CHECK(fp16_tensor.is_contiguous()); @@ -202,29 +214,23 @@ at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - const __half *fp16_ptr = reinterpret_cast<__half*>(fp16_tensor.data_ptr()); + const uint16_t *fp16_ptr = reinterpret_cast(fp16_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp16_tensor.numel(); if (fp16_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i += 4) - fp16_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); + bits_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - fp16_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + bits_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); } return fp6_tensor; } -__global__ void fp32_to_fp6_unpacked_kernel(const float *fp32_ptr, uint8_t *fp6_ptr, int n) { - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < n) - fp6_ptr[tid] = fp32_to_fp6(fp32_ptr[tid]); -} - // this is useful for debugging at::Tensor fp32_to_fp6_unpacked(at::Tensor fp32_tensor) { TORCH_CHECK(fp32_tensor.dtype() == torch::kFloat32); @@ -234,40 +240,23 @@ at::Tensor fp32_to_fp6_unpacked(at::Tensor fp32_tensor) { at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp32_tensor.device()); at::Tensor fp6_tensor = at::empty(fp32_tensor.sizes(), options); - const float *fp32_ptr = fp32_tensor.data_ptr(); + const uint32_t *fp32_ptr = reinterpret_cast(fp32_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp32_tensor.numel(); if (fp32_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i++) - fp6_ptr[i] = fp32_to_fp6(fp32_ptr[i]); + fp6_ptr[i] = bits_to_fp6(fp32_ptr[i]); } else { constexpr int block_size = 256; int grid_size = (n + block_size - 1) / block_size; - fp32_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); + bits_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); } return fp6_tensor; } -__device__ __host__ static void fp32_4_to_fp6_4_packed(const float *fp32_ptr, uint8_t *fp6_ptr) { - uint8_t val0 = fp32_to_fp6(fp32_ptr[0]); - uint8_t val1 = fp32_to_fp6(fp32_ptr[1]); - uint8_t val2 = fp32_to_fp6(fp32_ptr[2]); - uint8_t val3 = fp32_to_fp6(fp32_ptr[3]); - - fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 -} - -__global__ void fp32_to_fp6_packed_kernel(const float *fp32_ptr, uint8_t *fp6_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) - fp32_4_to_fp6_4_packed(fp32_ptr + idx, fp6_ptr + idx / 4 * 3); -} - at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { TORCH_CHECK(fp32_tensor.dtype() == torch::kFloat32); TORCH_CHECK(fp32_tensor.is_contiguous()); @@ -281,18 +270,18 @@ at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp32_tensor.device()); at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - const float *fp32_ptr = fp32_tensor.data_ptr(); + const uint32_t *fp32_ptr = reinterpret_cast(fp32_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); int n = fp32_tensor.numel(); if (fp32_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i += 4) - fp32_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); + bits_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - fp32_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); + bits_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); } return fp6_tensor; @@ -329,23 +318,6 @@ at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { return fp32_tensor; } -__device__ __host__ static void _fp6_packed_to_fp32(const uint8_t *fp6_ptr, float *fp32_ptr) { - uint8_t bits0 = fp6_ptr[0]; // 0000 0011 - uint8_t bits1 = fp6_ptr[1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[2]; // 2233 3333 - - fp32_ptr[0] = fp6_to_fp32(bits0 >> 2); - fp32_ptr[1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp32_ptr[2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp32_ptr[3] = fp6_to_fp32(bits2 & 0x3Fu); -} - -__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - if (idx < n) - _fp6_packed_to_fp32(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); -} - at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); @@ -366,7 +338,7 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { if (fp6_tensor.is_cpu()) { #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i += 3) - _fp6_packed_to_fp32(fp6_ptr + i, fp32_ptr + i / 3 * 4); + fp6_4_packed_to_fp32_4(fp6_ptr + i, fp32_ptr + i / 3 * 4); } else { constexpr int block_size = 256; int grid_size = (n + block_size * 3 - 1) / (block_size * 3); From 887bac266d93ea56f9da6c6602dae6c781a81511 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 12:41:33 +0800 Subject: [PATCH 26/80] simplify API. add BF16 support via templates --- test/test_ops.py | 16 +-- torchao/csrc/cuda/fp6_llm/fp6.cu | 224 ++++++++++++++++++------------- torchao/csrc/fp6_llm/fp6_llm.cpp | 6 +- torchao/ops.py | 38 +++--- 4 files changed, 157 insertions(+), 127 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 7690686220..4fcf6dd77a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -156,23 +156,19 @@ def _skip_cpu(self): ) def test_to_fp6_correctness(self, input, output): self._skip_cpu() - configs = [ - (torch.half, torchao.ops.fp16_to_fp6_unpacked), - (torch.float, torchao.ops.fp32_to_fp6_unpacked), - ] - for dtype, func in configs: + for dtype in (torch.float32, torch.float16, torch.bfloat16): x = torch.tensor(input, dtype=dtype) - assert func(x).item() == output - assert func(-x).item() == (output | 0b100000) - assert func(x.cuda()).item() == output - assert func(-x.cuda()).item() == (output | 0b100000) + assert torchao.ops.to_fp6_unpacked(x).item() == output + assert torchao.ops.to_fp6_unpacked(-x).item() == (output | 0b100000) + assert torchao.ops.to_fp6_unpacked(x.cuda()).item() == output + assert torchao.ops.to_fp6_unpacked(-x.cuda()).item() == (output | 0b100000) @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) def test_fp16_to_fp6_exception(self, input): self._skip_cpu() x = torch.tensor(input).half() with self.assertRaises(Exception): - torchao.ops.fp16_to_fp6_unpacked(x) + torchao.ops.to_fp6_unpacked(x) if __name__ == "__main__": diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 60cc1414c5..a93516320d 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -120,20 +120,6 @@ __device__ __host__ static void bits_4_to_fp6_4_packed(const T *bits_ptr, uint8_ fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 } -template -__global__ void bits_to_fp6_unpacked_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) - fp6_ptr[idx] = bits_to_fp6(bits_ptr[idx]); -} - -template -__global__ void bits_to_fp6_packed_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) - bits_4_to_fp6_4_packed(bits_ptr + idx, fp6_ptr + idx / 4 * 3); -} - // assume the lower 6 bits contain the data __device__ __host__ static float fp6_to_fp32(const uint8_t a) { // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. @@ -176,112 +162,172 @@ __global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_pt namespace torchao { // this is useful for debugging -at::Tensor fp16_to_fp6_unpacked(at::Tensor fp16_tensor) { - TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); - TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp16_tensor.sizes(), options); - - const uint16_t *fp16_ptr = reinterpret_cast(fp16_tensor.data_ptr()); +at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp16_tensor.numel(); - if (fp16_tensor.is_cpu()) { + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp6_ptr[i] = bits_to_fp6(fp32_ptr[i]); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + #pragma omp parallel for num_threads(4) for (int i = 0; i < n; i++) fp6_ptr[i] = bits_to_fp6(fp16_ptr[i]); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i++) + fp6_ptr[i] = bits_to_fp6(bf16_ptr[i]); + } else { - constexpr int block_size = 256; - int grid_size = (n + block_size - 1) / block_size; - bits_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); } return fp6_tensor; } -at::Tensor fp16_to_fp6_packed(at::Tensor fp16_tensor) { - TORCH_CHECK(fp16_tensor.dtype() == torch::kFloat16); - TORCH_CHECK(fp16_tensor.is_contiguous()); - TORCH_CHECK(fp16_tensor.is_cpu() || fp16_tensor.is_cuda()); - TORCH_CHECK(fp16_tensor.ndimension() == 2); - - int M = fp16_tensor.size(0); - int N = fp16_tensor.size(1); - TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); +template +__global__ void bits_to_fp6_unpacked_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) + fp6_ptr[idx] = bits_to_fp6(bits_ptr[idx]); +} - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp16_tensor.device()); - at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); +// this is useful for debugging +at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cuda()); - const uint16_t *fp16_ptr = reinterpret_cast(fp16_tensor.data_ptr()); + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp16_tensor.numel(); - if (fp16_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 4) - bits_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + constexpr int block_size = 256; + int grid_size = (n + block_size - 1) / block_size; + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + bits_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + bits_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + bits_to_fp6_unpacked_kernel<<>>(bf16_ptr, fp6_ptr, n); + } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - bits_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); } return fp6_tensor; } -// this is useful for debugging -at::Tensor fp32_to_fp6_unpacked(at::Tensor fp32_tensor) { - TORCH_CHECK(fp32_tensor.dtype() == torch::kFloat32); - TORCH_CHECK(fp32_tensor.is_contiguous()); - TORCH_CHECK(fp32_tensor.is_cpu() || fp32_tensor.is_cuda()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp32_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp32_tensor.sizes(), options); - - const uint32_t *fp32_ptr = reinterpret_cast(fp32_tensor.data_ptr()); +at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cpu()); + TORCH_CHECK(fp_tensor.ndimension() == 2); + + int M = fp_tensor.size(0); + int N = fp_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp32_tensor.numel(); - if (fp32_tensor.is_cpu()) { + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i++) - fp6_ptr[i] = bits_to_fp6(fp32_ptr[i]); + for (int i = 0; i < n; i += 4) + bits_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 4) + bits_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + + #pragma omp parallel for num_threads(4) + for (int i = 0; i < n; i += 4) + bits_4_to_fp6_4_packed(bf16_ptr + i, fp6_ptr + i / 4 * 3); + } else { - constexpr int block_size = 256; - int grid_size = (n + block_size - 1) / block_size; - bits_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); } return fp6_tensor; } -at::Tensor fp32_to_fp6_packed(at::Tensor fp32_tensor) { - TORCH_CHECK(fp32_tensor.dtype() == torch::kFloat32); - TORCH_CHECK(fp32_tensor.is_contiguous()); - TORCH_CHECK(fp32_tensor.is_cpu() || fp32_tensor.is_cuda()); - TORCH_CHECK(fp32_tensor.ndimension() == 2); +template +__global__ void bits_to_fp6_packed_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // times 4 since each thread will handle 4 values + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < n) + bits_4_to_fp6_4_packed(bits_ptr + idx, fp6_ptr + idx / 4 * 3); +} + +at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cuda()); + TORCH_CHECK(fp_tensor.ndimension() == 2); - int M = fp32_tensor.size(0); - int N = fp32_tensor.size(1); + int M = fp_tensor.size(0); + int N = fp_tensor.size(1); TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp32_tensor.device()); + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - - const uint32_t *fp32_ptr = reinterpret_cast(fp32_tensor.data_ptr()); uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp32_tensor.numel(); - if (fp32_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) - for (int i = 0; i < n; i += 4) - bits_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); - } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / (block_size * 4); + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + // times 4 since each thread will handle 4 values + constexpr int block_size = 256; + int grid_size = (n + (block_size * 4) - 1) / (block_size * 4); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); bits_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + bits_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + bits_to_fp6_packed_kernel<<>>(bf16_ptr, fp6_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); } return fp6_tensor; @@ -349,19 +395,15 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { } TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); - m.impl("torchao::fp32_to_fp6_unpacked", &fp32_to_fp6_unpacked); - m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); + m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cpu); + m.impl("torchao::to_fp6_packed", &to_fp6_packed_cpu); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked); - m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed); - m.impl("torchao::fp32_to_fp6_unpacked", &fp32_to_fp6_unpacked); - m.impl("torchao::fp32_to_fp6_packed", &fp32_to_fp6_packed); + m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cuda); + m.impl("torchao::to_fp6_packed", &to_fp6_packed_cuda); m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index b7f1584bff..8e4c20f08a 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -9,10 +9,8 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); - m.def("fp16_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); - m.def("fp16_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); - m.def("fp32_to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); - m.def("fp32_to_fp6_packed(Tensor fp16_tensor) -> Tensor"); + m.def("to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); + m.def("to_fp6_packed(Tensor fp16_tensor) -> Tensor"); m.def("fp6_unpacked_to_fp32(Tensor fp6_tensor) -> Tensor"); m.def("fp6_packed_to_fp32(Tensor fp6_tensor) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 677e931343..89f2a910da 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -42,35 +42,29 @@ def _(fp6_weight): return torch.empty_like(fp6_weight) -def fp16_to_fp6_unpacked(fp16_tensor: Tensor) -> Tensor: - return torch.ops.torchao.fp16_to_fp6_unpacked.default(fp16_tensor) +def to_fp6_unpacked(fp_tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_fp6_unpacked.default(fp_tensor) -@torch.library.impl_abstract("torchao::fp16_to_fp6_unpacked") -def _(fp16_tensor): - return torch.empty_like(fp16_tensor, dtype=torch.uint8) +@torch.library.impl_abstract("torchao::to_fp6_unpacked") +def _(fp_tensor): + return torch.empty_like(fp_tensor, dtype=torch.uint8) -def fp16_to_fp6_packed(fp16_tensor: Tensor) -> Tensor: - *leading_dims, last_dim = fp16_tensor.shape - return torch.ops.torchao.fp16_to_fp6_packed.default(fp16_tensor.view(-1, last_dim)).view(*leading_dims, -1) +def to_fp6_packed(fp_tensor: Tensor) -> Tensor: + *leading_dims, last_dim = fp_tensor.shape + return torch.ops.torchao.to_fp6_packed.default(fp_tensor.view(-1, last_dim)).view(*leading_dims, -1) -@torch.library.impl_abstract("torchao::fp16_to_fp6_packed") -def _(fp16_tensor): - torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") - *leading_dims, last_dim = fp16_tensor.shape +@torch.library.impl_abstract("torchao::to_fp6_packed") +def _(fp_tensor): + torch._check( + fp_tensor.dtype in (torch.float32, torch.float16, torch.bfloat16), + lambda: f"weight must be FP32, FP16, or BF16, got {fp_tensor.dtype}", + ) + *leading_dims, last_dim = fp_tensor.shape torch._check(last_dim % 4 == 0, lambda: f"last dimension must be a multiple of 4, got {last_dim}") - return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp16_tensor.device, dtype=torch.uint8) - - -def fp32_to_fp6_packed(fp32_tensor: Tensor) -> Tensor: - *leading_dims, last_dim = fp32_tensor.shape - return torch.ops.torchao.fp32_to_fp6_packed.default(fp32_tensor.view(-1, last_dim)).view(*leading_dims, -1) - - -def fp32_to_fp6_unpacked(fp32_tensor: Tensor) -> Tensor: - return torch.ops.torchao.fp32_to_fp6_unpacked.default(fp32_tensor) + return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp_tensor.device, dtype=torch.uint8) def fp6_unpacked_to_fp32(fp6_tensor: Tensor) -> Tensor: From 4b5c99f98584dfc2155a3bc354188ff3216eacee Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 12:43:29 +0800 Subject: [PATCH 27/80] typo --- torchao/csrc/cuda/fp6_llm/fp6.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index a93516320d..48b3c1807f 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -69,8 +69,8 @@ __device__ __host__ static uint8_t bits_to_fp6(T bits) { // all exponent bits are 1s if (bits >= (ones_mask(N_EXP) << N_MAN)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN_BITS >= 3) - if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN- 3u)))) + // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif From 39f9dce9c81440c068dd9722d98db4aceee8e8ab Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 15:49:12 +0800 Subject: [PATCH 28/80] enable OpenMP via compile flags --- setup.py | 2 ++ torchao/csrc/cuda/fp6_llm/fp6.cu | 16 ++++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 5d1f32da2b..0cac46ab96 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ def get_extensions(): extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] + extra_link_args.append("-fopenmp") extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", @@ -54,6 +55,7 @@ def get_extensions(): ], "nvcc": [ "-O3" if not debug_mode else "-O0", + "-Xcompiler", "-fopenmp", ] } if debug_mode: diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 48b3c1807f..7d7d0a17b2 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -176,21 +176,21 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i++) fp6_ptr[i] = bits_to_fp6(fp32_ptr[i]); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i++) fp6_ptr[i] = bits_to_fp6(fp16_ptr[i]); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i++) fp6_ptr[i] = bits_to_fp6(bf16_ptr[i]); @@ -261,21 +261,21 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i += 4) bits_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i += 4) bits_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i += 4) bits_4_to_fp6_4_packed(bf16_ptr + i, fp6_ptr + i / 4 * 3); @@ -352,7 +352,7 @@ at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { int n = fp6_tensor.numel(); if (fp6_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i++) fp32_ptr[i] = fp6_to_fp32(fp6_ptr[i]); } else { @@ -382,7 +382,7 @@ at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { int n = fp6_tensor.numel(); if (fp6_tensor.is_cpu()) { - #pragma omp parallel for num_threads(4) + #pragma omp parallel for for (int i = 0; i < n; i += 3) fp6_4_packed_to_fp32_4(fp6_ptr + i, fp32_ptr + i / 3 * 4); } else { From b681ae147130f9c3850c5d2dfa3b9f12cb98de94 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 21:21:32 +0800 Subject: [PATCH 29/80] add memory access optimized version (though it is not faster..) --- torchao/csrc/cuda/fp6_llm/fp6.cu | 53 +++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 7d7d0a17b2..fa0aa72eef 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -286,12 +286,57 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { return fp6_tensor; } -template +template __global__ void bits_to_fp6_packed_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // naive version // times 4 since each thread will handle 4 values const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; if (idx < n) bits_4_to_fp6_4_packed(bits_ptr + idx, fp6_ptr + idx / 4 * 3); + return; + + // more optimized version. coalesced memory write (speedup is minimal) + // const int tid = threadIdx.x; + // const int input_offset = (blockIdx.x * blockDim.x) * 4; + // const int output_offset = (blockIdx.x * blockDim.x) * 3; + + // bits_ptr += input_offset; + // fp6_ptr += output_offset; + + // __shared__ uint8_t shmem[BLOCK_SIZE * 3]; + + // if (input_offset + tid * 4 < n) { + // uint8_t val0, val1, val2, val3; + // if (std::is_same_v) { + // uint4 values = reinterpret_cast(bits_ptr)[tid * 4]; + // val0 = bits_to_fp6(values.x); + // val1 = bits_to_fp6(values.y); + // val2 = bits_to_fp6(values.z); + // val3 = bits_to_fp6(values.w); + // } else if (std::is_same_v) { + // ushort4 values = reinterpret_cast(bits_ptr)[tid * 4]; + // val0 = bits_to_fp6(values.x); + // val1 = bits_to_fp6(values.y); + // val2 = bits_to_fp6(values.z); + // val3 = bits_to_fp6(values.w); + // } else { + // val0 = bits_to_fp6(bits_ptr[tid * 4]); + // val1 = bits_to_fp6(bits_ptr[tid * 4 + 1]); + // val2 = bits_to_fp6(bits_ptr[tid * 4 + 2]); + // val3 = bits_to_fp6(bits_ptr[tid * 4 + 3]); + // } + // shmem[tid * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + // shmem[tid * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + // shmem[tid * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + // } + // __syncthreads(); + + // // TODO: write in larger word size + // for (int i = 0; i < 3; i++) { + // if (output_offset + BLOCK_SIZE * i + tid < n / 4 * 3) { + // fp6_ptr[BLOCK_SIZE * i + tid] = shmem[BLOCK_SIZE * i + tid]; + // } + // } } at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { @@ -316,15 +361,15 @@ at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); + bits_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + bits_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_packed_kernel<<>>(bf16_ptr, fp6_ptr, n); + bits_to_fp6_packed_kernel<<>>(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); From 7c5fcd30a234a67c87ceb9e4c9abd101f318555d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 22:48:19 +0800 Subject: [PATCH 30/80] use fp32 mul impl for CUDA --- torchao/csrc/cuda/fp6_llm/fp6.cu | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index fa0aa72eef..10f2463286 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -1,10 +1,11 @@ +#include +#include #include #include #include // reference implementation. this doesn't have a lot of bit manipulation, so it's less error-prone -// this is not exposed to PyTorch -__device__ __host__ static uint8_t fp32_to_fp6_ref(float a) { +__device__ __host__ static uint8_t fp32_to_fp6_value(float a) { #ifndef __CUDA_ARCH__ if (std::isnan(a) | std::isinf(a)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); @@ -45,6 +46,27 @@ __device__ __host__ static constexpr uint32_t ones_mask(uint32_t len) { return ( // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" template __device__ __host__ static uint8_t bits_to_fp6(T bits) { + // on CUDA, dtype conversion kernels are memory-bound. thus, using fp32_to_fp6_value() + // does not impact the speed. fp32_to_fp6_value() also won't cause warp divergence. + // on CPU, for FP32->FP6, bit manipulation is 20% faster than fp32_to_fp6_value(). +#ifdef __CUDA_ARCH__ + if (std::is_same_v && (FP_SPEC == FP32_SPEC)) { + float a; + std::memcpy(&a, &bits, sizeof(bits)); + return fp32_to_fp6_value(a); + } + if (std::is_same_v && (FP_SPEC == FP16_SPEC)) { + __half a; + std::memcpy(&a, &bits, sizeof(bits)); + return fp32_to_fp6_value(__half2float(a)); + } + if (std::is_same_v && (FP_SPEC == BF16_SPEC)) { + __nv_bfloat16 a; + std::memcpy(&a, &bits, sizeof(bits)); + return fp32_to_fp6_value(__bfloat162float(a)); + } +#endif + constexpr uint32_t N_EXP = FP_SPEC >> 16u; constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; @@ -92,7 +114,7 @@ __device__ __host__ static uint8_t bits_to_fp6(T bits) { // step 2: shift mantissa right so that exponent value is equal to // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) T shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + N_EXP + 2u + shift); + remainder = man << (1u + N_EXP + 2u + shift); // THIS IS WRONG, need to change result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 } // FP6 underflow. E=000, M=00 From 82e4e60eb0903ce741766c45a8f3968d556853f4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 22:51:19 +0800 Subject: [PATCH 31/80] add test case --- test/test_ops.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4fcf6dd77a..61e56c9c41 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -149,8 +149,10 @@ def _skip_cpu(self): (28.0, 0b011111), # max (0.1875, 0b00011), # subnormal number (0.0625, 0b000001), # min - (29.0, 0b011111), # rounding - (26.0, 0b011110), # round to nearest even + (29.0, 0b011111), # normal round down + (26.0, 0b011110), # normal round to nearest even + (0.1251, 0b000010), # subnormal round down + (0.03128, 0b000001), # subnormal round up (0.03, 0b000000), # underflow ] ) @@ -163,12 +165,12 @@ def test_to_fp6_correctness(self, input, output): assert torchao.ops.to_fp6_unpacked(x.cuda()).item() == output assert torchao.ops.to_fp6_unpacked(-x.cuda()).item() == (output | 0b100000) - @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) - def test_fp16_to_fp6_exception(self, input): - self._skip_cpu() - x = torch.tensor(input).half() - with self.assertRaises(Exception): - torchao.ops.to_fp6_unpacked(x) + # @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) + # def test_fp16_to_fp6_exception(self, input): + # self._skip_cpu() + # x = torch.tensor(input).half() + # with self.assertRaises(Exception): + # torchao.ops.to_fp6_unpacked(x) if __name__ == "__main__": From fb18c736f0e2d698f821d347dcc1cad46fbe666e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 19 May 2024 23:13:40 +0800 Subject: [PATCH 32/80] typo. remove OpenMP since we cannot throw exception --- setup.py | 4 ++-- test/test_ops.py | 14 +++++++------- torchao/csrc/cuda/fp6_llm/fp6.cu | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 0cac46ab96..cbf24e98eb 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_extensions(): extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] - extra_link_args.append("-fopenmp") + # extra_link_args.append("-fopenmp") extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", @@ -55,7 +55,7 @@ def get_extensions(): ], "nvcc": [ "-O3" if not debug_mode else "-O0", - "-Xcompiler", "-fopenmp", + # "-Xcompiler", "-fopenmp", ] } if debug_mode: diff --git a/test/test_ops.py b/test/test_ops.py index 61e56c9c41..05beaf89d8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -152,7 +152,7 @@ def _skip_cpu(self): (29.0, 0b011111), # normal round down (26.0, 0b011110), # normal round to nearest even (0.1251, 0b000010), # subnormal round down - (0.03128, 0b000001), # subnormal round up + (0.0313, 0b000001), # subnormal round up (0.03, 0b000000), # underflow ] ) @@ -165,12 +165,12 @@ def test_to_fp6_correctness(self, input, output): assert torchao.ops.to_fp6_unpacked(x.cuda()).item() == output assert torchao.ops.to_fp6_unpacked(-x.cuda()).item() == (output | 0b100000) - # @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) - # def test_fp16_to_fp6_exception(self, input): - # self._skip_cpu() - # x = torch.tensor(input).half() - # with self.assertRaises(Exception): - # torchao.ops.to_fp6_unpacked(x) + @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) + def test_to_fp6_exception(self, input): + self._skip_cpu() + x = torch.tensor(input) + with self.assertRaises(Exception): + torchao.ops.to_fp6_unpacked(x) if __name__ == "__main__": diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 10f2463286..6f82ed1c6f 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -55,7 +55,7 @@ __device__ __host__ static uint8_t bits_to_fp6(T bits) { std::memcpy(&a, &bits, sizeof(bits)); return fp32_to_fp6_value(a); } - if (std::is_same_v && (FP_SPEC == FP16_SPEC)) { + if (std::is_same_v && (FP_SPEC == FP16_SPEC)) { __half a; std::memcpy(&a, &bits, sizeof(bits)); return fp32_to_fp6_value(__half2float(a)); From a3c5e3671695d32a8cca58ede38c4e6288ffe24f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 07:43:37 +0800 Subject: [PATCH 33/80] fix rounding for subnormal --- test/test_ops.py | 2 +- torchao/csrc/cuda/fp6_llm/fp6.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 05beaf89d8..7124cd45e2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -152,7 +152,7 @@ def _skip_cpu(self): (29.0, 0b011111), # normal round down (26.0, 0b011110), # normal round to nearest even (0.1251, 0b000010), # subnormal round down - (0.0313, 0b000001), # subnormal round up + (0.0314, 0b000001), # subnormal round up (0.03, 0b000000), # underflow ] ) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 6f82ed1c6f..ef8619e1e1 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -114,7 +114,7 @@ __device__ __host__ static uint8_t bits_to_fp6(T bits) { // step 2: shift mantissa right so that exponent value is equal to // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) T shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + N_EXP + 2u + shift); // THIS IS WRONG, need to change + remainder = man << (1u + N_EXP + 2u - shift); result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 } // FP6 underflow. E=000, M=00 From 27781e5f42927a71c931e43a2ccf4533edd6a430 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 01:24:23 +0000 Subject: [PATCH 34/80] add to_fp6_value() --- torchao/csrc/cuda/fp6_llm/fp6.cu | 90 +++++++++++++++----------------- 1 file changed, 42 insertions(+), 48 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index ef8619e1e1..e39814feb2 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -1,11 +1,30 @@ +#include +#include +#include #include #include #include #include #include -// reference implementation. this doesn't have a lot of bit manipulation, so it's less error-prone -__device__ __host__ static uint8_t fp32_to_fp6_value(float a) { +// This implementation doesn't have a lot of bit manipulation, so it's less error-prone. +// On CPU, for FP32->FP6, bit manipulation (to_fp6_bits()) is 20% faster than this. +// On CUDA, dtype conversion kernels are memory-bound. Thus, using to_fp6_value() or +// to_fp6_bits() does not matter much. However, to_fp6_bits() has a lot of branching +// based on input value, thus it will cause warp divergence. +template +__device__ __host__ static uint8_t to_fp6_value(T a) { + float fp32_value; + + if (std::is_same_v) + fp32_value = a; + else if (std::is_same_v) + fp32_value = __half2float(a); + else if (std::is_same_v) + fp32_value = __bfloat162float(a); + else if (std::is_same_v || std::is_same_v) + fp32_value = static_cast(a); + #ifndef __CUDA_ARCH__ if (std::isnan(a) | std::isinf(a)) throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); @@ -45,28 +64,7 @@ __device__ __host__ static constexpr uint32_t ones_mask(uint32_t len) { return ( // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" template -__device__ __host__ static uint8_t bits_to_fp6(T bits) { - // on CUDA, dtype conversion kernels are memory-bound. thus, using fp32_to_fp6_value() - // does not impact the speed. fp32_to_fp6_value() also won't cause warp divergence. - // on CPU, for FP32->FP6, bit manipulation is 20% faster than fp32_to_fp6_value(). -#ifdef __CUDA_ARCH__ - if (std::is_same_v && (FP_SPEC == FP32_SPEC)) { - float a; - std::memcpy(&a, &bits, sizeof(bits)); - return fp32_to_fp6_value(a); - } - if (std::is_same_v && (FP_SPEC == FP16_SPEC)) { - __half a; - std::memcpy(&a, &bits, sizeof(bits)); - return fp32_to_fp6_value(__half2float(a)); - } - if (std::is_same_v && (FP_SPEC == BF16_SPEC)) { - __nv_bfloat16 a; - std::memcpy(&a, &bits, sizeof(bits)); - return fp32_to_fp6_value(__bfloat162float(a)); - } -#endif - +__device__ __host__ static uint8_t to_fp6_bits(T bits) { constexpr uint32_t N_EXP = FP_SPEC >> 16u; constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; @@ -132,10 +130,10 @@ __device__ __host__ static uint8_t bits_to_fp6(T bits) { template __device__ __host__ static void bits_4_to_fp6_4_packed(const T *bits_ptr, uint8_t *fp6_ptr) { - uint8_t val0 = bits_to_fp6(bits_ptr[0]); - uint8_t val1 = bits_to_fp6(bits_ptr[1]); - uint8_t val2 = bits_to_fp6(bits_ptr[2]); - uint8_t val3 = bits_to_fp6(bits_ptr[3]); + uint8_t val0 = to_fp6_bits(bits_ptr[0]); + uint8_t val1 = to_fp6_bits(bits_ptr[1]); + uint8_t val2 = to_fp6_bits(bits_ptr[2]); + uint8_t val3 = to_fp6_bits(bits_ptr[3]); fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 @@ -177,10 +175,6 @@ __global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_pt fp6_4_packed_to_fp32_4(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); } -#include -#include -#include - namespace torchao { // this is useful for debugging @@ -200,21 +194,21 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { #pragma omp parallel for for (int i = 0; i < n; i++) - fp6_ptr[i] = bits_to_fp6(fp32_ptr[i]); + fp6_ptr[i] = to_fp6_bits(fp32_ptr[i]); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); #pragma omp parallel for for (int i = 0; i < n; i++) - fp6_ptr[i] = bits_to_fp6(fp16_ptr[i]); + fp6_ptr[i] = to_fp6_bits(fp16_ptr[i]); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); #pragma omp parallel for for (int i = 0; i < n; i++) - fp6_ptr[i] = bits_to_fp6(bf16_ptr[i]); + fp6_ptr[i] = to_fp6_bits(bf16_ptr[i]); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -227,7 +221,7 @@ template __global__ void bits_to_fp6_unpacked_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) - fp6_ptr[idx] = bits_to_fp6(bits_ptr[idx]); + fp6_ptr[idx] = to_fp6_bits(bits_ptr[idx]); } // this is useful for debugging @@ -331,21 +325,21 @@ __global__ void bits_to_fp6_packed_kernel(const T *bits_ptr, uint8_t *fp6_ptr, i // uint8_t val0, val1, val2, val3; // if (std::is_same_v) { // uint4 values = reinterpret_cast(bits_ptr)[tid * 4]; - // val0 = bits_to_fp6(values.x); - // val1 = bits_to_fp6(values.y); - // val2 = bits_to_fp6(values.z); - // val3 = bits_to_fp6(values.w); + // val0 = to_fp6_bits(values.x); + // val1 = to_fp6_bits(values.y); + // val2 = to_fp6_bits(values.z); + // val3 = to_fp6_bits(values.w); // } else if (std::is_same_v) { // ushort4 values = reinterpret_cast(bits_ptr)[tid * 4]; - // val0 = bits_to_fp6(values.x); - // val1 = bits_to_fp6(values.y); - // val2 = bits_to_fp6(values.z); - // val3 = bits_to_fp6(values.w); + // val0 = to_fp6_bits(values.x); + // val1 = to_fp6_bits(values.y); + // val2 = to_fp6_bits(values.z); + // val3 = to_fp6_bits(values.w); // } else { - // val0 = bits_to_fp6(bits_ptr[tid * 4]); - // val1 = bits_to_fp6(bits_ptr[tid * 4 + 1]); - // val2 = bits_to_fp6(bits_ptr[tid * 4 + 2]); - // val3 = bits_to_fp6(bits_ptr[tid * 4 + 3]); + // val0 = to_fp6_bits(bits_ptr[tid * 4]); + // val1 = to_fp6_bits(bits_ptr[tid * 4 + 1]); + // val2 = to_fp6_bits(bits_ptr[tid * 4 + 2]); + // val3 = to_fp6_bits(bits_ptr[tid * 4 + 3]); // } // shmem[tid * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 // shmem[tid * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 From 7d9dd34a08dad5efbc72b1e144cc32a9e88c9c89 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 01:59:54 +0000 Subject: [PATCH 35/80] simplify to_fp6_unpacked_cuda --- torchao/csrc/cuda/fp6_llm/fp6.cu | 33 ++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index e39814feb2..53d6acbf23 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -6,6 +6,7 @@ #include #include #include +#include // This implementation doesn't have a lot of bit manipulation, so it's less error-prone. // On CPU, for FP32->FP6, bit manipulation (to_fp6_bits()) is 20% faster than this. @@ -16,14 +17,18 @@ template __device__ __host__ static uint8_t to_fp6_value(T a) { float fp32_value; - if (std::is_same_v) + // need to use if constexpr so that the branches are pruned at compile-time. + // without it, expression in each branch must be valid regardless of template type T. + if constexpr (std::is_same_v) fp32_value = a; - else if (std::is_same_v) + else if constexpr (std::is_same_v) fp32_value = __half2float(a); - else if (std::is_same_v) + else if constexpr (std::is_same_v) fp32_value = __bfloat162float(a); - else if (std::is_same_v || std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) fp32_value = static_cast(a); + else + assert(false); #ifndef __CUDA_ARCH__ if (std::isnan(a) | std::isinf(a)) @@ -32,9 +37,9 @@ __device__ __host__ static uint8_t to_fp6_value(T a) { throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); #endif - a *= 0x1p-124; // 2^(127-3) + fp32_value *= 0x1p-124; // 2^(127-3) uint32_t bits; - std::memcpy(&bits, &a, sizeof(a)); + std::memcpy(&bits, &fp32_value, sizeof(fp32_value)); uint8_t sign = bits >> 31u << 5u; uint8_t exp_and_man = (bits >> 21u) & 0x1Fu; @@ -217,11 +222,11 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { return fp6_tensor; } -template +template __global__ void bits_to_fp6_unpacked_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) - fp6_ptr[idx] = to_fp6_bits(bits_ptr[idx]); + fp6_ptr[idx] = to_fp6_value(bits_ptr[idx]); } // this is useful for debugging @@ -240,16 +245,16 @@ at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { int grid_size = (n + block_size - 1) / block_size; if (dtype == torch::kFloat32) { - const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); + const float *fp32_ptr = fp_tensor.data_ptr(); + bits_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { - const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + const at::Half *fp16_ptr = fp_tensor.data_ptr(); + bits_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { - const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_unpacked_kernel<<>>(bf16_ptr, fp6_ptr, n); + const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr(); + bits_to_fp6_unpacked_kernel<<>>(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); From 7fb8c8ba250f1888f74fd2397b5a814b86aed04c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 03:54:55 +0000 Subject: [PATCH 36/80] simplify to_fp6_packed_cuda --- torchao/csrc/cuda/fp6_llm/fp6.cu | 147 +++++++++++++++++-------------- 1 file changed, 81 insertions(+), 66 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 53d6acbf23..00f68e0352 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -1,12 +1,17 @@ -#include #include +#include #include + #include #include + #include #include #include -#include + + +// need to do this trick so that static_assert(false) only evaluates at template instantiation. +template constexpr std::false_type always_false{}; // This implementation doesn't have a lot of bit manipulation, so it's less error-prone. // On CPU, for FP32->FP6, bit manipulation (to_fp6_bits()) is 20% faster than this. @@ -25,10 +30,10 @@ __device__ __host__ static uint8_t to_fp6_value(T a) { fp32_value = __half2float(a); else if constexpr (std::is_same_v) fp32_value = __bfloat162float(a); - else if constexpr (std::is_same_v || std::is_same_v) + else if constexpr (std::is_same_v || std::is_same_v) fp32_value = static_cast(a); else - assert(false); + static_assert(always_false, "Only float, __half, __nv_bfloat16, c10::Half, and c10::BFloat16 are suppored"); #ifndef __CUDA_ARCH__ if (std::isnan(a) | std::isinf(a)) @@ -223,10 +228,13 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { } template -__global__ void bits_to_fp6_unpacked_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { +__global__ void to_fp6_unpacked_kernel(const T *fp_ptr, uint8_t *fp6_ptr, int n) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // NOTE: we are writing 32 uint8 (32 bytes) to global memory. vector load can be used + // to improve memory throughput. using uchar4, we can issue 128-byte global memory write. if (idx < n) - fp6_ptr[idx] = to_fp6_value(bits_ptr[idx]); + fp6_ptr[idx] = to_fp6_value(fp_ptr[idx]); } // this is useful for debugging @@ -246,15 +254,15 @@ at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const float *fp32_ptr = fp_tensor.data_ptr(); - bits_to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); + to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { const at::Half *fp16_ptr = fp_tensor.data_ptr(); - bits_to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); + to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr(); - bits_to_fp6_unpacked_kernel<<>>(bf16_ptr, fp6_ptr, n); + to_fp6_unpacked_kernel<<>>(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -307,57 +315,64 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { return fp6_tensor; } -template -__global__ void bits_to_fp6_packed_kernel(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - // naive version - // times 4 since each thread will handle 4 values - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < n) - bits_4_to_fp6_4_packed(bits_ptr + idx, fp6_ptr + idx / 4 * 3); - return; - - // more optimized version. coalesced memory write (speedup is minimal) - // const int tid = threadIdx.x; - // const int input_offset = (blockIdx.x * blockDim.x) * 4; - // const int output_offset = (blockIdx.x * blockDim.x) * 3; - - // bits_ptr += input_offset; - // fp6_ptr += output_offset; - - // __shared__ uint8_t shmem[BLOCK_SIZE * 3]; - - // if (input_offset + tid * 4 < n) { - // uint8_t val0, val1, val2, val3; - // if (std::is_same_v) { - // uint4 values = reinterpret_cast(bits_ptr)[tid * 4]; - // val0 = to_fp6_bits(values.x); - // val1 = to_fp6_bits(values.y); - // val2 = to_fp6_bits(values.z); - // val3 = to_fp6_bits(values.w); - // } else if (std::is_same_v) { - // ushort4 values = reinterpret_cast(bits_ptr)[tid * 4]; - // val0 = to_fp6_bits(values.x); - // val1 = to_fp6_bits(values.y); - // val2 = to_fp6_bits(values.z); - // val3 = to_fp6_bits(values.w); - // } else { - // val0 = to_fp6_bits(bits_ptr[tid * 4]); - // val1 = to_fp6_bits(bits_ptr[tid * 4 + 1]); - // val2 = to_fp6_bits(bits_ptr[tid * 4 + 2]); - // val3 = to_fp6_bits(bits_ptr[tid * 4 + 3]); - // } - // shmem[tid * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - // shmem[tid * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - // shmem[tid * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - // } - // __syncthreads(); - - // // TODO: write in larger word size - // for (int i = 0; i < 3; i++) { - // if (output_offset + BLOCK_SIZE * i + tid < n / 4 * 3) { - // fp6_ptr[BLOCK_SIZE * i + tid] = shmem[BLOCK_SIZE * i + tid]; - // } - // } +// define our own vector types since NVIDIA doesn't provide them. +typedef struct __align__(8) { __half x, y, z, w; } fp16_vec4; +typedef struct __align__(8) { __nv_bfloat16 x, y, z, w; } bf16_vec4; + +template +__global__ void to_fp6_packed_kernel(const T *fp_ptr, uint8_t *fp6_ptr, int n) { + const int tid = threadIdx.x; + const int input_offset = (blockIdx.x * blockDim.x) * 4; + const int output_offset = (blockIdx.x * blockDim.x) * 3; + + fp_ptr += input_offset; + fp6_ptr += output_offset; + + __shared__ uint8_t shmem[BLOCK_SIZE * 3]; + + if (input_offset + tid * 4 < n) { + uint8_t val0, val1, val2, val3; + + // vector load for coalesced memory read + if constexpr (std::is_same_v) { + float4 values = reinterpret_cast(fp_ptr)[tid]; + val0 = to_fp6_value(values.x); + val1 = to_fp6_value(values.y); + val2 = to_fp6_value(values.z); + val3 = to_fp6_value(values.w); + } else if constexpr (std::is_same_v || std::is_same_v) { + fp16_vec4 values = reinterpret_cast(fp_ptr)[tid]; + val0 = to_fp6_value(values.x); + val1 = to_fp6_value(values.y); + val2 = to_fp6_value(values.z); + val3 = to_fp6_value(values.w); + } else if constexpr (std::is_same_v || std::is_same_v) { + bf16_vec4 values = reinterpret_cast(fp_ptr)[tid]; + val0 = to_fp6_value(values.x); + val1 = to_fp6_value(values.y); + val2 = to_fp6_value(values.z); + val3 = to_fp6_value(values.w); + } else { + // fallback. no coalesced memory access. (assert false instead?) + val0 = to_fp6_value(fp_ptr[tid * 4]); + val1 = to_fp6_value(fp_ptr[tid * 4 + 1]); + val2 = to_fp6_value(fp_ptr[tid * 4 + 2]); + val3 = to_fp6_value(fp_ptr[tid * 4 + 3]); + } + + shmem[tid * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + shmem[tid * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + shmem[tid * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } + __syncthreads(); + + // coalesced memory write + // TODO: write in larger word size + for (int i = 0; i < 3; i++) { + if (output_offset + BLOCK_SIZE * i + tid < n / 4 * 3) { + fp6_ptr[BLOCK_SIZE * i + tid] = shmem[BLOCK_SIZE * i + tid]; + } + } } at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { @@ -381,16 +396,16 @@ at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { int grid_size = (n + (block_size * 4) - 1) / (block_size * 4); if (dtype == torch::kFloat32) { - const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); + const float *fp32_ptr = fp_tensor.data_ptr(); + to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { - const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); + const at::Half *fp16_ptr = fp_tensor.data_ptr(); + to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { - const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - bits_to_fp6_packed_kernel<<>>(bf16_ptr, fp6_ptr, n); + const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr(); + to_fp6_packed_kernel<<>>(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); From 965838c44d59c505bdea785da68b4b426fec852c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 05:17:24 +0000 Subject: [PATCH 37/80] clean up CPU impl --- torchao/csrc/cuda/fp6_llm/fp6.cu | 90 +++++++++++++------------------- 1 file changed, 37 insertions(+), 53 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 00f68e0352..227db8081c 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -61,10 +61,7 @@ __device__ __host__ static uint8_t to_fp6_value(T a) { // we need to do this because C++17 does not allow using struct as template non-type parameter // use the upper 16 bits for num exponent, lower 16 bits for num mantissa -static constexpr uint32_t encode_fp_spec(uint32_t n_exp_bits, uint32_t n_man_bits) { - return (n_exp_bits << 16u) | n_man_bits; -} - +static constexpr uint32_t encode_fp_spec(uint32_t n_exp, uint32_t n_man) { return (n_exp << 16u) | n_man; } static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u); static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u); static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u); @@ -79,13 +76,10 @@ __device__ __host__ static uint8_t to_fp6_bits(T bits) { constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; - // sanity checks. will be removed in template specialization. -#ifndef __CUDA_ARCH__ - if (N_EXP < 3) - throw std::invalid_argument("Number of exponent bits must be >= 3."); - if (N_MAN < 3) - throw std::invalid_argument("Number of mantissa bits must be >= 3."); -#endif + // sanity checks. will be removed in template instantiation. + // minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases. + static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); + static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); T remainder = 0u; T sign = bits >> N_EXP_MAN << 5u; @@ -138,18 +132,6 @@ __device__ __host__ static uint8_t to_fp6_bits(T bits) { return result; } -template -__device__ __host__ static void bits_4_to_fp6_4_packed(const T *bits_ptr, uint8_t *fp6_ptr) { - uint8_t val0 = to_fp6_bits(bits_ptr[0]); - uint8_t val1 = to_fp6_bits(bits_ptr[1]); - uint8_t val2 = to_fp6_bits(bits_ptr[2]); - uint8_t val3 = to_fp6_bits(bits_ptr[3]); - - fp6_ptr[0] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[2] = (val2 << 6) | (val3); // 2233 3333 -} - // assume the lower 6 bits contain the data __device__ __host__ static float fp6_to_fp32(const uint8_t a) { // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. @@ -179,14 +161,14 @@ __device__ __host__ static void fp6_4_packed_to_fp32_4(const uint8_t *fp6_ptr, f fp32_ptr[3] = fp6_to_fp32(bits2 & 0x3Fu); } -__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - if (idx < n) - fp6_4_packed_to_fp32_4(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); -} - namespace torchao { +template void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n; i++) + fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); +} + // this is useful for debugging at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { TORCH_CHECK(fp_tensor.is_contiguous()); @@ -201,24 +183,15 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - - #pragma omp parallel for - for (int i = 0; i < n; i++) - fp6_ptr[i] = to_fp6_bits(fp32_ptr[i]); + to_fp6_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - - #pragma omp parallel for - for (int i = 0; i < n; i++) - fp6_ptr[i] = to_fp6_bits(fp16_ptr[i]); + to_fp6_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - - #pragma omp parallel for - for (int i = 0; i < n; i++) - fp6_ptr[i] = to_fp6_bits(bf16_ptr[i]); + to_fp6_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -271,6 +244,20 @@ at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { return fp6_tensor; } +template void to_fp6_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n / 4; i++) { + uint8_t val0 = to_fp6_bits(bits_ptr[i * 4]); + uint8_t val1 = to_fp6_bits(bits_ptr[i * 4 + 1]); + uint8_t val2 = to_fp6_bits(bits_ptr[i * 4 + 2]); + uint8_t val3 = to_fp6_bits(bits_ptr[i * 4 + 3]); + + fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } +} + at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { TORCH_CHECK(fp_tensor.is_contiguous()); TORCH_CHECK(fp_tensor.is_cpu()); @@ -289,24 +276,15 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - - #pragma omp parallel for - for (int i = 0; i < n; i += 4) - bits_4_to_fp6_4_packed(fp32_ptr + i, fp6_ptr + i / 4 * 3); + to_fp6_packed_cpu_impl(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - - #pragma omp parallel for - for (int i = 0; i < n; i += 4) - bits_4_to_fp6_4_packed(fp16_ptr + i, fp6_ptr + i / 4 * 3); + to_fp6_packed_cpu_impl(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - - #pragma omp parallel for - for (int i = 0; i < n; i += 4) - bits_4_to_fp6_4_packed(bf16_ptr + i, fp6_ptr + i / 4 * 3); + to_fp6_packed_cpu_impl(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -445,6 +423,12 @@ at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { return fp32_tensor; } +__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { + const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; + if (idx < n) + fp6_4_packed_to_fp32_4(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); +} + at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); From a64421e313b2d8331a74315c54c8fb06ff2db662 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 05:55:20 +0000 Subject: [PATCH 38/80] add FP6->FP16/BF16 --- torchao/csrc/cuda/fp6_llm/fp6.cu | 199 +++++++++++++++++++++++-------- torchao/csrc/fp6_llm/fp6_llm.cpp | 4 +- torchao/ops.py | 8 +- 3 files changed, 153 insertions(+), 58 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 227db8081c..1fdb23c99a 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -132,8 +132,10 @@ __device__ __host__ static uint8_t to_fp6_bits(T bits) { return result; } -// assume the lower 6 bits contain the data -__device__ __host__ static float fp6_to_fp32(const uint8_t a) { +// assume the lower 6 bits contain the data. +// NOTE: probably not efficient for FP6->FP16 and FP6->BF16 on CPU since FP32->FP16/BF16 is slow. +template +__device__ __host__ static T from_fp6(uint8_t a) { // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. // this also handles subnormal numbers correctly. // FP6: SE EEMM @@ -147,18 +149,8 @@ __device__ __host__ static float fp6_to_fp32(const uint8_t a) { // we can correct this by direct FP32 multiplication, which also handles subnormal numbers. float result; std::memcpy(&result, &result_bits, sizeof(result)); - return result * 0x1p124; // 2^(127-3) -} - -__device__ __host__ static void fp6_4_packed_to_fp32_4(const uint8_t *fp6_ptr, float *fp32_ptr) { - uint8_t bits0 = fp6_ptr[0]; // 0000 0011 - uint8_t bits1 = fp6_ptr[1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[2]; // 2233 3333 - - fp32_ptr[0] = fp6_to_fp32(bits0 >> 2); - fp32_ptr[1] = fp6_to_fp32(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp32_ptr[2] = fp6_to_fp32(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp32_ptr[3] = fp6_to_fp32(bits2 & 0x3Fu); + result *= 0x1p124; // 2^(127-3) + return static_cast(result); } namespace torchao { @@ -223,7 +215,7 @@ at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { auto dtype = fp_tensor.dtype(); constexpr int block_size = 256; - int grid_size = (n + block_size - 1) / block_size; + const int grid_size = (n + block_size - 1) / block_size; if (dtype == torch::kFloat32) { const float *fp32_ptr = fp_tensor.data_ptr(); @@ -371,7 +363,7 @@ at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { // times 4 since each thread will handle 4 values constexpr int block_size = 256; - int grid_size = (n + (block_size * 4) - 1) / (block_size * 4); + const int grid_size = (n + (block_size * 4) - 1) / (block_size * 4); if (dtype == torch::kFloat32) { const float *fp32_ptr = fp_tensor.data_ptr(); @@ -392,85 +384,188 @@ at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { return fp6_tensor; } -__global__ void fp6_unpacked_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { +template +void from_fp6_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n; i++) + fp_ptr[i] = from_fp6(fp6_ptr[i]); +} + +at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (dtype == torch::kFloat32) { + from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kFloat16) { + from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kBFloat16) { + from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); + } + + return fp_tensor; +} + +template +__global__ void from_fp6_unpacked_kernel(const uint8_t *fp6_ptr, T *fp_ptr, int n) { + // TODO: use vector load for reading from global memory const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) - fp32_ptr[idx] = fp6_to_fp32(fp6_ptr[idx]); + fp_ptr[idx] = from_fp6(fp6_ptr[idx]); } -at::Tensor fp6_unpacked_to_fp32(at::Tensor fp6_tensor) { +at::Tensor from_fp6_unpacked_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); + TORCH_CHECK(fp6_tensor.is_cuda()); - at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); - at::Tensor fp32_tensor = at::empty(fp6_tensor.sizes(), options); + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - float *fp32_ptr = fp32_tensor.data_ptr(); int n = fp6_tensor.numel(); - if (fp6_tensor.is_cpu()) { - #pragma omp parallel for - for (int i = 0; i < n; i++) - fp32_ptr[i] = fp6_to_fp32(fp6_ptr[i]); + constexpr int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + if (dtype == torch::kFloat32) { + from_fp6_unpacked_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kFloat16) { + from_fp6_unpacked_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kBFloat16) { + from_fp6_unpacked_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); + } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 4 - 1) / (block_size * 4); - fp6_unpacked_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); + throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); } - return fp32_tensor; + return fp_tensor; } -__global__ void fp6_packed_to_fp32_kernel(const uint8_t *fp6_ptr, float *fp32_ptr, int n) { - const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 3; - if (idx < n) - fp6_4_packed_to_fp32_4(fp6_ptr + idx, fp32_ptr + idx / 3 * 4); +template +void from_fp6_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n / 3; i++) { + uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011 + uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333 + + fp_ptr[i * 4] = from_fp6(bits0 >> 2); + fp_ptr[i * 4 + 1] = from_fp6(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp_ptr[i * 4 + 2] = from_fp6(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp_ptr[i * 4 + 3] = from_fp6(bits2 & 0x3Fu); + } } -at::Tensor fp6_packed_to_fp32(at::Tensor fp6_tensor) { +at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu() || fp6_tensor.is_cuda()); + TORCH_CHECK(fp6_tensor.is_cpu()); TORCH_CHECK(fp6_tensor.ndimension() == 2); int M = fp6_tensor.size(0); int N = fp6_tensor.size(1); TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); - at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat32).device(fp6_tensor.device()); - at::Tensor fp32_tensor = at::empty({M, N / 3 * 4}, options); + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - float *fp32_ptr = fp32_tensor.data_ptr(); int n = fp6_tensor.numel(); - if (fp6_tensor.is_cpu()) { - #pragma omp parallel for - for (int i = 0; i < n; i += 3) - fp6_4_packed_to_fp32_4(fp6_ptr + i, fp32_ptr + i / 3 * 4); + if (dtype == torch::kFloat32) { + from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kFloat16) { + from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kBFloat16) { + from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); + } + + return fp_tensor; +} + +template +__global__ void from_fp6_packed_kernel(const uint8_t *fp6_ptr, T *fp_ptr, int n) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n / 3) { + // TODO: use vector load for reading from global memory + uint8_t bits0 = fp6_ptr[idx * 3]; // 0000 0011 + uint8_t bits1 = fp6_ptr[idx * 3 + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[idx * 3 + 2]; // 2233 3333 + + fp_ptr[idx * 4] = from_fp6(bits0 >> 2); + fp_ptr[idx * 4 + 1] = from_fp6(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp_ptr[idx * 4 + 2] = from_fp6(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp_ptr[idx * 4 + 3] = from_fp6(bits2 & 0x3Fu); + } +} + +at::Tensor from_fp6_packed_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cuda()); + + int M = fp6_tensor.size(0); + int N = fp6_tensor.size(1); + TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + // times 3 because each thread read 3 bytes (which represent 4 FP6 values) + constexpr int block_size = 256; + const int grid_size = (n + block_size * 3 - 1) / (block_size * 3); + + if (dtype == torch::kFloat32) { + from_fp6_packed_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kFloat16) { + from_fp6_packed_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); + + } else if (dtype == torch::kBFloat16) { + from_fp6_packed_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); + } else { - constexpr int block_size = 256; - int grid_size = (n + block_size * 3 - 1) / (block_size * 3); - fp6_packed_to_fp32_kernel<<>>(fp6_ptr, fp32_ptr, n); + throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); } - return fp32_tensor; + return fp_tensor; } TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cpu); m.impl("torchao::to_fp6_packed", &to_fp6_packed_cpu); - m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); + m.impl("torchao::from_fp6_unpacked", &from_fp6_unpacked_cpu); + m.impl("torchao::from_fp6_packed", &from_fp6_packed_cpu); } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cuda); m.impl("torchao::to_fp6_packed", &to_fp6_packed_cuda); - m.impl("torchao::fp6_unpacked_to_fp32", &fp6_unpacked_to_fp32); - m.impl("torchao::fp6_packed_to_fp32", &fp6_packed_to_fp32); + m.impl("torchao::from_fp6_unpacked", &from_fp6_unpacked_cuda); + m.impl("torchao::from_fp6_packed", &from_fp6_packed_cuda); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 8e4c20f08a..32923eb185 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -11,6 +11,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); m.def("to_fp6_packed(Tensor fp16_tensor) -> Tensor"); - m.def("fp6_unpacked_to_fp32(Tensor fp6_tensor) -> Tensor"); - m.def("fp6_packed_to_fp32(Tensor fp6_tensor) -> Tensor"); + m.def("from_fp6_unpacked(Tensor fp6_tensor, ScalarType dtype) -> Tensor"); + m.def("from_fp6_packed(Tensor fp6_tensor, ScalarType dtype) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 89f2a910da..c96ed12a35 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -67,13 +67,13 @@ def _(fp_tensor): return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp_tensor.device, dtype=torch.uint8) -def fp6_unpacked_to_fp32(fp6_tensor: Tensor) -> Tensor: - return torch.ops.torchao.fp6_unpacked_to_fp32.default(fp6_tensor) +def from_fp6_unpacked(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_fp6_unpacked.default(fp6_tensor, dtype) -def fp6_packed_to_fp32(fp6_tensor: Tensor) -> Tensor: +def from_fp6_packed(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: *leading_dims, last_dim = fp6_tensor.shape - return torch.ops.torchao.fp6_packed_to_fp32.default(fp6_tensor.view(-1, last_dim)).view(*leading_dims, -1) + return torch.ops.torchao.from_fp6_packed.default(fp6_tensor.view(-1, last_dim), dtype).view(*leading_dims, -1) def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: From a4b7c7a5c8dda18bd3a1a271c7cdb600413e2926 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 05:59:22 +0000 Subject: [PATCH 39/80] add dim check --- torchao/csrc/cuda/fp6_llm/fp6.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 1fdb23c99a..e44cbe96ee 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -523,6 +523,7 @@ at::Tensor from_fp6_packed_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); TORCH_CHECK(fp6_tensor.is_cuda()); + TORCH_CHECK(fp6_tensor.ndimension() == 2); int M = fp6_tensor.size(0); int N = fp6_tensor.size(1); From 3e4c1c1fbdd6b230344624fb1b9dacb13d28174d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 06:39:23 +0000 Subject: [PATCH 40/80] add qtorch to dev req --- dev-requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dev-requirements.txt b/dev-requirements.txt index 6dadb274aa..156e8766d2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,3 +12,6 @@ pandas # Custom CUDA Extensions ninja + +# for FP6-LLM (can be removed once we remove fp16_to_fp6_original()) +qtorch From 632af9320922fdaba4647933a9a1bff5b23ab879 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 07:45:38 +0000 Subject: [PATCH 41/80] handle exception with OpenMP --- setup.py | 4 ++-- torchao/csrc/cuda/fp6_llm/fp6.cu | 38 +++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index cbf24e98eb..0cac46ab96 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_extensions(): extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] - # extra_link_args.append("-fopenmp") + extra_link_args.append("-fopenmp") extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", @@ -55,7 +55,7 @@ def get_extensions(): ], "nvcc": [ "-O3" if not debug_mode else "-O0", - # "-Xcompiler", "-fopenmp", + "-Xcompiler", "-fopenmp", ] } if debug_mode: diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index e44cbe96ee..31730d4a63 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -10,6 +10,16 @@ #include +class fp6_nan_inf : public std::invalid_argument { +public: + fp6_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6.") { } +}; + +class fp6_overflow : public std::invalid_argument { +public: + fp6_overflow() : std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf. Make sure input < 30.0") { } +}; + // need to do this trick so that static_assert(false) only evaluates at template instantiation. template constexpr std::false_type always_false{}; @@ -36,10 +46,8 @@ __device__ __host__ static uint8_t to_fp6_value(T a) { static_assert(always_false, "Only float, __half, __nv_bfloat16, c10::Half, and c10::BFloat16 are suppored"); #ifndef __CUDA_ARCH__ - if (std::isnan(a) | std::isinf(a)) - throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); - if (std::abs(a) >= 30.0f) - throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); + if (std::isnan(fp32_value) | std::isinf(fp32_value)) throw fp6_nan_inf(); + if (std::abs(fp32_value) >= 30.0f) throw fp6_overflow(); #endif fp32_value *= 0x1p-124; // 2^(127-3) @@ -91,11 +99,10 @@ __device__ __host__ static uint8_t to_fp6_bits(T bits) { // only checks for invalid values on CPU, since we can't throw exception in CUDA #ifndef __CUDA_ARCH__ // all exponent bits are 1s - if (bits >= (ones_mask(N_EXP) << N_MAN)) - throw std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6."); + if (bits >= (ones_mask(N_EXP) << N_MAN)) throw fp6_nan_inf(); + // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) - if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) - throw std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf."); + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw fp6_overflow(); #endif // FP6 normal number (E>=001) @@ -156,9 +163,20 @@ __device__ __host__ static T from_fp6(uint8_t a) { namespace torchao { template void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // exception within OpenMP parallel region must be caught. + // set a flag when exception occurs, then re-raise it. + bool found_nan_inf = false; + bool found_overflow = false; + #pragma omp parallel for - for (int i = 0; i < n; i++) - fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); + for (int i = 0; i < n; i++) { + try { fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); } + catch (fp6_nan_inf &e) { found_nan_inf = true; } + catch (fp6_overflow &e) { found_overflow = true; } + } + + if (found_nan_inf) throw fp6_nan_inf(); + if (found_overflow) throw fp6_overflow(); } // this is useful for debugging From 6c6fe83e5bf6467177776ae423ceb70a287fc113 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 08:21:02 +0000 Subject: [PATCH 42/80] handle exception in OpenMP --- torchao/csrc/cuda/fp6_llm/fp6.cu | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu index 31730d4a63..e90e4f2820 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6.cu @@ -171,8 +171,8 @@ template void to_fp6_unpacked_cpu_impl(const T *b #pragma omp parallel for for (int i = 0; i < n; i++) { try { fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); } - catch (fp6_nan_inf &e) { found_nan_inf = true; } - catch (fp6_overflow &e) { found_overflow = true; } + catch (fp6_nan_inf) { found_nan_inf = true; } + catch (fp6_overflow) { found_overflow = true; } } if (found_nan_inf) throw fp6_nan_inf(); @@ -255,17 +255,29 @@ at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { } template void to_fp6_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // exception within OpenMP parallel region must be caught. + // set a flag when exception occurs, then re-raise it. + bool found_nan_inf = false; + bool found_overflow = false; + #pragma omp parallel for for (int i = 0; i < n / 4; i++) { - uint8_t val0 = to_fp6_bits(bits_ptr[i * 4]); - uint8_t val1 = to_fp6_bits(bits_ptr[i * 4 + 1]); - uint8_t val2 = to_fp6_bits(bits_ptr[i * 4 + 2]); - uint8_t val3 = to_fp6_bits(bits_ptr[i * 4 + 3]); - - fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + try { + uint8_t val0 = to_fp6_bits(bits_ptr[i * 4]); + uint8_t val1 = to_fp6_bits(bits_ptr[i * 4 + 1]); + uint8_t val2 = to_fp6_bits(bits_ptr[i * 4 + 2]); + uint8_t val3 = to_fp6_bits(bits_ptr[i * 4 + 3]); + + fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } + catch (fp6_nan_inf) { found_nan_inf = true; } + catch (fp6_overflow) { found_overflow = true; } } + + if (found_nan_inf) throw fp6_nan_inf(); + if (found_overflow) throw fp6_overflow(); } at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { From cb08b37b14136f9877a745bbc4ee3be5cbcaf259 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 08:47:11 +0000 Subject: [PATCH 43/80] add tests --- test/test_ops.py | 95 +++++++++++++++++++++++++++++++++++++++--------- torchao/ops.py | 28 +++++++++++++- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 7124cd45e2..6f419a5ee8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -141,36 +141,97 @@ def _skip_cpu(self): if not torch.cuda.is_available(): self.skipTest("CUDA not available. We don't compile for CPU-only build") + @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) + def test_to_fp6_unpacked(self, device, dtype): + self._skip_cpu() + inputs = torch.randn(128, 128, device=device, dtype=dtype) + + # smoke test + torchao.ops.to_fp6_unpacked(inputs) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.to_fp6_unpacked, (inputs,), test_utils=test_utils) + + @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) + def test_to_fp6_packed(self, device, dtype): + self._skip_cpu() + inputs = torch.randn(128, 128, device=device, dtype=dtype) + + # smoke test + torchao.ops.to_fp6_packed(inputs) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.to_fp6_packed, (inputs,), test_utils=test_utils) + + @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) + def test_from_fp6_unpacked(self, device, dtype): + self._skip_cpu() + inputs = torch.randint(256, size=(128, 128 // 4 * 3), device=device, dtype=torch.uint8) + + # smoke test + torchao.ops.from_fp6_unpacked(inputs, dtype) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.from_fp6_unpacked, (inputs, dtype), test_utils=test_utils) + + @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) + def test_from_fp6_packed(self, device, dtype): + self._skip_cpu() + inputs = torch.randint(256, size=(128, 128 // 4 * 3), device=device, dtype=torch.uint8) + + # smoke test + torchao.ops.from_fp6_packed(inputs, dtype) + + # comprehensive testing + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] + opcheck(torch.ops.torchao.from_fp6_packed, (inputs, dtype), test_utils=test_utils) + + def test_to_fp6_unpacked_shape(self): + for shape in [(), (0,), (10,), (20, 20)]: + x = torch.randn(shape) + result = torchao.ops.to_fp6_unpacked(x) + assert result.shape == shape + + def test_to_fp6_packed_shape(self): + for shape in [(4,), (20, 20)]: + x = torch.randn(shape) + result = torchao.ops.to_fp6_packed(x) + assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) + @parameterized.expand( [ - (0.0, 0b000000), # simple values - (1.0, 0b001100), # normal numbers - (1.25, 0b001101), - (28.0, 0b011111), # max - (0.1875, 0b00011), # subnormal number + (0.0, 0b000000), # exact values + (1.0, 0b001100), # normal numbers + (1.25, 0b001101), + (28.0, 0b011111), # max + (0.1875, 0b000011), # subnormal number (0.0625, 0b000001), # min - (29.0, 0b011111), # normal round down - (26.0, 0b011110), # normal round to nearest even + (29.0, 0b011111), # normal round down + (26.0, 0b011110), # normal round to nearest even (0.1251, 0b000010), # subnormal round down (0.0314, 0b000001), # subnormal round up - (0.03, 0b000000), # underflow + (0.03, 0b000000), # underflow ] ) - def test_to_fp6_correctness(self, input, output): + def test_to_fp6_unpacked_correctness(self, input, output): self._skip_cpu() - for dtype in (torch.float32, torch.float16, torch.bfloat16): - x = torch.tensor(input, dtype=dtype) - assert torchao.ops.to_fp6_unpacked(x).item() == output - assert torchao.ops.to_fp6_unpacked(-x).item() == (output | 0b100000) - assert torchao.ops.to_fp6_unpacked(x.cuda()).item() == output - assert torchao.ops.to_fp6_unpacked(-x.cuda()).item() == (output | 0b100000) - - @parameterized.expand([30.0, 100.0, float("inf"), float("nan")]) + for device in ("cpu", "cuda"): + for dtype in (torch.float32, torch.float16, torch.bfloat16): + x = torch.tensor(input, device=device, dtype=dtype) + assert torchao.ops.to_fp6_unpacked(x).item() == output + assert torchao.ops.to_fp6_unpacked(-x).item() == (output | 0b100000) + + @parameterized.expand([30.0, -100.0, float("inf"), float("nan")]) def test_to_fp6_exception(self, input): self._skip_cpu() x = torch.tensor(input) with self.assertRaises(Exception): torchao.ops.to_fp6_unpacked(x) + with self.assertRaises(Exception): + torchao.ops.to_fp6_packed(x) if __name__ == "__main__": diff --git a/torchao/ops.py b/torchao/ops.py index c96ed12a35..7ba448a346 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -48,6 +48,10 @@ def to_fp6_unpacked(fp_tensor: Tensor) -> Tensor: @torch.library.impl_abstract("torchao::to_fp6_unpacked") def _(fp_tensor): + torch._check( + fp_tensor.dtype in (torch.float32, torch.float16, torch.bfloat16), + lambda: f"inputs must be FP32, FP16, or BF16, got {fp_tensor.dtype}", + ) return torch.empty_like(fp_tensor, dtype=torch.uint8) @@ -60,22 +64,42 @@ def to_fp6_packed(fp_tensor: Tensor) -> Tensor: def _(fp_tensor): torch._check( fp_tensor.dtype in (torch.float32, torch.float16, torch.bfloat16), - lambda: f"weight must be FP32, FP16, or BF16, got {fp_tensor.dtype}", + lambda: f"inputs must be FP32, FP16, or BF16, got {fp_tensor.dtype}", ) *leading_dims, last_dim = fp_tensor.shape torch._check(last_dim % 4 == 0, lambda: f"last dimension must be a multiple of 4, got {last_dim}") - return torch.empty(*leading_dims, last_dim * 3 / 4, device=fp_tensor.device, dtype=torch.uint8) + return torch.empty(*leading_dims, last_dim * 3 // 4, device=fp_tensor.device, dtype=torch.uint8) def from_fp6_unpacked(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: return torch.ops.torchao.from_fp6_unpacked.default(fp6_tensor, dtype) +@torch.library.impl_abstract("torchao::from_fp6_unpacked") +def _(fp6_tensor, dtype): + torch._check( + dtype in (torch.float32, torch.float16, torch.bfloat16), + lambda: f"outputs must be FP32, FP16, or BF16, got {dtype}", + ) + return torch.empty_like(fp6_tensor, device=fp6_tensor.device, dtype=dtype) + + def from_fp6_packed(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: *leading_dims, last_dim = fp6_tensor.shape return torch.ops.torchao.from_fp6_packed.default(fp6_tensor.view(-1, last_dim), dtype).view(*leading_dims, -1) +@torch.library.impl_abstract("torchao::from_fp6_packed") +def _(fp6_tensor, dtype): + torch._check( + dtype in (torch.float32, torch.float16, torch.bfloat16), + lambda: f"outputs must be FP32, FP16, or BF16, got {dtype}", + ) + *leading_dims, last_dim = fp6_tensor.shape + torch._check(last_dim % 3 == 0, lambda: f"last dimension must be a multiple of 3, got {last_dim}") + return torch.empty(*leading_dims, last_dim * 4 // 3, device=fp6_tensor.device, dtype=dtype) + + def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor (containing only FP6 values) into FP6 tensor. From 9f94030a76ddc5bbe14d57ac566b346c07705961 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 09:14:09 +0000 Subject: [PATCH 44/80] more tests --- test/test_ops.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 6f419a5ee8..251237cc4f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -224,6 +224,20 @@ def test_to_fp6_unpacked_correctness(self, input, output): assert torchao.ops.to_fp6_unpacked(x).item() == output assert torchao.ops.to_fp6_unpacked(-x).item() == (output | 0b100000) + @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) + def test_to_fp6_packed_correctness(self, device, dtype): + x = torch.randn(128, 128, device=device, dtype=dtype) + results_unpacked = torchao.ops.to_fp6_unpacked(x) + results_packed = torchao.ops.to_fp6_packed(x) + + val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + + expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) + assert (results_packed == expected_packed).all() + @parameterized.expand([30.0, -100.0, float("inf"), float("nan")]) def test_to_fp6_exception(self, input): self._skip_cpu() @@ -233,6 +247,45 @@ def test_to_fp6_exception(self, input): with self.assertRaises(Exception): torchao.ops.to_fp6_packed(x) + @parameterized.expand( + [ + (0b000000, 0.0), + (0b001100, 1.0), + (0b011111, 28.0), + (0b000001, 0.0625), + (0b001110, 1.5), + (0b000011, 0.1875), + ] + ) + def test_from_fp6_unpacked_correctness(self, input, output): + self._skip_cpu() + for device in ("cpu", "cuda"): + for dtype in (torch.float32, torch.float16, torch.bfloat16): + x = torch.tensor(input, device=device, dtype=torch.uint8) + result = torchao.ops.from_fp6_unpacked(x, dtype) + assert result.dtype == dtype + assert result.item() == output + + x = torch.tensor(input | 0b100000, device=device, dtype=torch.uint8) + result = torchao.ops.from_fp6_unpacked(x, dtype) + assert result.dtype == dtype + assert result.item() == -output + + @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) + def test_from_fp6_packed_correctness(self, device, dtype): + x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) + results = torchao.ops.from_fp6_packed(x, dtype=dtype) + + bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) + x_unpacked0 = bits0 >> 2 + x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) + x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) + x_unpacked3 = bits2 & 0x3F + + x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) + expected = torchao.ops.from_fp6_unpacked(x_unpacked, dtype) + assert (results == expected).all() + if __name__ == "__main__": unittest.main() From 7b7e8235378b83fcaada8724b1c0713e716a8093 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 09:16:03 +0000 Subject: [PATCH 45/80] simplify test --- test/test_ops.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 251237cc4f..56fa10ac28 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -66,14 +66,7 @@ def test_prepack_fp6_weight(self): def test_fp16_to_fp6(self): OC = 256 IC = 256 - - # in this fp6, we use 3 bits for exponent and 2 bits for mantissa - # also, we don't have nan/inf - fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11 - fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number) fp16_weight = torch.randn((OC, IC), dtype=torch.float16) - fp16_weight.clip_(-fp6_absmax, fp6_absmax) - fp16_weight[fp16_weight.abs() < fp6_absmin] = 0 # smoke test torchao.ops.fp16_to_fp6_original(fp16_weight) From 7c1ff7d3a2b6a7ceee544b59cfbd5791e4223960 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 09:16:17 +0000 Subject: [PATCH 46/80] rename --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 56fa10ac28..81db4a7a06 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -63,7 +63,7 @@ def test_prepack_fp6_weight(self): opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp16_to_fp6(self): + def test_fp16_to_fp6_original(self): OC = 256 IC = 256 fp16_weight = torch.randn((OC, IC), dtype=torch.float16) From 0472b064c8dc93919df4a91e210f6b3e2ed320b0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 09:21:58 +0000 Subject: [PATCH 47/80] add back checks --- test/test_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 81db4a7a06..966e400e4d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -68,6 +68,10 @@ def test_fp16_to_fp6_original(self): IC = 256 fp16_weight = torch.randn((OC, IC), dtype=torch.float16) + # the original FP16->FP6 kernel checks for overflow/underflow + fp16_weight.clip_(-28.0, 28.0) + fp16_weight[fp16_weight.abs() < 0.0625] = 0.0 + # smoke test torchao.ops.fp16_to_fp6_original(fp16_weight) From 0bda927dd06126661f475c044d2dabfd57e46801 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 09:31:55 +0000 Subject: [PATCH 48/80] update docs --- torchao/ops.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/torchao/ops.py b/torchao/ops.py index 7ba448a346..3da8351619 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -43,6 +43,11 @@ def _(fp6_weight): def to_fp6_unpacked(fp_tensor: Tensor) -> Tensor: + """ + Convert FP32/FP16/BF16 tensor to FP6. Each FP6 value is stored in the lower 6 bits of an uint8, + thus 2 bits are wasted. This is useful for debugging, since you can access the bits of FP6 + directly via tensor indexing. + """ return torch.ops.torchao.to_fp6_unpacked.default(fp_tensor) @@ -56,6 +61,10 @@ def _(fp_tensor): def to_fp6_packed(fp_tensor: Tensor) -> Tensor: + """ + Convert FP32/FP16/BF16 tensor to FP6. Every 4 FP32/FP16/BF16 values are packed into 3 uint8 + (4 x 6 bits = 3 x 8 bits). The last dimension must be a multiple of 4. + """ *leading_dims, last_dim = fp_tensor.shape return torch.ops.torchao.to_fp6_packed.default(fp_tensor.view(-1, last_dim)).view(*leading_dims, -1) @@ -72,11 +81,15 @@ def _(fp_tensor): def from_fp6_unpacked(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: + """ + Inverse of to_fp6_unpacked(). + """ return torch.ops.torchao.from_fp6_unpacked.default(fp6_tensor, dtype) @torch.library.impl_abstract("torchao::from_fp6_unpacked") def _(fp6_tensor, dtype): + torch._check(fp6_tensor.dtype == torch.uint8, lambda: f"inputs must be uint8, got {fp6_tensor.dtype}") torch._check( dtype in (torch.float32, torch.float16, torch.bfloat16), lambda: f"outputs must be FP32, FP16, or BF16, got {dtype}", @@ -85,12 +98,16 @@ def _(fp6_tensor, dtype): def from_fp6_packed(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: + """ + Inverse of to_fp6_packed(). The last dimension must be a multiple of 3. + """ *leading_dims, last_dim = fp6_tensor.shape return torch.ops.torchao.from_fp6_packed.default(fp6_tensor.view(-1, last_dim), dtype).view(*leading_dims, -1) @torch.library.impl_abstract("torchao::from_fp6_packed") def _(fp6_tensor, dtype): + torch._check(fp6_tensor.dtype == torch.uint8, lambda: f"inputs must be uint8, got {fp6_tensor.dtype}") torch._check( dtype in (torch.float32, torch.float16, torch.bfloat16), lambda: f"outputs must be FP32, FP16, or BF16, got {dtype}", @@ -102,7 +119,7 @@ def _(fp6_tensor, dtype): def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ - Pack FP16 tensor (containing only FP6 values) into FP6 tensor. + Pack FP16 tensor to FP6 tensor. qtorch is required to use this function. """ try: from qtorch.quant import float_quantize From a21837ca2dbf50cf703feecd2e2843a1cb827936 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 20:09:59 +0800 Subject: [PATCH 49/80] add pure pytorch impl --- torchao/ops.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/torchao/ops.py b/torchao/ops.py index 3da8351619..9d099ffea2 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -187,3 +187,32 @@ def _(fp6_tensor, fp16_scale): torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") return fp16_scale.new_empty((OC, _IC * 16 // 3)) + + +def to_fp6_pt(tensor: torch.Tensor, unpacked: bool = False) -> Tensor: + tensor = tensor.float() + tensor = tensor * 2.0 ** (-124) + bits = tensor.view(torch.int32) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + remainder = bits & 0x1F_FFFF + do_round_up = torch.logical_or( + remainder > 0x10_0000, + torch.logical_and(remainder == 0x10_0000, result & 1) + ) + result = torch.where(do_round_up, result + 1, result) + result = result.to(torch.uint8) + + if unpacked: + return result + + # pre-allocate output tensor is faster than using torch.stack() + outputs = torch.empty(tensor.shape[:-1] + (tensor.shape[-1] // 4, 3), device=tensor.device, dtype=torch.uint8) + val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) + outputs[..., 0] = (val0 << 2) | (val1 >> 4) # 0000 0011 + outputs[..., 1] = (val1 << 4) | (val2 >> 2) # 1111 2222 + outputs[..., 2] = (val2 << 6) | (val3); # 2233 3333 + return outputs.flatten(-2) From 6101869f0f1e2c1c29c17f0cb5296fa0a3677e5f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 20 May 2024 20:37:03 +0800 Subject: [PATCH 50/80] add benchmark --- benchmarks/benchmark_fp6_conversion.py | 44 ++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 benchmarks/benchmark_fp6_conversion.py diff --git a/benchmarks/benchmark_fp6_conversion.py b/benchmarks/benchmark_fp6_conversion.py new file mode 100644 index 0000000000..64c9dd534c --- /dev/null +++ b/benchmarks/benchmark_fp6_conversion.py @@ -0,0 +1,44 @@ +from functools import partial + +import torch +import torchao +import pandas as pd +from torch.utils.benchmark import Timer + + +def benchmark(f, weight): + measurement = Timer( + stmt="f(weight)", + globals={"f": f, "weight": weight}, + ).blocked_autorange() + return measurement.median * 1000 + + +if __name__ == "__main__": + M = 8192 + N = 8192 + + fp32_weight = torch.randn(M, N) + fp32_weight_cuda = fp32_weight.cuda() + fp16_weight = fp32_weight.half() + fp16_weight_cuda = fp16_weight.cuda() + + functions = [ + ("original (FP6 packed)", torchao.ops.fp16_to_fp6_original), + # ("custom C++/CUDA (FP6 unpacked)", torchao.ops.to_fp6_unpacked), + ("custom C++/CUDA (FP6 packed)", torchao.ops.to_fp6_packed), + # ("PyTorch + torch.compile (FP6 unpacked)", partial(torch.compile(torchao.ops.to_fp6_pt), unpacked=True)), + ("PyTorch + torch.compile (FP6 packed)", partial(torch.compile(torchao.ops.to_fp6_pt), unpacked=False)), + ] + + results = [] + for name, f in functions: + results.append([name, "CPU", "FP32->FP6", benchmark(f, fp32_weight)]) + results.append([name, "CPU", "FP16->FP6", benchmark(f, fp16_weight)]) + if name != "original (FP6 packed)": + results.append([name, "CUDA", "FP32->FP6", benchmark(f, fp32_weight_cuda)]) + results.append([name, "CUDA", "FP16->FP6", benchmark(f, fp16_weight_cuda)]) + + df = pd.DataFrame(results, columns=["op", "device", "dtype", "time (m/s)"]) + df["op"] = df["op"].str.removesuffix(" (FP6 packed)") + print(df.to_markdown(index=False)) From df7932b89ca277aee59d83afd294f8ef5872094e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 02:32:37 +0000 Subject: [PATCH 51/80] update benchmark script --- benchmarks/benchmark_fp6_conversion.py | 29 ++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_fp6_conversion.py b/benchmarks/benchmark_fp6_conversion.py index 64c9dd534c..4cd5cfe621 100644 --- a/benchmarks/benchmark_fp6_conversion.py +++ b/benchmarks/benchmark_fp6_conversion.py @@ -24,21 +24,24 @@ def benchmark(f, weight): fp16_weight_cuda = fp16_weight.cuda() functions = [ - ("original (FP6 packed)", torchao.ops.fp16_to_fp6_original), - # ("custom C++/CUDA (FP6 unpacked)", torchao.ops.to_fp6_unpacked), - ("custom C++/CUDA (FP6 packed)", torchao.ops.to_fp6_packed), - # ("PyTorch + torch.compile (FP6 unpacked)", partial(torch.compile(torchao.ops.to_fp6_pt), unpacked=True)), - ("PyTorch + torch.compile (FP6 packed)", partial(torch.compile(torchao.ops.to_fp6_pt), unpacked=False)), + ("original", torchao.ops.fp16_to_fp6_original), + ("C++/CUDA extension", torchao.ops.to_fp6_packed), + ("PyTorch + torch.compile (default)", torch.compile(torchao.ops.to_fp6_pt)), + ("PyTorch + torch.compile (max-autotune)", torch.compile(torchao.ops.to_fp6_pt, mode="max-autotune")), + + # ("C++/CUDA extension (no bit-packing)", torchao.ops.to_fp6_unpacked), + # ("PyTorch + torch.compile (no bit-packing)", partial(torch.compile(torchao.ops.to_fp6_pt), unpacked=True)), ] results = [] for name, f in functions: - results.append([name, "CPU", "FP32->FP6", benchmark(f, fp32_weight)]) - results.append([name, "CPU", "FP16->FP6", benchmark(f, fp16_weight)]) - if name != "original (FP6 packed)": - results.append([name, "CUDA", "FP32->FP6", benchmark(f, fp32_weight_cuda)]) - results.append([name, "CUDA", "FP16->FP6", benchmark(f, fp16_weight_cuda)]) - - df = pd.DataFrame(results, columns=["op", "device", "dtype", "time (m/s)"]) - df["op"] = df["op"].str.removesuffix(" (FP6 packed)") + results.append(["CPU", "FP32->FP6", name, benchmark(f, fp32_weight)]) + results.append(["CPU", "FP16->FP6", name, benchmark(f, fp16_weight)]) + + if name != "original": + results.append(["CUDA", "FP32->FP6", name, benchmark(f, fp32_weight_cuda)]) + results.append(["CUDA", "FP16->FP6", name, benchmark(f, fp16_weight_cuda)]) + + df = pd.DataFrame(results, columns=["device", "dtype", "op", "time (m/s)"]) + df = df.sort_values(["device", "dtype"]) print(df.to_markdown(index=False)) From bdbd907d7ad2dcce0cf0203be26e1390dc80e777 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 05:55:42 +0000 Subject: [PATCH 52/80] add triton kernel --- torchao/ops.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/torchao/ops.py b/torchao/ops.py index 9d099ffea2..5ef92b4807 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,5 +1,11 @@ import torch from torch import Tensor +from torch.utils._triton import has_triton + +if has_triton(): + import triton + from triton import language as tl + def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: """ @@ -189,7 +195,57 @@ def _(fp6_tensor, fp16_scale): return fp16_scale.new_empty((OC, _IC * 16 // 3)) +if has_triton(): + @triton.jit + def _to_fp6_triton(x: tl.tensor): + x = x.to(tl.float32) + x = x * 2.0 ** (-124) + bits = x.to(tl.int32, bitcast=True) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + remainder = bits & 0x1F_FFFF + do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & (result & 1)) + result = tl.where(do_round_up, result + 1, result) + return result.to(tl.uint8) + + @triton.jit + def _to_fp6_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n + + # strided memory read. there will be uncoalesced memory access + val0 = _to_fp6_triton(tl.load(in_ptr + offsets * 4, mask)) + val1 = _to_fp6_triton(tl.load(in_ptr + offsets * 4 + 1, mask)) + val2 = _to_fp6_triton(tl.load(in_ptr + offsets * 4 + 2, mask)) + val3 = _to_fp6_triton(tl.load(in_ptr + offsets * 4 + 3, mask)) + + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + + # strided memory write. there will be uncoalesced memory access + tl.store(out_ptr + offsets * 3, bits0, mask) + tl.store(out_ptr + offsets * 3 + 1, bits1, mask) + tl.store(out_ptr + offsets * 3 + 2, bits2, mask) + +else: + _to_fp6_triton_kernel = None + + def to_fp6_pt(tensor: torch.Tensor, unpacked: bool = False) -> Tensor: + if tensor.device.type == "cuda" and _to_fp6_triton_kernel is not None: + out_shape = tensor.shape[:-1] + (tensor.shape[-1] // 4 * 3,) + output = torch.empty(out_shape, device=tensor.device, dtype=torch.uint8) + + n = tensor.numel() // 4 + grid_size = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) + _to_fp6_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) + + return output + tensor = tensor.float() tensor = tensor * 2.0 ** (-124) bits = tensor.view(torch.int32) From 42bf7716d3ed861b911dee0c84ee24992f23b5b5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 05:57:15 +0000 Subject: [PATCH 53/80] remove CUDA kernel --- setup.py | 2 - torchao/csrc/cuda/fp6_llm/fp6.cu | 602 ------------------------------- torchao/csrc/fp6_llm/fp6_llm.cpp | 5 - torchao/ops.py | 75 ---- 4 files changed, 684 deletions(-) delete mode 100644 torchao/csrc/cuda/fp6_llm/fp6.cu diff --git a/setup.py b/setup.py index 0cac46ab96..5d1f32da2b 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,6 @@ def get_extensions(): extension = CUDAExtension if use_cuda else CppExtension extra_link_args = [] - extra_link_args.append("-fopenmp") extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", @@ -55,7 +54,6 @@ def get_extensions(): ], "nvcc": [ "-O3" if not debug_mode else "-O0", - "-Xcompiler", "-fopenmp", ] } if debug_mode: diff --git a/torchao/csrc/cuda/fp6_llm/fp6.cu b/torchao/csrc/cuda/fp6_llm/fp6.cu deleted file mode 100644 index e90e4f2820..0000000000 --- a/torchao/csrc/cuda/fp6_llm/fp6.cu +++ /dev/null @@ -1,602 +0,0 @@ -#include -#include -#include - -#include -#include - -#include -#include -#include - - -class fp6_nan_inf : public std::invalid_argument { -public: - fp6_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6.") { } -}; - -class fp6_overflow : public std::invalid_argument { -public: - fp6_overflow() : std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf. Make sure input < 30.0") { } -}; - -// need to do this trick so that static_assert(false) only evaluates at template instantiation. -template constexpr std::false_type always_false{}; - -// This implementation doesn't have a lot of bit manipulation, so it's less error-prone. -// On CPU, for FP32->FP6, bit manipulation (to_fp6_bits()) is 20% faster than this. -// On CUDA, dtype conversion kernels are memory-bound. Thus, using to_fp6_value() or -// to_fp6_bits() does not matter much. However, to_fp6_bits() has a lot of branching -// based on input value, thus it will cause warp divergence. -template -__device__ __host__ static uint8_t to_fp6_value(T a) { - float fp32_value; - - // need to use if constexpr so that the branches are pruned at compile-time. - // without it, expression in each branch must be valid regardless of template type T. - if constexpr (std::is_same_v) - fp32_value = a; - else if constexpr (std::is_same_v) - fp32_value = __half2float(a); - else if constexpr (std::is_same_v) - fp32_value = __bfloat162float(a); - else if constexpr (std::is_same_v || std::is_same_v) - fp32_value = static_cast(a); - else - static_assert(always_false, "Only float, __half, __nv_bfloat16, c10::Half, and c10::BFloat16 are suppored"); - -#ifndef __CUDA_ARCH__ - if (std::isnan(fp32_value) | std::isinf(fp32_value)) throw fp6_nan_inf(); - if (std::abs(fp32_value) >= 30.0f) throw fp6_overflow(); -#endif - - fp32_value *= 0x1p-124; // 2^(127-3) - uint32_t bits; - std::memcpy(&bits, &fp32_value, sizeof(fp32_value)); - - uint8_t sign = bits >> 31u << 5u; - uint8_t exp_and_man = (bits >> 21u) & 0x1Fu; - uint8_t result = sign | exp_and_man; - - // round to nearest even - uint32_t remainder = bits << 11u; - if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 1u))) { - result += 1; - } - - return result; -} - -// we need to do this because C++17 does not allow using struct as template non-type parameter -// use the upper 16 bits for num exponent, lower 16 bits for num mantissa -static constexpr uint32_t encode_fp_spec(uint32_t n_exp, uint32_t n_man) { return (n_exp << 16u) | n_man; } -static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u); -static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u); -static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u); - -// NOTE: only works for len < 32 -__device__ __host__ static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } - -// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" -template -__device__ __host__ static uint8_t to_fp6_bits(T bits) { - constexpr uint32_t N_EXP = FP_SPEC >> 16u; - constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); - constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; - - // sanity checks. will be removed in template instantiation. - // minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases. - static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); - static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); - - T remainder = 0u; - T sign = bits >> N_EXP_MAN << 5u; - bits &= ones_mask(N_EXP_MAN); // clear sign bit - T result; - - constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; - - // only checks for invalid values on CPU, since we can't throw exception in CUDA -#ifndef __CUDA_ARCH__ - // all exponent bits are 1s - if (bits >= (ones_mask(N_EXP) << N_MAN)) throw fp6_nan_inf(); - - // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) - if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw fp6_overflow(); -#endif - - // FP6 normal number (E>=001) - if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { - remainder = bits << (1u + N_EXP + 2u); - bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent - result = sign | (bits >> (N_MAN - 2u)); - } - // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) - else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) { - T exp = bits >> N_MAN; - T man = bits & ones_mask(N_MAN); - - // to make subnormal FP6 from normal FP16 - // step 1: add implicit 1 to mantissa - man |= (1u << N_MAN); - - // step 2: shift mantissa right so that exponent value is equal to - // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) - T shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + N_EXP + 2u - shift); - result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 - } - // FP6 underflow. E=000, M=00 - else { - result = sign; - } - - // round to nearest even - constexpr T HALF_REMAINDER = 1u << N_EXP_MAN; - if ((remainder > HALF_REMAINDER) || ((remainder == HALF_REMAINDER) && (result & 0x1u))) { - result += 1; - } - return result; -} - -// assume the lower 6 bits contain the data. -// NOTE: probably not efficient for FP6->FP16 and FP6->BF16 on CPU since FP32->FP16/BF16 is slow. -template -__device__ __host__ static T from_fp6(uint8_t a) { - // we shift the bits so that sign, exponent, and mantissa bits are in their correct positions in FP32. - // this also handles subnormal numbers correctly. - // FP6: SE EEMM - // FP32: S000 00EE EMM0 0000 0000 0000 0000 0000 - uint32_t bits = a; // bit extension - uint32_t sign = bits >> 5u << 31u; - uint32_t exp_and_man = (bits & 0x1Fu) << 21u; - uint32_t result_bits = sign | exp_and_man; - - // the result will be off by the difference in exponent bias (3 in FP6 and 127 in FP32) - // we can correct this by direct FP32 multiplication, which also handles subnormal numbers. - float result; - std::memcpy(&result, &result_bits, sizeof(result)); - result *= 0x1p124; // 2^(127-3) - return static_cast(result); -} - -namespace torchao { - -template void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - // exception within OpenMP parallel region must be caught. - // set a flag when exception occurs, then re-raise it. - bool found_nan_inf = false; - bool found_overflow = false; - -#pragma omp parallel for - for (int i = 0; i < n; i++) { - try { fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); } - catch (fp6_nan_inf) { found_nan_inf = true; } - catch (fp6_overflow) { found_overflow = true; } - } - - if (found_nan_inf) throw fp6_nan_inf(); - if (found_overflow) throw fp6_overflow(); -} - -// this is useful for debugging -at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { - TORCH_CHECK(fp_tensor.is_contiguous()); - TORCH_CHECK(fp_tensor.is_cpu()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - - int n = fp_tensor.numel(); - auto dtype = fp_tensor.dtype(); - - if (dtype == torch::kFloat32) { - const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); - - } else if (dtype == torch::kFloat16) { - const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); - - } else if (dtype == torch::kBFloat16) { - const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp6_tensor; -} - -template -__global__ void to_fp6_unpacked_kernel(const T *fp_ptr, uint8_t *fp6_ptr, int n) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - - // NOTE: we are writing 32 uint8 (32 bytes) to global memory. vector load can be used - // to improve memory throughput. using uchar4, we can issue 128-byte global memory write. - if (idx < n) - fp6_ptr[idx] = to_fp6_value(fp_ptr[idx]); -} - -// this is useful for debugging -at::Tensor to_fp6_unpacked_cuda(at::Tensor fp_tensor) { - TORCH_CHECK(fp_tensor.is_contiguous()); - TORCH_CHECK(fp_tensor.is_cuda()); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); - at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - - int n = fp_tensor.numel(); - auto dtype = fp_tensor.dtype(); - - constexpr int block_size = 256; - const int grid_size = (n + block_size - 1) / block_size; - - if (dtype == torch::kFloat32) { - const float *fp32_ptr = fp_tensor.data_ptr(); - to_fp6_unpacked_kernel<<>>(fp32_ptr, fp6_ptr, n); - - } else if (dtype == torch::kFloat16) { - const at::Half *fp16_ptr = fp_tensor.data_ptr(); - to_fp6_unpacked_kernel<<>>(fp16_ptr, fp6_ptr, n); - - } else if (dtype == torch::kBFloat16) { - const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr(); - to_fp6_unpacked_kernel<<>>(bf16_ptr, fp6_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp6_tensor; -} - -template void to_fp6_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { - // exception within OpenMP parallel region must be caught. - // set a flag when exception occurs, then re-raise it. - bool found_nan_inf = false; - bool found_overflow = false; - -#pragma omp parallel for - for (int i = 0; i < n / 4; i++) { - try { - uint8_t val0 = to_fp6_bits(bits_ptr[i * 4]); - uint8_t val1 = to_fp6_bits(bits_ptr[i * 4 + 1]); - uint8_t val2 = to_fp6_bits(bits_ptr[i * 4 + 2]); - uint8_t val3 = to_fp6_bits(bits_ptr[i * 4 + 3]); - - fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - } - catch (fp6_nan_inf) { found_nan_inf = true; } - catch (fp6_overflow) { found_overflow = true; } - } - - if (found_nan_inf) throw fp6_nan_inf(); - if (found_overflow) throw fp6_overflow(); -} - -at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { - TORCH_CHECK(fp_tensor.is_contiguous()); - TORCH_CHECK(fp_tensor.is_cpu()); - TORCH_CHECK(fp_tensor.ndimension() == 2); - - int M = fp_tensor.size(0); - int N = fp_tensor.size(1); - TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); - at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - - int n = fp_tensor.numel(); - auto dtype = fp_tensor.dtype(); - - if (dtype == torch::kFloat32) { - const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_packed_cpu_impl(fp32_ptr, fp6_ptr, n); - - } else if (dtype == torch::kFloat16) { - const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_packed_cpu_impl(fp16_ptr, fp6_ptr, n); - - } else if (dtype == torch::kBFloat16) { - const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_packed_cpu_impl(bf16_ptr, fp6_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp6_tensor; -} - -// define our own vector types since NVIDIA doesn't provide them. -typedef struct __align__(8) { __half x, y, z, w; } fp16_vec4; -typedef struct __align__(8) { __nv_bfloat16 x, y, z, w; } bf16_vec4; - -template -__global__ void to_fp6_packed_kernel(const T *fp_ptr, uint8_t *fp6_ptr, int n) { - const int tid = threadIdx.x; - const int input_offset = (blockIdx.x * blockDim.x) * 4; - const int output_offset = (blockIdx.x * blockDim.x) * 3; - - fp_ptr += input_offset; - fp6_ptr += output_offset; - - __shared__ uint8_t shmem[BLOCK_SIZE * 3]; - - if (input_offset + tid * 4 < n) { - uint8_t val0, val1, val2, val3; - - // vector load for coalesced memory read - if constexpr (std::is_same_v) { - float4 values = reinterpret_cast(fp_ptr)[tid]; - val0 = to_fp6_value(values.x); - val1 = to_fp6_value(values.y); - val2 = to_fp6_value(values.z); - val3 = to_fp6_value(values.w); - } else if constexpr (std::is_same_v || std::is_same_v) { - fp16_vec4 values = reinterpret_cast(fp_ptr)[tid]; - val0 = to_fp6_value(values.x); - val1 = to_fp6_value(values.y); - val2 = to_fp6_value(values.z); - val3 = to_fp6_value(values.w); - } else if constexpr (std::is_same_v || std::is_same_v) { - bf16_vec4 values = reinterpret_cast(fp_ptr)[tid]; - val0 = to_fp6_value(values.x); - val1 = to_fp6_value(values.y); - val2 = to_fp6_value(values.z); - val3 = to_fp6_value(values.w); - } else { - // fallback. no coalesced memory access. (assert false instead?) - val0 = to_fp6_value(fp_ptr[tid * 4]); - val1 = to_fp6_value(fp_ptr[tid * 4 + 1]); - val2 = to_fp6_value(fp_ptr[tid * 4 + 2]); - val3 = to_fp6_value(fp_ptr[tid * 4 + 3]); - } - - shmem[tid * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 - shmem[tid * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 - shmem[tid * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 - } - __syncthreads(); - - // coalesced memory write - // TODO: write in larger word size - for (int i = 0; i < 3; i++) { - if (output_offset + BLOCK_SIZE * i + tid < n / 4 * 3) { - fp6_ptr[BLOCK_SIZE * i + tid] = shmem[BLOCK_SIZE * i + tid]; - } - } -} - -at::Tensor to_fp6_packed_cuda(at::Tensor fp_tensor) { - TORCH_CHECK(fp_tensor.is_contiguous()); - TORCH_CHECK(fp_tensor.is_cuda()); - TORCH_CHECK(fp_tensor.ndimension() == 2); - - int M = fp_tensor.size(0); - int N = fp_tensor.size(1); - TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); - at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); - uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - - int n = fp_tensor.numel(); - auto dtype = fp_tensor.dtype(); - - // times 4 since each thread will handle 4 values - constexpr int block_size = 256; - const int grid_size = (n + (block_size * 4) - 1) / (block_size * 4); - - if (dtype == torch::kFloat32) { - const float *fp32_ptr = fp_tensor.data_ptr(); - to_fp6_packed_kernel<<>>(fp32_ptr, fp6_ptr, n); - - } else if (dtype == torch::kFloat16) { - const at::Half *fp16_ptr = fp_tensor.data_ptr(); - to_fp6_packed_kernel<<>>(fp16_ptr, fp6_ptr, n); - - } else if (dtype == torch::kBFloat16) { - const at::BFloat16 *bf16_ptr = fp_tensor.data_ptr(); - to_fp6_packed_kernel<<>>(bf16_ptr, fp6_ptr, n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); - } - - return fp6_tensor; -} - -template -void from_fp6_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { -#pragma omp parallel for - for (int i = 0; i < n; i++) - fp_ptr[i] = from_fp6(fp6_ptr[i]); -} - -at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu()); - - at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); - at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - if (dtype == torch::kFloat32) { - from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kFloat16) { - from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kBFloat16) { - from_fp6_unpacked_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); - } - - return fp_tensor; -} - -template -__global__ void from_fp6_unpacked_kernel(const uint8_t *fp6_ptr, T *fp_ptr, int n) { - // TODO: use vector load for reading from global memory - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n) - fp_ptr[idx] = from_fp6(fp6_ptr[idx]); -} - -at::Tensor from_fp6_unpacked_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cuda()); - - at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); - at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - constexpr int block_size = 256; - const int grid_size = (n + block_size - 1) / block_size; - - if (dtype == torch::kFloat32) { - from_fp6_unpacked_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kFloat16) { - from_fp6_unpacked_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kBFloat16) { - from_fp6_unpacked_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); - } - - return fp_tensor; -} - -template -void from_fp6_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { -#pragma omp parallel for - for (int i = 0; i < n / 3; i++) { - uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011 - uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333 - - fp_ptr[i * 4] = from_fp6(bits0 >> 2); - fp_ptr[i * 4 + 1] = from_fp6(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp_ptr[i * 4 + 2] = from_fp6(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp_ptr[i * 4 + 3] = from_fp6(bits2 & 0x3Fu); - } -} - -at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cpu()); - TORCH_CHECK(fp6_tensor.ndimension() == 2); - - int M = fp6_tensor.size(0); - int N = fp6_tensor.size(1); - TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); - at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - if (dtype == torch::kFloat32) { - from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kFloat16) { - from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kBFloat16) { - from_fp6_packed_cpu_impl(fp6_ptr, fp_tensor.data_ptr(), n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); - } - - return fp_tensor; -} - -template -__global__ void from_fp6_packed_kernel(const uint8_t *fp6_ptr, T *fp_ptr, int n) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < n / 3) { - // TODO: use vector load for reading from global memory - uint8_t bits0 = fp6_ptr[idx * 3]; // 0000 0011 - uint8_t bits1 = fp6_ptr[idx * 3 + 1]; // 1111 2222 - uint8_t bits2 = fp6_ptr[idx * 3 + 2]; // 2233 3333 - - fp_ptr[idx * 4] = from_fp6(bits0 >> 2); - fp_ptr[idx * 4 + 1] = from_fp6(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp_ptr[idx * 4 + 2] = from_fp6(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp_ptr[idx * 4 + 3] = from_fp6(bits2 & 0x3Fu); - } -} - -at::Tensor from_fp6_packed_cuda(at::Tensor fp6_tensor, c10::ScalarType dtype) { - TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); - TORCH_CHECK(fp6_tensor.is_contiguous()); - TORCH_CHECK(fp6_tensor.is_cuda()); - TORCH_CHECK(fp6_tensor.ndimension() == 2); - - int M = fp6_tensor.size(0); - int N = fp6_tensor.size(1); - TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); - - at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); - at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); - - const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); - int n = fp6_tensor.numel(); - - // times 3 because each thread read 3 bytes (which represent 4 FP6 values) - constexpr int block_size = 256; - const int grid_size = (n + block_size * 3 - 1) / (block_size * 3); - - if (dtype == torch::kFloat32) { - from_fp6_packed_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kFloat16) { - from_fp6_packed_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); - - } else if (dtype == torch::kBFloat16) { - from_fp6_packed_kernel<<>>(fp6_ptr, fp_tensor.data_ptr(), n); - - } else { - throw std::invalid_argument("Only FP32, FP16, and BF16 outputs are accepted."); - } - - return fp_tensor; -} - -TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cpu); - m.impl("torchao::to_fp6_packed", &to_fp6_packed_cpu); - m.impl("torchao::from_fp6_unpacked", &from_fp6_unpacked_cpu); - m.impl("torchao::from_fp6_packed", &from_fp6_packed_cpu); -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::to_fp6_unpacked", &to_fp6_unpacked_cuda); - m.impl("torchao::to_fp6_packed", &to_fp6_packed_cuda); - m.impl("torchao::from_fp6_unpacked", &from_fp6_unpacked_cuda); - m.impl("torchao::from_fp6_packed", &from_fp6_packed_cuda); -} - -} diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 32923eb185..a35caf0893 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -8,9 +8,4 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); - - m.def("to_fp6_unpacked(Tensor fp16_tensor) -> Tensor"); - m.def("to_fp6_packed(Tensor fp16_tensor) -> Tensor"); - m.def("from_fp6_unpacked(Tensor fp6_tensor, ScalarType dtype) -> Tensor"); - m.def("from_fp6_packed(Tensor fp6_tensor, ScalarType dtype) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 5ef92b4807..f4d88e097e 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -48,81 +48,6 @@ def _(fp6_weight): return torch.empty_like(fp6_weight) -def to_fp6_unpacked(fp_tensor: Tensor) -> Tensor: - """ - Convert FP32/FP16/BF16 tensor to FP6. Each FP6 value is stored in the lower 6 bits of an uint8, - thus 2 bits are wasted. This is useful for debugging, since you can access the bits of FP6 - directly via tensor indexing. - """ - return torch.ops.torchao.to_fp6_unpacked.default(fp_tensor) - - -@torch.library.impl_abstract("torchao::to_fp6_unpacked") -def _(fp_tensor): - torch._check( - fp_tensor.dtype in (torch.float32, torch.float16, torch.bfloat16), - lambda: f"inputs must be FP32, FP16, or BF16, got {fp_tensor.dtype}", - ) - return torch.empty_like(fp_tensor, dtype=torch.uint8) - - -def to_fp6_packed(fp_tensor: Tensor) -> Tensor: - """ - Convert FP32/FP16/BF16 tensor to FP6. Every 4 FP32/FP16/BF16 values are packed into 3 uint8 - (4 x 6 bits = 3 x 8 bits). The last dimension must be a multiple of 4. - """ - *leading_dims, last_dim = fp_tensor.shape - return torch.ops.torchao.to_fp6_packed.default(fp_tensor.view(-1, last_dim)).view(*leading_dims, -1) - - -@torch.library.impl_abstract("torchao::to_fp6_packed") -def _(fp_tensor): - torch._check( - fp_tensor.dtype in (torch.float32, torch.float16, torch.bfloat16), - lambda: f"inputs must be FP32, FP16, or BF16, got {fp_tensor.dtype}", - ) - *leading_dims, last_dim = fp_tensor.shape - torch._check(last_dim % 4 == 0, lambda: f"last dimension must be a multiple of 4, got {last_dim}") - return torch.empty(*leading_dims, last_dim * 3 // 4, device=fp_tensor.device, dtype=torch.uint8) - - -def from_fp6_unpacked(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: - """ - Inverse of to_fp6_unpacked(). - """ - return torch.ops.torchao.from_fp6_unpacked.default(fp6_tensor, dtype) - - -@torch.library.impl_abstract("torchao::from_fp6_unpacked") -def _(fp6_tensor, dtype): - torch._check(fp6_tensor.dtype == torch.uint8, lambda: f"inputs must be uint8, got {fp6_tensor.dtype}") - torch._check( - dtype in (torch.float32, torch.float16, torch.bfloat16), - lambda: f"outputs must be FP32, FP16, or BF16, got {dtype}", - ) - return torch.empty_like(fp6_tensor, device=fp6_tensor.device, dtype=dtype) - - -def from_fp6_packed(fp6_tensor: Tensor, dtype: torch.dtype) -> Tensor: - """ - Inverse of to_fp6_packed(). The last dimension must be a multiple of 3. - """ - *leading_dims, last_dim = fp6_tensor.shape - return torch.ops.torchao.from_fp6_packed.default(fp6_tensor.view(-1, last_dim), dtype).view(*leading_dims, -1) - - -@torch.library.impl_abstract("torchao::from_fp6_packed") -def _(fp6_tensor, dtype): - torch._check(fp6_tensor.dtype == torch.uint8, lambda: f"inputs must be uint8, got {fp6_tensor.dtype}") - torch._check( - dtype in (torch.float32, torch.float16, torch.bfloat16), - lambda: f"outputs must be FP32, FP16, or BF16, got {dtype}", - ) - *leading_dims, last_dim = fp6_tensor.shape - torch._check(last_dim % 3 == 0, lambda: f"last dimension must be a multiple of 3, got {last_dim}") - return torch.empty(*leading_dims, last_dim * 4 // 3, device=fp6_tensor.device, dtype=dtype) - - def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ Pack FP16 tensor to FP6 tensor. qtorch is required to use this function. From f178b01e082460fd1e15aaddcb44f711b8fd2846 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 08:48:38 +0000 Subject: [PATCH 54/80] move to_fp6 to dtypes/ --- torchao/dtypes/fp6.py | 93 +++++++++++++++++++++++++++++++++++++++++++ torchao/ops.py | 84 -------------------------------------- 2 files changed, 93 insertions(+), 84 deletions(-) create mode 100644 torchao/dtypes/fp6.py diff --git a/torchao/dtypes/fp6.py b/torchao/dtypes/fp6.py new file mode 100644 index 0000000000..39bda9d170 --- /dev/null +++ b/torchao/dtypes/fp6.py @@ -0,0 +1,93 @@ +import torch +from torch import Tensor +from torch.utils._triton import has_triton + + +if has_triton(): + import triton + from triton import language as tl + + @triton.jit + def _triton_fp32_to_fp6(x: tl.tensor): + x = x.to(tl.float32) + x = x * 2.0 ** (-127 + 3) + bits = x.to(tl.int32, bitcast=True) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + remainder = bits & 0x1F_FFFF + do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) + result = tl.where(do_round_up, result + 1, result) + return result.to(tl.uint8) + + @triton.jit + def _to_fp6_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n + + # strided memory read. there will be uncoalesced memory access + val0 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4, mask)) + val1 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 1, mask)) + val2 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 2, mask)) + val3 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 3, mask)) + + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + + # strided memory write. there will be uncoalesced memory access + tl.store(out_ptr + offsets * 3, bits0, mask) + tl.store(out_ptr + offsets * 3 + 1, bits1, mask) + tl.store(out_ptr + offsets * 3 + 2, bits2, mask) + + def _to_fp6_triton(tensor: Tensor) -> Tensor: + out_shape = tensor.shape[:-1] + (tensor.shape[-1] // 4 * 3,) + output = torch.empty(out_shape, device=tensor.device, dtype=torch.uint8) + + n = tensor.numel() + grid_size = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"] * 4),) + _to_fp6_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) + + return output + +else: + _to_fp6_triton = None + + +def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + tensor = tensor.float() + tensor = tensor * 2.0 ** (-127 + 3) + bits = tensor.view(torch.int32) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + remainder = bits & 0x1F_FFFF + do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) + result = torch.where(do_round_up, result + 1, result) + result = result.to(torch.uint8) + + if no_bit_packing: + return result + + val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) + + +def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + if not no_bit_packing: + assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" + + # torch.compile() cannot generate fused bit-packing triton kernel, + # thus we write custom triton kernel for this specific case. + if tensor.is_cuda and not no_bit_packing and _to_fp6_triton is not None: + return _to_fp6_triton(tensor) + + else: + return _to_fp6_pt(tensor, no_bit_packing=no_bit_packing) diff --git a/torchao/ops.py b/torchao/ops.py index f4d88e097e..2fed9a3fda 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,10 +1,5 @@ import torch from torch import Tensor -from torch.utils._triton import has_triton - -if has_triton(): - import triton - from triton import language as tl def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: @@ -118,82 +113,3 @@ def _(fp6_tensor, fp16_scale): torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") return fp16_scale.new_empty((OC, _IC * 16 // 3)) - - -if has_triton(): - @triton.jit - def _to_fp6_triton(x: tl.tensor): - x = x.to(tl.float32) - x = x * 2.0 ** (-124) - bits = x.to(tl.int32, bitcast=True) - - sign = ((bits >> 31) & 0x1) << 5 - exp_and_man = (bits >> 21) & 0x1F - result = sign | exp_and_man - - remainder = bits & 0x1F_FFFF - do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & (result & 1)) - result = tl.where(do_round_up, result + 1, result) - return result.to(tl.uint8) - - @triton.jit - def _to_fp6_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n - - # strided memory read. there will be uncoalesced memory access - val0 = _to_fp6_triton(tl.load(in_ptr + offsets * 4, mask)) - val1 = _to_fp6_triton(tl.load(in_ptr + offsets * 4 + 1, mask)) - val2 = _to_fp6_triton(tl.load(in_ptr + offsets * 4 + 2, mask)) - val3 = _to_fp6_triton(tl.load(in_ptr + offsets * 4 + 3, mask)) - - bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 - bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 - bits2 = (val2 << 6) | (val3); # 2233 3333 - - # strided memory write. there will be uncoalesced memory access - tl.store(out_ptr + offsets * 3, bits0, mask) - tl.store(out_ptr + offsets * 3 + 1, bits1, mask) - tl.store(out_ptr + offsets * 3 + 2, bits2, mask) - -else: - _to_fp6_triton_kernel = None - - -def to_fp6_pt(tensor: torch.Tensor, unpacked: bool = False) -> Tensor: - if tensor.device.type == "cuda" and _to_fp6_triton_kernel is not None: - out_shape = tensor.shape[:-1] + (tensor.shape[-1] // 4 * 3,) - output = torch.empty(out_shape, device=tensor.device, dtype=torch.uint8) - - n = tensor.numel() // 4 - grid_size = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) - _to_fp6_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) - - return output - - tensor = tensor.float() - tensor = tensor * 2.0 ** (-124) - bits = tensor.view(torch.int32) - - sign = ((bits >> 31) & 0x1) << 5 - exp_and_man = (bits >> 21) & 0x1F - result = sign | exp_and_man - - remainder = bits & 0x1F_FFFF - do_round_up = torch.logical_or( - remainder > 0x10_0000, - torch.logical_and(remainder == 0x10_0000, result & 1) - ) - result = torch.where(do_round_up, result + 1, result) - result = result.to(torch.uint8) - - if unpacked: - return result - - # pre-allocate output tensor is faster than using torch.stack() - outputs = torch.empty(tensor.shape[:-1] + (tensor.shape[-1] // 4, 3), device=tensor.device, dtype=torch.uint8) - val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) - outputs[..., 0] = (val0 << 2) | (val1 >> 4) # 0000 0011 - outputs[..., 1] = (val1 << 4) | (val2 >> 2) # 1111 2222 - outputs[..., 2] = (val2 << 6) | (val3); # 2233 3333 - return outputs.flatten(-2) From ee2310ce32c2c847f195b6257ee92724857a82d2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 08:49:00 +0000 Subject: [PATCH 55/80] add to_fp6 import --- torchao/dtypes/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index b14aff9904..cd9b1d62eb 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,8 +1,10 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor +from .fp6 import to_fp6 __all__ = [ "NF4Tensor", "to_nf4", "UInt4Tensor" + "to_fp6", ] From 404f700ab1c5fc67858751e496c51908e69b7121 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 09:13:41 +0000 Subject: [PATCH 56/80] move tests --- test/dtypes/test_fp6.py | 85 ++++++++++++++++++++++++++++++++++++ test/test_ops.py | 97 ----------------------------------------- 2 files changed, 85 insertions(+), 97 deletions(-) create mode 100644 test/dtypes/test_fp6.py diff --git a/test/dtypes/test_fp6.py b/test/dtypes/test_fp6.py new file mode 100644 index 0000000000..8c80755f53 --- /dev/null +++ b/test/dtypes/test_fp6.py @@ -0,0 +1,85 @@ +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torchao.dtypes.fp6 import to_fp6 + + +_DTYPES = [torch.float32, torch.float16, torch.bfloat16] +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestFp6(TestCase): + + @parametrize("device", _DEVICES) + @parametrize("dtype", _DTYPES) + @parametrize( + "input_output", + [ + (0.0, 0b000000), # exact values + (1.0, 0b001100), # normal numbers + (1.25, 0b001101), + (28.0, 0b011111), # max + (0.1875, 0b000011), # subnormal number + (0.0625, 0b000001), # min + (29.0, 0b011111), # normal round down + (26.0, 0b011110), # normal round to nearest even + (0.1251, 0b000010), # subnormal round down + (0.0314, 0b000001), # subnormal round up + (0.03, 0b000000), # underflow + ], + ) + def test_no_bit_packing_correctness(self, device, dtype, input_output): + input, output = input_output + input = torch.tensor(input, device=device, dtype=dtype) + assert to_fp6(input, no_bit_packing=True).item() == output + + @parametrize("device", _DEVICES) + @parametrize("dtype", _DTYPES) + def test_bit_packing_correctness(self, device, dtype): + x = torch.randn(128, 128, device=device, dtype=dtype) + results_unpacked = to_fp6(x, no_bit_packing=True) + results_packed = to_fp6(x) + + val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + + expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) + assert (results_packed == expected_packed).all() + + @parametrize("device", _DEVICES) + @parametrize("shape", [(), (0,), (10,), (20, 20)]) + def test_no_bit_packing_shape(self, device, shape): + x = torch.randn(shape, device=device) + result = to_fp6(x, no_bit_packing=True) + assert result.shape == shape + + @parametrize("device", _DEVICES) + @parametrize("shape", [(4,), (20, 20)]) + def test_bit_packing_shape(self, device, shape): + x = torch.randn(shape, device=device) + result = to_fp6(x) + assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) + + @parametrize("device", _DEVICES) + @parametrize("dtype", _DTYPES) + @parametrize("no_bit_packing", [False, True]) + def test_compile(self, device, dtype, no_bit_packing): + x = torch.randn(20, 20, device=device, dtype=dtype) + to_fp6_compiled = torch.compile(to_fp6) # will hit cache_size_limit if fullgraph=True + + actual = to_fp6_compiled(x, no_bit_packing=no_bit_packing) + expected = to_fp6(x, no_bit_packing=no_bit_packing) + torch.testing.assert_close(actual, expected) + + +instantiate_parametrized_tests(TestFp6) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 966e400e4d..2dfbb72d6b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -138,103 +138,6 @@ def _skip_cpu(self): if not torch.cuda.is_available(): self.skipTest("CUDA not available. We don't compile for CPU-only build") - @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) - def test_to_fp6_unpacked(self, device, dtype): - self._skip_cpu() - inputs = torch.randn(128, 128, device=device, dtype=dtype) - - # smoke test - torchao.ops.to_fp6_unpacked(inputs) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.to_fp6_unpacked, (inputs,), test_utils=test_utils) - - @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) - def test_to_fp6_packed(self, device, dtype): - self._skip_cpu() - inputs = torch.randn(128, 128, device=device, dtype=dtype) - - # smoke test - torchao.ops.to_fp6_packed(inputs) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.to_fp6_packed, (inputs,), test_utils=test_utils) - - @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) - def test_from_fp6_unpacked(self, device, dtype): - self._skip_cpu() - inputs = torch.randint(256, size=(128, 128 // 4 * 3), device=device, dtype=torch.uint8) - - # smoke test - torchao.ops.from_fp6_unpacked(inputs, dtype) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.from_fp6_unpacked, (inputs, dtype), test_utils=test_utils) - - @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) - def test_from_fp6_packed(self, device, dtype): - self._skip_cpu() - inputs = torch.randint(256, size=(128, 128 // 4 * 3), device=device, dtype=torch.uint8) - - # smoke test - torchao.ops.from_fp6_packed(inputs, dtype) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.from_fp6_packed, (inputs, dtype), test_utils=test_utils) - - def test_to_fp6_unpacked_shape(self): - for shape in [(), (0,), (10,), (20, 20)]: - x = torch.randn(shape) - result = torchao.ops.to_fp6_unpacked(x) - assert result.shape == shape - - def test_to_fp6_packed_shape(self): - for shape in [(4,), (20, 20)]: - x = torch.randn(shape) - result = torchao.ops.to_fp6_packed(x) - assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) - - @parameterized.expand( - [ - (0.0, 0b000000), # exact values - (1.0, 0b001100), # normal numbers - (1.25, 0b001101), - (28.0, 0b011111), # max - (0.1875, 0b000011), # subnormal number - (0.0625, 0b000001), # min - (29.0, 0b011111), # normal round down - (26.0, 0b011110), # normal round to nearest even - (0.1251, 0b000010), # subnormal round down - (0.0314, 0b000001), # subnormal round up - (0.03, 0b000000), # underflow - ] - ) - def test_to_fp6_unpacked_correctness(self, input, output): - self._skip_cpu() - for device in ("cpu", "cuda"): - for dtype in (torch.float32, torch.float16, torch.bfloat16): - x = torch.tensor(input, device=device, dtype=dtype) - assert torchao.ops.to_fp6_unpacked(x).item() == output - assert torchao.ops.to_fp6_unpacked(-x).item() == (output | 0b100000) - - @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) - def test_to_fp6_packed_correctness(self, device, dtype): - x = torch.randn(128, 128, device=device, dtype=dtype) - results_unpacked = torchao.ops.to_fp6_unpacked(x) - results_packed = torchao.ops.to_fp6_packed(x) - - val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) - bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 - bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 - bits2 = (val2 << 6) | (val3); # 2233 3333 - - expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) - assert (results_packed == expected_packed).all() - @parameterized.expand([30.0, -100.0, float("inf"), float("nan")]) def test_to_fp6_exception(self, input): self._skip_cpu() From 48fe45f83bff192cf4083fb6abb78593a2f26ff4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 09:23:14 +0000 Subject: [PATCH 57/80] update benchmark script --- benchmarks/benchmark_fp6_conversion.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/benchmarks/benchmark_fp6_conversion.py b/benchmarks/benchmark_fp6_conversion.py index 4cd5cfe621..a8bee950e8 100644 --- a/benchmarks/benchmark_fp6_conversion.py +++ b/benchmarks/benchmark_fp6_conversion.py @@ -1,15 +1,14 @@ -from functools import partial - import torch import torchao import pandas as pd from torch.utils.benchmark import Timer -def benchmark(f, weight): +def benchmark(f, weight, num_threads = 1): measurement = Timer( stmt="f(weight)", globals={"f": f, "weight": weight}, + num_threads=num_threads, ).blocked_autorange() return measurement.median * 1000 @@ -25,12 +24,7 @@ def benchmark(f, weight): functions = [ ("original", torchao.ops.fp16_to_fp6_original), - ("C++/CUDA extension", torchao.ops.to_fp6_packed), - ("PyTorch + torch.compile (default)", torch.compile(torchao.ops.to_fp6_pt)), - ("PyTorch + torch.compile (max-autotune)", torch.compile(torchao.ops.to_fp6_pt, mode="max-autotune")), - - # ("C++/CUDA extension (no bit-packing)", torchao.ops.to_fp6_unpacked), - # ("PyTorch + torch.compile (no bit-packing)", partial(torch.compile(torchao.ops.to_fp6_pt), unpacked=True)), + ("ours", torch.compile(torchao.dtypes.to_fp6)), ] results = [] @@ -38,6 +32,9 @@ def benchmark(f, weight): results.append(["CPU", "FP32->FP6", name, benchmark(f, fp32_weight)]) results.append(["CPU", "FP16->FP6", name, benchmark(f, fp16_weight)]) + results.append(["CPU", "FP32->FP6", f"{name} (num_threads=4)", benchmark(f, fp32_weight, num_threads=4)]) + results.append(["CPU", "FP16->FP6", f"{name} (num_threads=4)", benchmark(f, fp16_weight, num_threads=4)]) + if name != "original": results.append(["CUDA", "FP32->FP6", name, benchmark(f, fp32_weight_cuda)]) results.append(["CUDA", "FP16->FP6", name, benchmark(f, fp16_weight_cuda)]) From 5126f8f4d53fbba4e53ded1523b59cb9c4cd20ad Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 20:03:14 +0800 Subject: [PATCH 58/80] add from_fp6 --- torchao/dtypes/__init__.py | 3 ++- torchao/dtypes/fp6.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index cd9b1d62eb..a5d444f4a1 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,10 +1,11 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor -from .fp6 import to_fp6 +from .fp6 import to_fp6, from_fp6 __all__ = [ "NF4Tensor", "to_nf4", "UInt4Tensor" "to_fp6", + "from_fp6", ] diff --git a/torchao/dtypes/fp6.py b/torchao/dtypes/fp6.py index 39bda9d170..9408d53a63 100644 --- a/torchao/dtypes/fp6.py +++ b/torchao/dtypes/fp6.py @@ -91,3 +91,28 @@ def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: else: return _to_fp6_pt(tensor, no_bit_packing=no_bit_packing) + + +def _pt_fp6_to_fp32(tensor: Tensor) -> Tensor: + bits = tensor.to(torch.int32) # bit extension + sign = bits >> 5 << 31 + exp_and_man = (bits & 0x1F) << 21 + results = sign | exp_and_man + + results = results.view(torch.float32) + return results * 2.0 ** (127 - 3) + + +def from_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + assert tensor.dtype == torch.uint8 + if no_bit_packing: + return _pt_fp6_to_fp32(tensor) + + assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 4" + + bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) + val0 = _pt_fp6_to_fp32(bits0 >> 2) + val1 = _pt_fp6_to_fp32(((bits0 & 0x3) << 4) | (bits1 >> 4)) + val2 = _pt_fp6_to_fp32(((bits1 & 0xF) << 2) | (bits2 >> 6)) + val3 = _pt_fp6_to_fp32(bits2 & 0x3F) + return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) From da767fa3e5c5172740b2b5b1313a7d1ff4fd4f02 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 20:22:00 +0800 Subject: [PATCH 59/80] migrate test --- test/dtypes/test_fp6.py | 44 ++++++++++++++++++++++++++++----- test/test_ops.py | 54 ----------------------------------------- 2 files changed, 38 insertions(+), 60 deletions(-) diff --git a/test/dtypes/test_fp6.py b/test/dtypes/test_fp6.py index 8c80755f53..4c6d3f9371 100644 --- a/test/dtypes/test_fp6.py +++ b/test/dtypes/test_fp6.py @@ -5,7 +5,7 @@ parametrize, run_tests, ) -from torchao.dtypes.fp6 import to_fp6 +from torchao.dtypes.fp6 import to_fp6, from_fp6 _DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -32,14 +32,14 @@ class TestFp6(TestCase): (0.03, 0b000000), # underflow ], ) - def test_no_bit_packing_correctness(self, device, dtype, input_output): + def test_to_fp6_no_bit_packing_correctness(self, device, dtype, input_output): input, output = input_output input = torch.tensor(input, device=device, dtype=dtype) assert to_fp6(input, no_bit_packing=True).item() == output @parametrize("device", _DEVICES) @parametrize("dtype", _DTYPES) - def test_bit_packing_correctness(self, device, dtype): + def test_to_fp6_bit_packing_correctness(self, device, dtype): x = torch.randn(128, 128, device=device, dtype=dtype) results_unpacked = to_fp6(x, no_bit_packing=True) results_packed = to_fp6(x) @@ -54,14 +54,14 @@ def test_bit_packing_correctness(self, device, dtype): @parametrize("device", _DEVICES) @parametrize("shape", [(), (0,), (10,), (20, 20)]) - def test_no_bit_packing_shape(self, device, shape): + def test_to_fp6_no_bit_packing_shape(self, device, shape): x = torch.randn(shape, device=device) result = to_fp6(x, no_bit_packing=True) assert result.shape == shape @parametrize("device", _DEVICES) @parametrize("shape", [(4,), (20, 20)]) - def test_bit_packing_shape(self, device, shape): + def test_to_fp6_bit_packing_shape(self, device, shape): x = torch.randn(shape, device=device) result = to_fp6(x) assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) @@ -69,7 +69,7 @@ def test_bit_packing_shape(self, device, shape): @parametrize("device", _DEVICES) @parametrize("dtype", _DTYPES) @parametrize("no_bit_packing", [False, True]) - def test_compile(self, device, dtype, no_bit_packing): + def test_to_fp6_compile(self, device, dtype, no_bit_packing): x = torch.randn(20, 20, device=device, dtype=dtype) to_fp6_compiled = torch.compile(to_fp6) # will hit cache_size_limit if fullgraph=True @@ -77,6 +77,38 @@ def test_compile(self, device, dtype, no_bit_packing): expected = to_fp6(x, no_bit_packing=no_bit_packing) torch.testing.assert_close(actual, expected) + @parametrize("device", _DEVICES) + @parametrize( + "input_output", + [ + (0b000000, 0.0), + (0b001100, 1.0), + (0b011111, 28.0), + (0b000001, 0.0625), + (0b001110, 1.5), + (0b000011, 0.1875), + ], + ) + def test_from_fp6_no_bit_packing_correctness(self, device, input_output): + input, output = input_output + input = torch.tensor(input, device=device, dtype=torch.uint8) + assert from_fp6(input, no_bit_packing=True).item() == output + + @parametrize("device", _DEVICES) + def test_from_fp6_bit_packing_correctness(self, device): + x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) + actual = from_fp6(x) + + bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) + x_unpacked0 = bits0 >> 2 + x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) + x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) + x_unpacked3 = bits2 & 0x3F + + x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) + expected = from_fp6(x_unpacked, no_bit_packing=True) + torch.testing.assert_close(actual, expected) + instantiate_parametrized_tests(TestFp6) diff --git a/test/test_ops.py b/test/test_ops.py index 2dfbb72d6b..fc69ff9ccb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -133,59 +133,5 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): assert relative_error.mean() < 1e-2 -class TestFp6(TestCase): - def _skip_cpu(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA not available. We don't compile for CPU-only build") - - @parameterized.expand([30.0, -100.0, float("inf"), float("nan")]) - def test_to_fp6_exception(self, input): - self._skip_cpu() - x = torch.tensor(input) - with self.assertRaises(Exception): - torchao.ops.to_fp6_unpacked(x) - with self.assertRaises(Exception): - torchao.ops.to_fp6_packed(x) - - @parameterized.expand( - [ - (0b000000, 0.0), - (0b001100, 1.0), - (0b011111, 28.0), - (0b000001, 0.0625), - (0b001110, 1.5), - (0b000011, 0.1875), - ] - ) - def test_from_fp6_unpacked_correctness(self, input, output): - self._skip_cpu() - for device in ("cpu", "cuda"): - for dtype in (torch.float32, torch.float16, torch.bfloat16): - x = torch.tensor(input, device=device, dtype=torch.uint8) - result = torchao.ops.from_fp6_unpacked(x, dtype) - assert result.dtype == dtype - assert result.item() == output - - x = torch.tensor(input | 0b100000, device=device, dtype=torch.uint8) - result = torchao.ops.from_fp6_unpacked(x, dtype) - assert result.dtype == dtype - assert result.item() == -output - - @parameterized.expand([(device, dtype) for device in ["cpu", "cuda"] for dtype in [torch.float32, torch.float16, torch.bfloat16]]) - def test_from_fp6_packed_correctness(self, device, dtype): - x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) - results = torchao.ops.from_fp6_packed(x, dtype=dtype) - - bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) - x_unpacked0 = bits0 >> 2 - x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) - x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) - x_unpacked3 = bits2 & 0x3F - - x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) - expected = torchao.ops.from_fp6_unpacked(x_unpacked, dtype) - assert (results == expected).all() - - if __name__ == "__main__": unittest.main() From 0b56ecfdd4d5d5a122f97da0cd29e478aa8e6ab0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 20:50:15 +0800 Subject: [PATCH 60/80] add docs --- torchao/dtypes/fp6.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/fp6.py b/torchao/dtypes/fp6.py index 9408d53a63..c88848d657 100644 --- a/torchao/dtypes/fp6.py +++ b/torchao/dtypes/fp6.py @@ -7,6 +7,7 @@ import triton from triton import language as tl + # see _to_fp6_pt() for explanation @triton.jit def _triton_fp32_to_fp6(x: tl.tensor): x = x.to(tl.float32) @@ -33,6 +34,7 @@ def _to_fp6_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): val2 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 2, mask)) val3 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 3, mask)) + # bit packing bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 bits2 = (val2 << 6) | (val3); # 2233 3333 @@ -58,6 +60,8 @@ def _to_fp6_triton(tensor: Tensor) -> Tensor: def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: tensor = tensor.float() + + # correct exponent bias. this also handles subnormal numbers correctly tensor = tensor * 2.0 ** (-127 + 3) bits = tensor.view(torch.int32) @@ -65,7 +69,8 @@ def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: exp_and_man = (bits >> 21) & 0x1F result = sign | exp_and_man - remainder = bits & 0x1F_FFFF + # round to nearest even + remainder = bits & 0x1F_FFFF # truncated mantissa bits do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) result = torch.where(do_round_up, result + 1, result) result = result.to(torch.uint8) @@ -73,6 +78,7 @@ def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: if no_bit_packing: return result + # bit packing val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 @@ -81,6 +87,25 @@ def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + """Convert input tensor to FP6. This particular FP6 format has 3 exponent bits and 2 mantissa + bits. By default, bit packing is performed: every 4 FP6 values are packed as 3 uint8 values + (4 x 6 bits = 3 x 8 bits). + + Args: + tensor: input tensor. The last dimension must be divisible by 4 (unless `no_bit_packing=False`) + no_bit_packing: whether to not perform bit packing. Setting this to `True` can be useful for + observing the bit patterns and debugging. + + Returns: + an FP6 tensor, stored as uint8 data. If `no_bit_packing=False`, the last dimension of output + tensor is 3/4 of that of input tensor. + + Note: + This FP6 format does not represent +/-inf and NaN. Thus, make sure that input tensor does + not have +/-inf or NaN values, and no values with magnitude >= 28 (largest number in FP6). + + Also see :func:`from_fp6` + """ if not no_bit_packing: assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" @@ -100,15 +125,22 @@ def _pt_fp6_to_fp32(tensor: Tensor) -> Tensor: results = sign | exp_and_man results = results.view(torch.float32) - return results * 2.0 ** (127 - 3) + return results * 2.0 ** (127 - 3) # exponent bias correction def from_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + """Convert an FP6 tensor (created by :func:`to_fp6`) to FP32. + + Args: + tensor: FP6 tensor, stored as uint8 data. If `no_bit_packing=False`, the last dimension must + be divisible by 3. + no_bit_packing: whether the input does not have bit packing. + """ assert tensor.dtype == torch.uint8 if no_bit_packing: return _pt_fp6_to_fp32(tensor) - assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 4" + assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 3" bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) val0 = _pt_fp6_to_fp32(bits0 >> 2) From 110e888b991be8744d97c538177daefbcddb73b8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 20:52:25 +0800 Subject: [PATCH 61/80] add docs --- docs/source/api_ref_dtypes.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 4cb797beb4..aff808d7fb 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -12,6 +12,8 @@ torchao.dtypes to_nf4 UInt4Tensor + to_fp6 + from_fp6 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring From 3e2643cd841c6aa5915f267f7db78ec70d603bfd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 21 May 2024 21:03:13 +0800 Subject: [PATCH 62/80] add torch.compile test --- test/dtypes/test_fp6.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/dtypes/test_fp6.py b/test/dtypes/test_fp6.py index 4c6d3f9371..5163ce3d43 100644 --- a/test/dtypes/test_fp6.py +++ b/test/dtypes/test_fp6.py @@ -109,6 +109,16 @@ def test_from_fp6_bit_packing_correctness(self, device): expected = from_fp6(x_unpacked, no_bit_packing=True) torch.testing.assert_close(actual, expected) + @parametrize("device", _DEVICES) + @parametrize("no_bit_packing", [False, True]) + def test_from_fp6_compile(self, device, no_bit_packing): + x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) + from_fp6_compiled = torch.compile(from_fp6) + + actual = from_fp6_compiled(x, no_bit_packing=no_bit_packing) + expected = from_fp6(x, no_bit_packing=no_bit_packing) + torch.testing.assert_close(actual, expected) + instantiate_parametrized_tests(TestFp6) From 750fbc6e0431904d2cd87d1ade879b9fddc91cd6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 07:42:08 +0800 Subject: [PATCH 63/80] polish docs --- torchao/dtypes/fp6.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/fp6.py b/torchao/dtypes/fp6.py index c88848d657..9046886dcc 100644 --- a/torchao/dtypes/fp6.py +++ b/torchao/dtypes/fp6.py @@ -3,6 +3,11 @@ from torch.utils._triton import has_triton +# some useful constants +FP6_MAX = 28.0 +FP6_SMALLEST_SUBNORMAL = 0.0625 + + if has_triton(): import triton from triton import language as tl @@ -92,19 +97,20 @@ def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: (4 x 6 bits = 3 x 8 bits). Args: - tensor: input tensor. The last dimension must be divisible by 4 (unless `no_bit_packing=False`) - no_bit_packing: whether to not perform bit packing. Setting this to `True` can be useful for + tensor: Input tensor. The last dimension must be divisible by 4 (unless ``no_bit_packing=False``) + no_bit_packing: Whether to not perform bit packing. Setting this to ``True`` can be useful for observing the bit patterns and debugging. Returns: - an FP6 tensor, stored as uint8 data. If `no_bit_packing=False`, the last dimension of output - tensor is 3/4 of that of input tensor. + :class:`torch.Tensor`: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last + dimension of output tensor is 3/4 of that of input tensor. Note: This FP6 format does not represent +/-inf and NaN. Thus, make sure that input tensor does - not have +/-inf or NaN values, and no values with magnitude >= 28 (largest number in FP6). + not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. + All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). - Also see :func:`from_fp6` + See also :func:`from_fp6` """ if not no_bit_packing: assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" @@ -132,9 +138,13 @@ def from_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: """Convert an FP6 tensor (created by :func:`to_fp6`) to FP32. Args: - tensor: FP6 tensor, stored as uint8 data. If `no_bit_packing=False`, the last dimension must - be divisible by 3. + tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must be + divisible by 3. no_bit_packing: whether the input does not have bit packing. + + Returns: + :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output tensor + is 4/3 of that of input tensor. """ assert tensor.dtype == torch.uint8 if no_bit_packing: From 6a3f0c01d6eff45da9d436ee036a07b9ecae338c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 07:48:30 +0800 Subject: [PATCH 64/80] remove original weight dequant --- test/test_ops.py | 13 ----- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 65 ----------------------- torchao/ops.py | 17 ------ 3 files changed, 95 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index fc69ff9ccb..673413e7a7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -99,19 +99,6 @@ def test_fp16act_fp6weight_linear(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_weight_dequant(self): - OC = 256 - IC = 256 - fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC) - - # smoke test - torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils) - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index 9aa78858fe..b519cbfb0d 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -13,7 +13,6 @@ // limitations under the License. // // This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h -// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h #include #include @@ -120,41 +119,6 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, } } -void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { - assert(M%64==0); // Currently, M must be a multiple of 64. - assert(K%64==0); // Currently, K must be a multiple of 64. - size_t TotalSizeInByte = M*K*6/8; - // - half* OutPTR = A_16bit_h; - for(size_t i=0; i>2)&0x1f); - unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); - B2 = (B2&0x80) | ((B2>>2)&0x1f); - unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); - B3 = (B3&0x80) | ((B3>>2)&0x1f); - unsigned char B4 = A_6bit_h[i*3+2]<<2; - B4 = (B4&0x80) | ((B4>>2)&0x1f); - half FP1, FP2, FP3, FP4; - unsigned char *PTR1, *PTR2, *PTR3, *PTR4; - PTR1 = reinterpret_cast(&FP1); - PTR2 = reinterpret_cast(&FP2); - PTR3 = reinterpret_cast(&FP3); - PTR4 = reinterpret_cast(&FP4); - PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU - PTR2[0] = 0; PTR2[1] = B2; - PTR3[0] = 0; PTR3[1] = B3; - PTR4[0] = 0; PTR4[1] = B4; - OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); - // - OutPTR +=4; - } -} - - #include #include #include @@ -183,37 +147,8 @@ at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor) return packed_fp6_tensor; } -/* - * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. - * A useful tool to construct input matrices for the FP16 GEMM baseline. - * [Input] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. - * [Output] - * fp16_tensor: half tensor of shape [OC, IC]. - */ -at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale) -{ - int OC = fp6_tensor.size(0); - TORCH_CHECK(fp6_tensor.size(1) % 3 == 0); - int IC = fp6_tensor.size(1) / 3 * 16; - TORCH_CHECK(fp16_scale.size(0) == OC); - // - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); - auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); - // - auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device()); - at::Tensor fp16_tensor = at::empty({OC, IC}, options); - auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); - // - DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr); - // - return fp16_tensor; -} - TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); - m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); } } diff --git a/torchao/ops.py b/torchao/ops.py index 2fed9a3fda..b74cfa3731 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -96,20 +96,3 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) - - -def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: - return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) - - -@torch.library.impl_abstract("torchao::fp6_weight_dequant") -def _(fp6_tensor, fp16_scale): - torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D") - torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}") - torch._check(fp16_scale.dim() == 1, lambda: f"scale should be a 2d tensor, got {fp16_scale.dim()}D") - torch._check(fp16_scale.dtype is torch.float16, lambda: f"scale must be FP16, got {fp16_scale.dtype}") - - OC, _IC = fp6_tensor.shape - torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") - - return fp16_scale.new_empty((OC, _IC * 16 // 3)) From f32d09f82f2ce1cee544630ec4af3ae29c7d03fe Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 07:58:34 +0800 Subject: [PATCH 65/80] remove weight dequant --- torchao/csrc/fp6_llm/fp6_llm.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index a35caf0893..ccde481764 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -7,5 +7,4 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); - m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); } From 8b5b81ee0be9650159b3e125bfd12cab3ff044c3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 10:18:42 +0800 Subject: [PATCH 66/80] improve tests --- test/dtypes/test_fp6.py | 8 ++++---- test/test_ops.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/dtypes/test_fp6.py b/test/dtypes/test_fp6.py index 5163ce3d43..8b0f4879c0 100644 --- a/test/dtypes/test_fp6.py +++ b/test/dtypes/test_fp6.py @@ -71,10 +71,10 @@ def test_to_fp6_bit_packing_shape(self, device, shape): @parametrize("no_bit_packing", [False, True]) def test_to_fp6_compile(self, device, dtype, no_bit_packing): x = torch.randn(20, 20, device=device, dtype=dtype) - to_fp6_compiled = torch.compile(to_fp6) # will hit cache_size_limit if fullgraph=True + expected = to_fp6(x, no_bit_packing=no_bit_packing) + to_fp6_compiled = torch.compile(to_fp6) actual = to_fp6_compiled(x, no_bit_packing=no_bit_packing) - expected = to_fp6(x, no_bit_packing=no_bit_packing) torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) @@ -113,10 +113,10 @@ def test_from_fp6_bit_packing_correctness(self, device): @parametrize("no_bit_packing", [False, True]) def test_from_fp6_compile(self, device, no_bit_packing): x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) - from_fp6_compiled = torch.compile(from_fp6) + expected = from_fp6(x, no_bit_packing=no_bit_packing) + from_fp6_compiled = torch.compile(from_fp6) actual = from_fp6_compiled(x, no_bit_packing=no_bit_packing) - expected = from_fp6(x, no_bit_packing=no_bit_packing) torch.testing.assert_close(actual, expected) diff --git a/test/test_ops.py b/test/test_ops.py index 673413e7a7..74c8e1163d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -112,7 +112,8 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) - fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() + fp32_weight = torchao.dtypes.from_fp6(fp6_weight.view(torch.uint8)) * fp16_scale[:, None] + fp16_weight = fp32_weight.half().cuda() results_fp16 = act_cuda @ fp16_weight.T error = (results_fp6 - results_fp16).abs() From a3cf93b6471fd03241710e81388f25e1d5cd683d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 21:39:26 +0800 Subject: [PATCH 67/80] update names --- .../{test_fp6.py => test_float6_e3m2.py} | 46 +++++++++---------- test/test_ops.py | 2 +- torchao/dtypes/__init__.py | 6 +-- torchao/dtypes/{fp6.py => float6_e3m2.py} | 46 +++++++++---------- 4 files changed, 50 insertions(+), 50 deletions(-) rename test/dtypes/{test_fp6.py => test_float6_e3m2.py} (68%) rename torchao/dtypes/{fp6.py => float6_e3m2.py} (75%) diff --git a/test/dtypes/test_fp6.py b/test/dtypes/test_float6_e3m2.py similarity index 68% rename from test/dtypes/test_fp6.py rename to test/dtypes/test_float6_e3m2.py index 8b0f4879c0..2df6224542 100644 --- a/test/dtypes/test_fp6.py +++ b/test/dtypes/test_float6_e3m2.py @@ -5,7 +5,7 @@ parametrize, run_tests, ) -from torchao.dtypes.fp6 import to_fp6, from_fp6 +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 _DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -32,17 +32,17 @@ class TestFp6(TestCase): (0.03, 0b000000), # underflow ], ) - def test_to_fp6_no_bit_packing_correctness(self, device, dtype, input_output): + def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): input, output = input_output input = torch.tensor(input, device=device, dtype=dtype) - assert to_fp6(input, no_bit_packing=True).item() == output + assert to_float6_e3m2(input, no_bit_packing=True).item() == output @parametrize("device", _DEVICES) @parametrize("dtype", _DTYPES) - def test_to_fp6_bit_packing_correctness(self, device, dtype): + def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype): x = torch.randn(128, 128, device=device, dtype=dtype) - results_unpacked = to_fp6(x, no_bit_packing=True) - results_packed = to_fp6(x) + results_unpacked = to_float6_e3m2(x, no_bit_packing=True) + results_packed = to_float6_e3m2(x) val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 @@ -54,27 +54,27 @@ def test_to_fp6_bit_packing_correctness(self, device, dtype): @parametrize("device", _DEVICES) @parametrize("shape", [(), (0,), (10,), (20, 20)]) - def test_to_fp6_no_bit_packing_shape(self, device, shape): + def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape): x = torch.randn(shape, device=device) - result = to_fp6(x, no_bit_packing=True) + result = to_float6_e3m2(x, no_bit_packing=True) assert result.shape == shape @parametrize("device", _DEVICES) @parametrize("shape", [(4,), (20, 20)]) - def test_to_fp6_bit_packing_shape(self, device, shape): + def test_to_float6_e3m2_bit_packing_shape(self, device, shape): x = torch.randn(shape, device=device) - result = to_fp6(x) + result = to_float6_e3m2(x) assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) @parametrize("device", _DEVICES) @parametrize("dtype", _DTYPES) @parametrize("no_bit_packing", [False, True]) - def test_to_fp6_compile(self, device, dtype, no_bit_packing): + def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): x = torch.randn(20, 20, device=device, dtype=dtype) - expected = to_fp6(x, no_bit_packing=no_bit_packing) + expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) - to_fp6_compiled = torch.compile(to_fp6) - actual = to_fp6_compiled(x, no_bit_packing=no_bit_packing) + to_float6_e3m2_compiled = torch.compile(to_float6_e3m2) + actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) @@ -89,15 +89,15 @@ def test_to_fp6_compile(self, device, dtype, no_bit_packing): (0b000011, 0.1875), ], ) - def test_from_fp6_no_bit_packing_correctness(self, device, input_output): + def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): input, output = input_output input = torch.tensor(input, device=device, dtype=torch.uint8) - assert from_fp6(input, no_bit_packing=True).item() == output + assert from_float6_e3m2(input, no_bit_packing=True).item() == output @parametrize("device", _DEVICES) - def test_from_fp6_bit_packing_correctness(self, device): + def test_from_float6_e3m2_bit_packing_correctness(self, device): x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) - actual = from_fp6(x) + actual = from_float6_e3m2(x) bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) x_unpacked0 = bits0 >> 2 @@ -106,17 +106,17 @@ def test_from_fp6_bit_packing_correctness(self, device): x_unpacked3 = bits2 & 0x3F x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) - expected = from_fp6(x_unpacked, no_bit_packing=True) + expected = from_float6_e3m2(x_unpacked, no_bit_packing=True) torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) @parametrize("no_bit_packing", [False, True]) - def test_from_fp6_compile(self, device, no_bit_packing): + def test_from_float6_e3m2_compile(self, device, no_bit_packing): x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) - expected = from_fp6(x, no_bit_packing=no_bit_packing) + expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) - from_fp6_compiled = torch.compile(from_fp6) - actual = from_fp6_compiled(x, no_bit_packing=no_bit_packing) + from_float6_e3m2_compiled = torch.compile(from_float6_e3m2) + actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) torch.testing.assert_close(actual, expected) diff --git a/test/test_ops.py b/test/test_ops.py index 74c8e1163d..40aeee0d64 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -112,7 +112,7 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) - fp32_weight = torchao.dtypes.from_fp6(fp6_weight.view(torch.uint8)) * fp16_scale[:, None] + fp32_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8)) * fp16_scale[:, None] fp16_weight = fp32_weight.half().cuda() results_fp16 = act_cuda @ fp16_weight.T diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index a5d444f4a1..0770068448 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,11 +1,11 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor -from .fp6 import to_fp6, from_fp6 +from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2 __all__ = [ "NF4Tensor", "to_nf4", "UInt4Tensor" - "to_fp6", - "from_fp6", + "to_float6_e3m2", + "from_float6_e3m2", ] diff --git a/torchao/dtypes/fp6.py b/torchao/dtypes/float6_e3m2.py similarity index 75% rename from torchao/dtypes/fp6.py rename to torchao/dtypes/float6_e3m2.py index 9046886dcc..aa15ac3d9d 100644 --- a/torchao/dtypes/fp6.py +++ b/torchao/dtypes/float6_e3m2.py @@ -4,8 +4,8 @@ # some useful constants -FP6_MAX = 28.0 -FP6_SMALLEST_SUBNORMAL = 0.0625 +FLOAT6_E3M2_MAX = 28.0 +FLOAT6_E3M2_SMALLEST_SUBNORMAL = 0.0625 if has_triton(): @@ -14,7 +14,7 @@ # see _to_fp6_pt() for explanation @triton.jit - def _triton_fp32_to_fp6(x: tl.tensor): + def _triton_float32_to_float6_e3m2(x: tl.tensor): x = x.to(tl.float32) x = x * 2.0 ** (-127 + 3) bits = x.to(tl.int32, bitcast=True) @@ -29,15 +29,15 @@ def _triton_fp32_to_fp6(x: tl.tensor): return result.to(tl.uint8) @triton.jit - def _to_fp6_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): + def _to_float6_e3m2_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n # strided memory read. there will be uncoalesced memory access - val0 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4, mask)) - val1 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 1, mask)) - val2 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 2, mask)) - val3 = _triton_fp32_to_fp6(tl.load(in_ptr + offsets * 4 + 3, mask)) + val0 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4, mask)) + val1 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 1, mask)) + val2 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 2, mask)) + val3 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 3, mask)) # bit packing bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 @@ -49,21 +49,21 @@ def _to_fp6_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): tl.store(out_ptr + offsets * 3 + 1, bits1, mask) tl.store(out_ptr + offsets * 3 + 2, bits2, mask) - def _to_fp6_triton(tensor: Tensor) -> Tensor: + def _to_float6_e3m2_triton(tensor: Tensor) -> Tensor: out_shape = tensor.shape[:-1] + (tensor.shape[-1] // 4 * 3,) output = torch.empty(out_shape, device=tensor.device, dtype=torch.uint8) n = tensor.numel() grid_size = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"] * 4),) - _to_fp6_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) + _to_float6_e3m2_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) return output else: - _to_fp6_triton = None + _to_float6_e3m2_triton = None -def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: +def _to_float6_e3m2_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: tensor = tensor.float() # correct exponent bias. this also handles subnormal numbers correctly @@ -91,7 +91,7 @@ def _to_fp6_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) -def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: +def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: """Convert input tensor to FP6. This particular FP6 format has 3 exponent bits and 2 mantissa bits. By default, bit packing is performed: every 4 FP6 values are packed as 3 uint8 values (4 x 6 bits = 3 x 8 bits). @@ -117,14 +117,14 @@ def to_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: # torch.compile() cannot generate fused bit-packing triton kernel, # thus we write custom triton kernel for this specific case. - if tensor.is_cuda and not no_bit_packing and _to_fp6_triton is not None: - return _to_fp6_triton(tensor) + if tensor.is_cuda and not no_bit_packing and _to_float6_e3m2_triton is not None: + return _to_float6_e3m2_triton(tensor) else: - return _to_fp6_pt(tensor, no_bit_packing=no_bit_packing) + return _to_float6_e3m2_pt(tensor, no_bit_packing=no_bit_packing) -def _pt_fp6_to_fp32(tensor: Tensor) -> Tensor: +def _pt_float6_e3m2_to_float32(tensor: Tensor) -> Tensor: bits = tensor.to(torch.int32) # bit extension sign = bits >> 5 << 31 exp_and_man = (bits & 0x1F) << 21 @@ -134,7 +134,7 @@ def _pt_fp6_to_fp32(tensor: Tensor) -> Tensor: return results * 2.0 ** (127 - 3) # exponent bias correction -def from_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: +def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: """Convert an FP6 tensor (created by :func:`to_fp6`) to FP32. Args: @@ -148,13 +148,13 @@ def from_fp6(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: """ assert tensor.dtype == torch.uint8 if no_bit_packing: - return _pt_fp6_to_fp32(tensor) + return _pt_float6_e3m2_to_float32(tensor) assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 3" bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) - val0 = _pt_fp6_to_fp32(bits0 >> 2) - val1 = _pt_fp6_to_fp32(((bits0 & 0x3) << 4) | (bits1 >> 4)) - val2 = _pt_fp6_to_fp32(((bits1 & 0xF) << 2) | (bits2 >> 6)) - val3 = _pt_fp6_to_fp32(bits2 & 0x3F) + val0 = _pt_float6_e3m2_to_float32(bits0 >> 2) + val1 = _pt_float6_e3m2_to_float32(((bits0 & 0x3) << 4) | (bits1 >> 4)) + val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)) + val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F) return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) From 3c636ffca7956d7d238a905811f5a08d4c17729d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 21:40:28 +0800 Subject: [PATCH 68/80] rename --- benchmarks/benchmark_fp6_conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_fp6_conversion.py b/benchmarks/benchmark_fp6_conversion.py index a8bee950e8..dfbbc0aef5 100644 --- a/benchmarks/benchmark_fp6_conversion.py +++ b/benchmarks/benchmark_fp6_conversion.py @@ -24,7 +24,7 @@ def benchmark(f, weight, num_threads = 1): functions = [ ("original", torchao.ops.fp16_to_fp6_original), - ("ours", torch.compile(torchao.dtypes.to_fp6)), + ("ours", torch.compile(torchao.dtypes.to_float6_e3m2)), ] results = [] From f672c7017267f9b2efdc192d413872c44ed7652d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 May 2024 21:41:55 +0800 Subject: [PATCH 69/80] update names --- docs/source/api_ref_dtypes.rst | 4 ++-- torchao/dtypes/float6_e3m2.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index aff808d7fb..36c3c9b4eb 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -12,8 +12,8 @@ torchao.dtypes to_nf4 UInt4Tensor - to_fp6 - from_fp6 + to_float6_e3m2 + from_float6_e3m2 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index aa15ac3d9d..261fd5a808 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -12,7 +12,7 @@ import triton from triton import language as tl - # see _to_fp6_pt() for explanation + # see _to_float6_e3m2_pt() for explanation @triton.jit def _triton_float32_to_float6_e3m2(x: tl.tensor): x = x.to(tl.float32) @@ -110,7 +110,7 @@ def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). - See also :func:`from_fp6` + See also :func:`from_float6_e3m2` """ if not no_bit_packing: assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" @@ -135,7 +135,7 @@ def _pt_float6_e3m2_to_float32(tensor: Tensor) -> Tensor: def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: - """Convert an FP6 tensor (created by :func:`to_fp6`) to FP32. + """Convert an FP6 tensor (created by :func:`to_float6_e3m2`) to FP32. Args: tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must be From 1a310e368dc00b3261236124994cc53ecd6b8afc Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 23 May 2024 01:18:15 +0000 Subject: [PATCH 70/80] add notes about denormal numbers --- test/dtypes/test_float6_e3m2.py | 4 ++++ torchao/dtypes/float6_e3m2.py | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index 2df6224542..13020c0556 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -33,6 +33,7 @@ class TestFp6(TestCase): ], ) def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): + torch.set_flush_denormal(False) input, output = input_output input = torch.tensor(input, device=device, dtype=dtype) assert to_float6_e3m2(input, no_bit_packing=True).item() == output @@ -70,6 +71,7 @@ def test_to_float6_e3m2_bit_packing_shape(self, device, shape): @parametrize("dtype", _DTYPES) @parametrize("no_bit_packing", [False, True]) def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): + torch.set_flush_denormal(False) x = torch.randn(20, 20, device=device, dtype=dtype) expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) @@ -90,6 +92,7 @@ def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): ], ) def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): + torch.set_flush_denormal(False) input, output = input_output input = torch.tensor(input, device=device, dtype=torch.uint8) assert from_float6_e3m2(input, no_bit_packing=True).item() == output @@ -112,6 +115,7 @@ def test_from_float6_e3m2_bit_packing_correctness(self, device): @parametrize("device", _DEVICES) @parametrize("no_bit_packing", [False, True]) def test_from_float6_e3m2_compile(self, device, no_bit_packing): + torch.set_flush_denormal(False) x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index 261fd5a808..d9b893ded6 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -110,6 +110,10 @@ def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). + This implementation requires FP32 denormal numbers to be handled correctly. In PyTorch, you can + use :func:`torch.set_flush_denormal(False)` to disable flushing denormal numbers to zero. Other + code or libraries might set it to ``True`` for performance gain. + See also :func:`from_float6_e3m2` """ if not no_bit_packing: @@ -138,13 +142,18 @@ def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: """Convert an FP6 tensor (created by :func:`to_float6_e3m2`) to FP32. Args: - tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must be - divisible by 3. + tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must + be divisible by 3. no_bit_packing: whether the input does not have bit packing. Returns: - :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output tensor - is 4/3 of that of input tensor. + :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output + tensor is 4/3 of that of input tensor. + + Note: + This implementation requires FP32 denormal numbers to be handled correctly. In PyTorch, you can + use :func:`torch.set_flush_denormal(False)` to disable flushing denormal numbers to zero. Other + code or libraries might set it to ``True`` for performance gain. """ assert tensor.dtype == torch.uint8 if no_bit_packing: From c9ec25521e43a459b520e2616e6dced0185d1b38 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 23 May 2024 01:39:47 +0000 Subject: [PATCH 71/80] update note --- torchao/dtypes/float6_e3m2.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index d9b893ded6..bf8e1f0cc4 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -110,9 +110,10 @@ def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). - This implementation requires FP32 denormal numbers to be handled correctly. In PyTorch, you can - use :func:`torch.set_flush_denormal(False)` to disable flushing denormal numbers to zero. Other - code or libraries might set it to ``True`` for performance gain. + This implementation requires FP32 denormal numbers to be handled correctly. On CPU, you can use + :func:`torch.set_flush_denormal` to disable flushing denormal numbers to zero. Other code or + libraries might set it to ``True`` for performance gain. On CUDA, this is not necessary since + CUDA always handle denormal numbers correctly. See also :func:`from_float6_e3m2` """ @@ -151,9 +152,10 @@ def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: tensor is 4/3 of that of input tensor. Note: - This implementation requires FP32 denormal numbers to be handled correctly. In PyTorch, you can - use :func:`torch.set_flush_denormal(False)` to disable flushing denormal numbers to zero. Other - code or libraries might set it to ``True`` for performance gain. + This implementation requires FP32 denormal numbers to be handled correctly. On CPU, you can use + :func:`torch.set_flush_denormal` to disable flushing denormal numbers to zero. Other code or + libraries might set it to ``True`` for performance gain. On CUDA, this is not necessary since + CUDA always handle denormal numbers correctly. """ assert tensor.dtype == torch.uint8 if no_bit_packing: From d24dba8ae6a23414fb9f3aa240773bb509d7cf0b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 13:51:40 +0800 Subject: [PATCH 72/80] fix merge problem --- benchmarks/benchmark_fp6_conversion.py | 44 ----------------------- torchao/ops.py | 48 +++++--------------------- 2 files changed, 9 insertions(+), 83 deletions(-) delete mode 100644 benchmarks/benchmark_fp6_conversion.py diff --git a/benchmarks/benchmark_fp6_conversion.py b/benchmarks/benchmark_fp6_conversion.py deleted file mode 100644 index dfbbc0aef5..0000000000 --- a/benchmarks/benchmark_fp6_conversion.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torchao -import pandas as pd -from torch.utils.benchmark import Timer - - -def benchmark(f, weight, num_threads = 1): - measurement = Timer( - stmt="f(weight)", - globals={"f": f, "weight": weight}, - num_threads=num_threads, - ).blocked_autorange() - return measurement.median * 1000 - - -if __name__ == "__main__": - M = 8192 - N = 8192 - - fp32_weight = torch.randn(M, N) - fp32_weight_cuda = fp32_weight.cuda() - fp16_weight = fp32_weight.half() - fp16_weight_cuda = fp16_weight.cuda() - - functions = [ - ("original", torchao.ops.fp16_to_fp6_original), - ("ours", torch.compile(torchao.dtypes.to_float6_e3m2)), - ] - - results = [] - for name, f in functions: - results.append(["CPU", "FP32->FP6", name, benchmark(f, fp32_weight)]) - results.append(["CPU", "FP16->FP6", name, benchmark(f, fp16_weight)]) - - results.append(["CPU", "FP32->FP6", f"{name} (num_threads=4)", benchmark(f, fp32_weight, num_threads=4)]) - results.append(["CPU", "FP16->FP6", f"{name} (num_threads=4)", benchmark(f, fp16_weight, num_threads=4)]) - - if name != "original": - results.append(["CUDA", "FP32->FP6", name, benchmark(f, fp32_weight_cuda)]) - results.append(["CUDA", "FP16->FP6", name, benchmark(f, fp16_weight_cuda)]) - - df = pd.DataFrame(results, columns=["device", "dtype", "op", "time (m/s)"]) - df = df.sort_values(["device", "dtype"]) - print(df.to_markdown(index=False)) diff --git a/torchao/ops.py b/torchao/ops.py index c2cb080ff4..3cdc7c0f7d 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,26 +1,14 @@ import torch from torch import Tensor +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 -def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: - """ - See https://pytorch.org/vision/main/generated/torchvision.ops.nms.html - """ - return torch.ops.torchao.nms.default(boxes, scores, iou_threshold) - - -# Defines the meta kernel / fake kernel / abstract impl -@torch.library.impl_abstract("torchao::nms") -def _(dets, scores, iou_threshold): - torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") - torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") - torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") - torch._check( - dets.size(0) == scores.size(0), - lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", - ) - ctx = torch._custom_ops.get_ctx() - num_to_keep = ctx.create_unbacked_symint() - return dets.new_empty(num_to_keep, dtype=torch.long) +def register_custom_op(name): + def decorator(func): + if TORCH_VERSION_AFTER_2_4: + return torch.library.register_fake(f"{name}")(func) + else: + return torch.library.impl_abstract(f"{name}")(func) + return decorator def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: @@ -36,7 +24,6 @@ def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) -# Defines the meta kernel / fake kernel / abstract impl @register_custom_op("torchao::prepack_fp6_weight") def _(fp6_weight): torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") @@ -56,7 +43,7 @@ def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp16_to_fp6_original.default(fp16_tensor) -@torch.library.impl_abstract("torchao::fp16_to_fp6") +@register_custom_op("torchao::fp16_to_fp6") def _(fp16_tensor): torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") @@ -96,20 +83,3 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) - - -def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: - return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) - - -@torch.library.impl_abstract("torchao::fp6_weight_dequant") -def _(fp6_tensor, fp16_scale): - torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D") - torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}") - torch._check(fp16_scale.dim() == 1, lambda: f"scale should be a 2d tensor, got {fp16_scale.dim()}D") - torch._check(fp16_scale.dtype is torch.float16, lambda: f"scale must be FP16, got {fp16_scale.dtype}") - - OC, _IC = fp6_tensor.shape - torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") - - return fp16_scale.new_empty((OC, _IC * 16 // 3)) From ce5dac1673fc00a82658f8fe2119f205ac09b58f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 13:53:30 +0800 Subject: [PATCH 73/80] fix merge conflict --- torchao/ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/ops.py b/torchao/ops.py index 3cdc7c0f7d..7823ca1312 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -2,6 +2,7 @@ from torch import Tensor from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + def register_custom_op(name): def decorator(func): if TORCH_VERSION_AFTER_2_4: @@ -24,6 +25,7 @@ def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) +# Defines the meta kernel / fake kernel / abstract impl @register_custom_op("torchao::prepack_fp6_weight") def _(fp6_weight): torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") From 922446dae89e19187decf54560005d3b08088df6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 14:13:11 +0800 Subject: [PATCH 74/80] add to_fp6 CPU C++ kernel --- setup.py | 3 +- torchao/__init__.py | 13 +- torchao/csrc/fp6_llm/fp6.cpp | 206 +++++++++++++++++++++++++++++++ torchao/csrc/fp6_llm/fp6_llm.cpp | 3 + torchao/dtypes/float6_e3m2.py | 15 ++- torchao/ops.py | 10 +- 6 files changed, 237 insertions(+), 13 deletions(-) create mode 100644 torchao/csrc/fp6_llm/fp6.cpp diff --git a/setup.py b/setup.py index 5d1f32da2b..65ec21e15f 100644 --- a/setup.py +++ b/setup.py @@ -46,11 +46,12 @@ def get_extensions(): use_cuda = torch.cuda.is_available() and CUDA_HOME is not None extension = CUDAExtension if use_cuda else CppExtension - extra_link_args = [] + extra_link_args = ["-fopenmp"] extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", + "-fopenmp", ], "nvcc": [ "-O3" if not debug_mode else "-O0", diff --git a/torchao/__init__.py b/torchao/__init__.py index c982e09a0c..c8f04c1d9e 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,9 +1,3 @@ -from torchao.quantization import ( - apply_weight_only_int8_quant, - apply_dynamic_quant, - autoquant, -) -from . import dtypes import torch _IS_FBCODE = ( hasattr(torch._utils_internal, "IS_FBSOURCE") and @@ -14,6 +8,13 @@ from . import _C from . import ops +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + autoquant, +) +from . import dtypes + __all__ = [ "dtypes", "apply_dynamic_quant", diff --git a/torchao/csrc/fp6_llm/fp6.cpp b/torchao/csrc/fp6_llm/fp6.cpp new file mode 100644 index 0000000000..df21b478f6 --- /dev/null +++ b/torchao/csrc/fp6_llm/fp6.cpp @@ -0,0 +1,206 @@ +#include +#include +#include + +#include +#include +#include + + +class fp6_nan_inf : public std::invalid_argument { +public: + fp6_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6.") { } +}; + +class fp6_overflow : public std::invalid_argument { +public: + fp6_overflow() : std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf. Make sure input < 30.0") { } +}; + +// we need to do this because C++17 does not allow using struct as template non-type parameter +// use the upper 16 bits for num exponent, lower 16 bits for num mantissa +static constexpr uint32_t encode_fp_spec(uint32_t n_exp, uint32_t n_man) { return (n_exp << 16u) | n_man; } +static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u); +static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u); +static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u); + +// NOTE: only works for len < 32 +static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } + +// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" +template +static uint8_t to_fp6_bits(T bits) { + constexpr uint32_t N_EXP = FP_SPEC >> 16u; + constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); + constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; + + // sanity checks. will be removed in template instantiation. + // minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases. + static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); + static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); + + T remainder = 0u; + T sign = bits >> N_EXP_MAN << 5u; + bits &= ones_mask(N_EXP_MAN); // clear sign bit + T result; + + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; + + // all exponent bits are 1s + if (bits >= (ones_mask(N_EXP) << N_MAN)) throw fp6_nan_inf(); + + // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw fp6_overflow(); + + // FP6 normal number (E>=001) + if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { + remainder = bits << (1u + N_EXP + 2u); + bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent + result = sign | (bits >> (N_MAN - 2u)); + } + // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) + else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) { + T exp = bits >> N_MAN; + T man = bits & ones_mask(N_MAN); + + // to make subnormal FP6 from normal FP16 + // step 1: add implicit 1 to mantissa + man |= (1u << N_MAN); + + // step 2: shift mantissa right so that exponent value is equal to + // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) + T shift = EXP_BIAS_DIFF + 1u - exp; + remainder = man << (1u + N_EXP + 2u - shift); + result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 + } + // FP6 underflow. E=000, M=00 + else { + result = sign; + } + + // round to nearest even + constexpr T HALF_REMAINDER = 1u << N_EXP_MAN; + if ((remainder > HALF_REMAINDER) || ((remainder == HALF_REMAINDER) && (result & 0x1u))) { + result += 1; + } + return result; +} + +namespace torchao { + +template void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // exception within OpenMP parallel region must be caught. + // set a flag when exception occurs, then re-raise it. + bool found_nan_inf = false; + bool found_overflow = false; + +#pragma omp parallel for + for (int i = 0; i < n; i++) { + try { fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); } + catch (fp6_nan_inf const &) { found_nan_inf = true; } + catch (fp6_overflow const &) { found_overflow = true; } + } + + if (found_nan_inf) throw fp6_nan_inf(); + if (found_overflow) throw fp6_overflow(); +} + +// this is useful for debugging +at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_fp6_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_fp6_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_fp6_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp6_tensor; +} + +template void to_fp6_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // exception within OpenMP parallel region must be caught. + // set a flag when exception occurs, then re-raise it. + bool found_nan_inf = false; + bool found_overflow = false; + +#pragma omp parallel for + for (int i = 0; i < n / 4; i++) { + try { + uint8_t val0 = to_fp6_bits(bits_ptr[i * 4]); + uint8_t val1 = to_fp6_bits(bits_ptr[i * 4 + 1]); + uint8_t val2 = to_fp6_bits(bits_ptr[i * 4 + 2]); + uint8_t val3 = to_fp6_bits(bits_ptr[i * 4 + 3]); + + fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } + catch (fp6_nan_inf const &) { found_nan_inf = true; } + catch (fp6_overflow const &) { found_overflow = true; } + } + + if (found_nan_inf) throw fp6_nan_inf(); + if (found_overflow) throw fp6_overflow(); +} + +at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cpu()); + TORCH_CHECK(fp_tensor.ndimension() == 2); + + int M = fp_tensor.size(0); + int N = fp_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_fp6_packed_cpu_impl(fp32_ptr, fp6_ptr, n); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_fp6_packed_cpu_impl(fp16_ptr, fp6_ptr, n); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_fp6_packed_cpu_impl(bf16_ptr, fp6_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp6_tensor; +} + + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::to_fp6_unpacked_cpu", &to_fp6_unpacked_cpu); + m.impl("torchao::to_fp6_packed_cpu", &to_fp6_packed_cpu); +} + +} diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index ccde481764..a250d9ce20 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -7,4 +7,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); + + m.def("to_fp6_unpacked_cpu(Tensor tensor) -> Tensor"); + m.def("to_fp6_packed_cpu(Tensor tensor) -> Tensor"); } diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index bf8e1f0cc4..e597c0f761 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -1,6 +1,7 @@ import torch from torch import Tensor from torch.utils._triton import has_triton +from torchao.ops import to_fp6_packed_cpu, to_fp6_unpacked_cpu # some useful constants @@ -63,6 +64,8 @@ def _to_float6_e3m2_triton(tensor: Tensor) -> Tensor: _to_float6_e3m2_triton = None +# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. +# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). def _to_float6_e3m2_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: tensor = tensor.float() @@ -110,16 +113,18 @@ def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). - This implementation requires FP32 denormal numbers to be handled correctly. On CPU, you can use - :func:`torch.set_flush_denormal` to disable flushing denormal numbers to zero. Other code or - libraries might set it to ``True`` for performance gain. On CUDA, this is not necessary since - CUDA always handle denormal numbers correctly. - See also :func:`from_float6_e3m2` """ if not no_bit_packing: assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" + if tensor.is_cpu: + if no_bit_packing: + return to_fp6_unpacked_cpu(tensor) + + *leading_dims, last_dim = tensor.shape + return to_fp6_packed_cpu(tensor.view(-1, last_dim)).view(*leading_dims, -1) + # torch.compile() cannot generate fused bit-packing triton kernel, # thus we write custom triton kernel for this specific case. if tensor.is_cuda and not no_bit_packing and _to_float6_e3m2_triton is not None: diff --git a/torchao/ops.py b/torchao/ops.py index 7823ca1312..b1fe14ea2b 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -45,7 +45,7 @@ def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp16_to_fp6_original.default(fp16_tensor) -@register_custom_op("torchao::fp16_to_fp6") +@register_custom_op("torchao::fp16_to_fp6_original") def _(fp16_tensor): torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") @@ -85,3 +85,11 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) + + +def to_fp6_unpacked_cpu(tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_fp6_unpacked_cpu.default(tensor) + + +def to_fp6_packed_cpu(tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_fp6_packed_cpu.default(tensor) From d287eb36b8ee5f0a38981858404142fea6d0a3d8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 15:08:59 +0800 Subject: [PATCH 75/80] add from_fp6 cpu C++ --- test/dtypes/test_float6_e3m2.py | 4 -- torchao/csrc/fp6_llm/fp6.cpp | 117 ++++++++++++++++++++++++++++++- torchao/csrc/fp6_llm/fp6_llm.cpp | 2 + torchao/dtypes/float6_e3m2.py | 20 ++++-- torchao/ops.py | 8 +++ 5 files changed, 138 insertions(+), 13 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index 13020c0556..2df6224542 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -33,7 +33,6 @@ class TestFp6(TestCase): ], ) def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): - torch.set_flush_denormal(False) input, output = input_output input = torch.tensor(input, device=device, dtype=dtype) assert to_float6_e3m2(input, no_bit_packing=True).item() == output @@ -71,7 +70,6 @@ def test_to_float6_e3m2_bit_packing_shape(self, device, shape): @parametrize("dtype", _DTYPES) @parametrize("no_bit_packing", [False, True]) def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): - torch.set_flush_denormal(False) x = torch.randn(20, 20, device=device, dtype=dtype) expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) @@ -92,7 +90,6 @@ def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): ], ) def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): - torch.set_flush_denormal(False) input, output = input_output input = torch.tensor(input, device=device, dtype=torch.uint8) assert from_float6_e3m2(input, no_bit_packing=True).item() == output @@ -115,7 +112,6 @@ def test_from_float6_e3m2_bit_packing_correctness(self, device): @parametrize("device", _DEVICES) @parametrize("no_bit_packing", [False, True]) def test_from_float6_e3m2_compile(self, device, no_bit_packing): - torch.set_flush_denormal(False) x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) diff --git a/torchao/csrc/fp6_llm/fp6.cpp b/torchao/csrc/fp6_llm/fp6.cpp index df21b478f6..10efc32411 100644 --- a/torchao/csrc/fp6_llm/fp6.cpp +++ b/torchao/csrc/fp6_llm/fp6.cpp @@ -33,6 +33,7 @@ static uint8_t to_fp6_bits(T bits) { constexpr uint32_t N_EXP = FP_SPEC >> 16u; constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; // sanity checks. will be removed in template instantiation. // minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases. @@ -44,8 +45,6 @@ static uint8_t to_fp6_bits(T bits) { bits &= ones_mask(N_EXP_MAN); // clear sign bit T result; - constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; - // all exponent bits are 1s if (bits >= (ones_mask(N_EXP) << N_MAN)) throw fp6_nan_inf(); @@ -86,6 +85,32 @@ static uint8_t to_fp6_bits(T bits) { return result; } +// assume the lower 6 bits contain the data. +template +static T from_fp6_bits(uint8_t a) { + constexpr uint32_t N_EXP = FP_SPEC >> 16u; + constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); + constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; + + uint32_t bits = a; // bit extension + uint32_t sign = bits >> 5u; + uint32_t exp = (bits >> 2u) & 0x7u; + uint32_t man = bits & 0x3u; + + if (exp > 0u) { // FP6 normal numbers + exp += EXP_BIAS_DIFF; + } else if (man > 0u) { // FP6 denormal numbers + uint32_t shift = (man >= 0b10u) ? 1u : 2u; + man = (man << shift) & 0x3u; // shift and remove explicit 1 + exp = 1u + EXP_BIAS_DIFF - shift; + } + // don't need to handle zero, since E=000 and M=00 + + uint32_t result = (sign << N_EXP_MAN) | (exp << N_MAN) | (man << (N_MAN - 2u)); + return static_cast(result); +} + namespace torchao { template void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { @@ -197,10 +222,98 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { return fp6_tensor; } +template +void from_fp6_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n; i++) + fp_ptr[i] = from_fp6_bits(fp6_ptr[i]); +} + +at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (dtype == torch::kFloat32) { + uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_fp6_unpacked_cpu_impl(fp6_ptr, fp32_ptr, n); + + } else if (dtype == torch::kFloat16) { + uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_fp6_unpacked_cpu_impl(fp6_ptr, fp16_ptr, n); + + } else if (dtype == torch::kBFloat16) { + uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_fp6_unpacked_cpu_impl(fp6_ptr, bf16_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp_tensor; +} + +template +void from_fp6_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n / 3; i++) { + uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011 + uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333 + + fp_ptr[i * 4] = from_fp6_bits(bits0 >> 2); + fp_ptr[i * 4 + 1] = from_fp6_bits(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp_ptr[i * 4 + 2] = from_fp6_bits(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp_ptr[i * 4 + 3] = from_fp6_bits(bits2 & 0x3Fu); + } +} + +at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu()); + TORCH_CHECK(fp6_tensor.ndimension() == 2); + + int M = fp6_tensor.size(0); + int N = fp6_tensor.size(1); + TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (dtype == torch::kFloat32) { + uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_fp6_packed_cpu_impl(fp6_ptr, fp32_ptr, n); + + } else if (dtype == torch::kFloat16) { + uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_fp6_packed_cpu_impl(fp6_ptr, fp16_ptr, n); + + } else if (dtype == torch::kBFloat16) { + uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_fp6_packed_cpu_impl(fp6_ptr, bf16_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp_tensor; +} TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::to_fp6_unpacked_cpu", &to_fp6_unpacked_cpu); m.impl("torchao::to_fp6_packed_cpu", &to_fp6_packed_cpu); + m.impl("torchao::from_fp6_unpacked_cpu", &from_fp6_unpacked_cpu); + m.impl("torchao::from_fp6_packed_cpu", &from_fp6_packed_cpu); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index a250d9ce20..f0f670ef02 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -10,4 +10,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("to_fp6_unpacked_cpu(Tensor tensor) -> Tensor"); m.def("to_fp6_packed_cpu(Tensor tensor) -> Tensor"); + m.def("from_fp6_unpacked_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); + m.def("from_fp6_packed_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); } diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index e597c0f761..e9563c79f7 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch.utils._triton import has_triton -from torchao.ops import to_fp6_packed_cpu, to_fp6_unpacked_cpu +from torchao.ops import to_fp6_packed_cpu, to_fp6_unpacked_cpu, from_fp6_packed_cpu, from_fp6_unpacked_cpu # some useful constants @@ -144,13 +144,14 @@ def _pt_float6_e3m2_to_float32(tensor: Tensor) -> Tensor: return results * 2.0 ** (127 - 3) # exponent bias correction -def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: +def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False, dtype: torch.dtype = torch.float32) -> Tensor: """Convert an FP6 tensor (created by :func:`to_float6_e3m2`) to FP32. Args: tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must be divisible by 3. no_bit_packing: whether the input does not have bit packing. + dtype: returned dtype. Returns: :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output @@ -164,13 +165,18 @@ def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: """ assert tensor.dtype == torch.uint8 if no_bit_packing: - return _pt_float6_e3m2_to_float32(tensor) + if tensor.is_cpu: + return from_fp6_unpacked_cpu(tensor, dtype) + + return _pt_float6_e3m2_to_float32(tensor).to(dtype) assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 3" + if tensor.is_cpu: + return from_fp6_packed_cpu(tensor, dtype) bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) - val0 = _pt_float6_e3m2_to_float32(bits0 >> 2) - val1 = _pt_float6_e3m2_to_float32(((bits0 & 0x3) << 4) | (bits1 >> 4)) - val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)) - val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F) + val0 = _pt_float6_e3m2_to_float32(bits0 >> 2).to(dtype) + val1 = _pt_float6_e3m2_to_float32(((bits0 & 0x3) << 4) | (bits1 >> 4)).to(dtype) + val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)).to(dtype) + val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F).to(dtype) return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) diff --git a/torchao/ops.py b/torchao/ops.py index b1fe14ea2b..dfdcf1068b 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -93,3 +93,11 @@ def to_fp6_unpacked_cpu(tensor: Tensor) -> Tensor: def to_fp6_packed_cpu(tensor: Tensor) -> Tensor: return torch.ops.torchao.to_fp6_packed_cpu.default(tensor) + + +def from_fp6_unpacked_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_fp6_unpacked_cpu.default(tensor, dtype) + + +def from_fp6_packed_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_fp6_packed_cpu.default(tensor, dtype) From ce7e09ad3bd4721f4297f153a181681d6d9daa7f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 15:17:47 +0800 Subject: [PATCH 76/80] rename --- .../csrc/fp6_llm/{fp6.cpp => float6_e3m2.cpp} | 100 +++++++++--------- torchao/csrc/fp6_llm/fp6_llm.cpp | 8 +- torchao/dtypes/float6_e3m2.py | 18 ++-- torchao/ops.py | 16 +-- 4 files changed, 69 insertions(+), 73 deletions(-) rename torchao/csrc/fp6_llm/{fp6.cpp => float6_e3m2.cpp} (69%) diff --git a/torchao/csrc/fp6_llm/fp6.cpp b/torchao/csrc/fp6_llm/float6_e3m2.cpp similarity index 69% rename from torchao/csrc/fp6_llm/fp6.cpp rename to torchao/csrc/fp6_llm/float6_e3m2.cpp index 10efc32411..a9bb59bf2b 100644 --- a/torchao/csrc/fp6_llm/fp6.cpp +++ b/torchao/csrc/fp6_llm/float6_e3m2.cpp @@ -7,14 +7,14 @@ #include -class fp6_nan_inf : public std::invalid_argument { +class float6_e3m2_nan_inf : public std::invalid_argument { public: - fp6_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in FP6.") { } + float6_e3m2_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in float6_e3m2.") { } }; -class fp6_overflow : public std::invalid_argument { +class float6_e3m2_overflow : public std::invalid_argument { public: - fp6_overflow() : std::invalid_argument("FP6 overflow. FP6 cannot represent +/-inf. Make sure input < 30.0") { } + float6_e3m2_overflow() : std::invalid_argument("float6_e3m2 overflow. float6_e3m2 cannot represent +/-inf. Make sure input < 30.0") { } }; // we need to do this because C++17 does not allow using struct as template non-type parameter @@ -29,7 +29,7 @@ static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" template -static uint8_t to_fp6_bits(T bits) { +static uint8_t to_float6_e3m2_bits(T bits) { constexpr uint32_t N_EXP = FP_SPEC >> 16u; constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; @@ -46,10 +46,10 @@ static uint8_t to_fp6_bits(T bits) { T result; // all exponent bits are 1s - if (bits >= (ones_mask(N_EXP) << N_MAN)) throw fp6_nan_inf(); + if (bits >= (ones_mask(N_EXP) << N_MAN)) throw float6_e3m2_nan_inf(); // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) - if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw fp6_overflow(); + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw float6_e3m2_overflow(); // FP6 normal number (E>=001) if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { @@ -87,7 +87,7 @@ static uint8_t to_fp6_bits(T bits) { // assume the lower 6 bits contain the data. template -static T from_fp6_bits(uint8_t a) { +static T from_float6_e3m2_bits(uint8_t a) { constexpr uint32_t N_EXP = FP_SPEC >> 16u; constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; @@ -113,7 +113,7 @@ static T from_fp6_bits(uint8_t a) { namespace torchao { -template void to_fp6_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { +template void to_float6_e3m2_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { // exception within OpenMP parallel region must be caught. // set a flag when exception occurs, then re-raise it. bool found_nan_inf = false; @@ -121,17 +121,17 @@ template void to_fp6_unpacked_cpu_impl(const T *b #pragma omp parallel for for (int i = 0; i < n; i++) { - try { fp6_ptr[i] = to_fp6_bits(bits_ptr[i]); } - catch (fp6_nan_inf const &) { found_nan_inf = true; } - catch (fp6_overflow const &) { found_overflow = true; } + try { fp6_ptr[i] = to_float6_e3m2_bits(bits_ptr[i]); } + catch (float6_e3m2_nan_inf const &) { found_nan_inf = true; } + catch (float6_e3m2_overflow const &) { found_overflow = true; } } - if (found_nan_inf) throw fp6_nan_inf(); - if (found_overflow) throw fp6_overflow(); + if (found_nan_inf) throw float6_e3m2_nan_inf(); + if (found_overflow) throw float6_e3m2_overflow(); } // this is useful for debugging -at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { +at::Tensor to_float6_e3m2_unpacked_cpu(at::Tensor fp_tensor) { TORCH_CHECK(fp_tensor.is_contiguous()); TORCH_CHECK(fp_tensor.is_cpu()); @@ -144,15 +144,15 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); + to_float6_e3m2_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); + to_float6_e3m2_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); + to_float6_e3m2_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -161,7 +161,7 @@ at::Tensor to_fp6_unpacked_cpu(at::Tensor fp_tensor) { return fp6_tensor; } -template void to_fp6_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { +template void to_float6_e3m2_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { // exception within OpenMP parallel region must be caught. // set a flag when exception occurs, then re-raise it. bool found_nan_inf = false; @@ -170,24 +170,24 @@ template void to_fp6_packed_cpu_impl(const T *bit #pragma omp parallel for for (int i = 0; i < n / 4; i++) { try { - uint8_t val0 = to_fp6_bits(bits_ptr[i * 4]); - uint8_t val1 = to_fp6_bits(bits_ptr[i * 4 + 1]); - uint8_t val2 = to_fp6_bits(bits_ptr[i * 4 + 2]); - uint8_t val3 = to_fp6_bits(bits_ptr[i * 4 + 3]); + uint8_t val0 = to_float6_e3m2_bits(bits_ptr[i * 4]); + uint8_t val1 = to_float6_e3m2_bits(bits_ptr[i * 4 + 1]); + uint8_t val2 = to_float6_e3m2_bits(bits_ptr[i * 4 + 2]); + uint8_t val3 = to_float6_e3m2_bits(bits_ptr[i * 4 + 3]); fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 } - catch (fp6_nan_inf const &) { found_nan_inf = true; } - catch (fp6_overflow const &) { found_overflow = true; } + catch (float6_e3m2_nan_inf const &) { found_nan_inf = true; } + catch (float6_e3m2_overflow const &) { found_overflow = true; } } - if (found_nan_inf) throw fp6_nan_inf(); - if (found_overflow) throw fp6_overflow(); + if (found_nan_inf) throw float6_e3m2_nan_inf(); + if (found_overflow) throw float6_e3m2_overflow(); } -at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { +at::Tensor to_float6_e3m2_packed_cpu(at::Tensor fp_tensor) { TORCH_CHECK(fp_tensor.is_contiguous()); TORCH_CHECK(fp_tensor.is_cpu()); TORCH_CHECK(fp_tensor.ndimension() == 2); @@ -205,15 +205,15 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { if (dtype == torch::kFloat32) { const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_packed_cpu_impl(fp32_ptr, fp6_ptr, n); + to_float6_e3m2_packed_cpu_impl(fp32_ptr, fp6_ptr, n); } else if (dtype == torch::kFloat16) { const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_packed_cpu_impl(fp16_ptr, fp6_ptr, n); + to_float6_e3m2_packed_cpu_impl(fp16_ptr, fp6_ptr, n); } else if (dtype == torch::kBFloat16) { const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - to_fp6_packed_cpu_impl(bf16_ptr, fp6_ptr, n); + to_float6_e3m2_packed_cpu_impl(bf16_ptr, fp6_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -223,13 +223,13 @@ at::Tensor to_fp6_packed_cpu(at::Tensor fp_tensor) { } template -void from_fp6_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +void from_float6_e3m2_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { #pragma omp parallel for for (int i = 0; i < n; i++) - fp_ptr[i] = from_fp6_bits(fp6_ptr[i]); + fp_ptr[i] = from_float6_e3m2_bits(fp6_ptr[i]); } -at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { +at::Tensor from_float6_e3m2_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); TORCH_CHECK(fp6_tensor.is_cpu()); @@ -242,15 +242,15 @@ at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { if (dtype == torch::kFloat32) { uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_fp6_unpacked_cpu_impl(fp6_ptr, fp32_ptr, n); + from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, fp32_ptr, n); } else if (dtype == torch::kFloat16) { uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_fp6_unpacked_cpu_impl(fp6_ptr, fp16_ptr, n); + from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, fp16_ptr, n); } else if (dtype == torch::kBFloat16) { uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_fp6_unpacked_cpu_impl(fp6_ptr, bf16_ptr, n); + from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, bf16_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -260,21 +260,21 @@ at::Tensor from_fp6_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { } template -void from_fp6_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +void from_float6_e3m2_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { #pragma omp parallel for for (int i = 0; i < n / 3; i++) { uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011 uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222 uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333 - fp_ptr[i * 4] = from_fp6_bits(bits0 >> 2); - fp_ptr[i * 4 + 1] = from_fp6_bits(((bits0 & 0x3u) << 4) | (bits1 >> 4)); - fp_ptr[i * 4 + 2] = from_fp6_bits(((bits1 & 0xFu) << 2) | (bits2 >> 6)); - fp_ptr[i * 4 + 3] = from_fp6_bits(bits2 & 0x3Fu); + fp_ptr[i * 4] = from_float6_e3m2_bits(bits0 >> 2); + fp_ptr[i * 4 + 1] = from_float6_e3m2_bits(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp_ptr[i * 4 + 2] = from_float6_e3m2_bits(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp_ptr[i * 4 + 3] = from_float6_e3m2_bits(bits2 & 0x3Fu); } } -at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { +at::Tensor from_float6_e3m2_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); TORCH_CHECK(fp6_tensor.is_contiguous()); TORCH_CHECK(fp6_tensor.is_cpu()); @@ -292,15 +292,15 @@ at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { if (dtype == torch::kFloat32) { uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_fp6_packed_cpu_impl(fp6_ptr, fp32_ptr, n); + from_float6_e3m2_packed_cpu_impl(fp6_ptr, fp32_ptr, n); } else if (dtype == torch::kFloat16) { uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_fp6_packed_cpu_impl(fp6_ptr, fp16_ptr, n); + from_float6_e3m2_packed_cpu_impl(fp6_ptr, fp16_ptr, n); } else if (dtype == torch::kBFloat16) { uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); - from_fp6_packed_cpu_impl(fp6_ptr, bf16_ptr, n); + from_float6_e3m2_packed_cpu_impl(fp6_ptr, bf16_ptr, n); } else { throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); @@ -310,10 +310,10 @@ at::Tensor from_fp6_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { } TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::to_fp6_unpacked_cpu", &to_fp6_unpacked_cpu); - m.impl("torchao::to_fp6_packed_cpu", &to_fp6_packed_cpu); - m.impl("torchao::from_fp6_unpacked_cpu", &from_fp6_unpacked_cpu); - m.impl("torchao::from_fp6_packed_cpu", &from_fp6_packed_cpu); + m.impl("torchao::to_float6_e3m2_unpacked_cpu", &to_float6_e3m2_unpacked_cpu); + m.impl("torchao::to_float6_e3m2_packed_cpu", &to_float6_e3m2_packed_cpu); + m.impl("torchao::from_float6_e3m2_unpacked_cpu", &from_float6_e3m2_unpacked_cpu); + m.impl("torchao::from_float6_e3m2_packed_cpu", &from_float6_e3m2_packed_cpu); } } diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index f0f670ef02..5239593bb6 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -8,8 +8,8 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); - m.def("to_fp6_unpacked_cpu(Tensor tensor) -> Tensor"); - m.def("to_fp6_packed_cpu(Tensor tensor) -> Tensor"); - m.def("from_fp6_unpacked_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); - m.def("from_fp6_packed_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); + m.def("to_float6_e3m2_unpacked_cpu(Tensor tensor) -> Tensor"); + m.def("to_float6_e3m2_packed_cpu(Tensor tensor) -> Tensor"); + m.def("from_float6_e3m2_unpacked_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); + m.def("from_float6_e3m2_packed_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); } diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py index e9563c79f7..0c27838d06 100644 --- a/torchao/dtypes/float6_e3m2.py +++ b/torchao/dtypes/float6_e3m2.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch.utils._triton import has_triton -from torchao.ops import to_fp6_packed_cpu, to_fp6_unpacked_cpu, from_fp6_packed_cpu, from_fp6_unpacked_cpu +from torchao.ops import to_float6_e3m2_packed_cpu, to_float6_e3m2_unpacked_cpu, from_float6_e3m2_packed_cpu, from_float6_e3m2_unpacked_cpu # some useful constants @@ -120,10 +120,10 @@ def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: if tensor.is_cpu: if no_bit_packing: - return to_fp6_unpacked_cpu(tensor) + return to_float6_e3m2_unpacked_cpu(tensor) *leading_dims, last_dim = tensor.shape - return to_fp6_packed_cpu(tensor.view(-1, last_dim)).view(*leading_dims, -1) + return to_float6_e3m2_packed_cpu(tensor.view(-1, last_dim)).view(*leading_dims, -1) # torch.compile() cannot generate fused bit-packing triton kernel, # thus we write custom triton kernel for this specific case. @@ -134,6 +134,8 @@ def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: return _to_float6_e3m2_pt(tensor, no_bit_packing=no_bit_packing) +# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. +# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). def _pt_float6_e3m2_to_float32(tensor: Tensor) -> Tensor: bits = tensor.to(torch.int32) # bit extension sign = bits >> 5 << 31 @@ -156,23 +158,17 @@ def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False, dtype: torch. Returns: :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output tensor is 4/3 of that of input tensor. - - Note: - This implementation requires FP32 denormal numbers to be handled correctly. On CPU, you can use - :func:`torch.set_flush_denormal` to disable flushing denormal numbers to zero. Other code or - libraries might set it to ``True`` for performance gain. On CUDA, this is not necessary since - CUDA always handle denormal numbers correctly. """ assert tensor.dtype == torch.uint8 if no_bit_packing: if tensor.is_cpu: - return from_fp6_unpacked_cpu(tensor, dtype) + return from_float6_e3m2_unpacked_cpu(tensor, dtype) return _pt_float6_e3m2_to_float32(tensor).to(dtype) assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 3" if tensor.is_cpu: - return from_fp6_packed_cpu(tensor, dtype) + return from_float6_e3m2_packed_cpu(tensor, dtype) bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) val0 = _pt_float6_e3m2_to_float32(bits0 >> 2).to(dtype) diff --git a/torchao/ops.py b/torchao/ops.py index dfdcf1068b..7fce2de22f 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -87,17 +87,17 @@ def _(_in_feats, _weights, _scales, splitK = 1): return _in_feats.new_empty((BS, OC)) -def to_fp6_unpacked_cpu(tensor: Tensor) -> Tensor: - return torch.ops.torchao.to_fp6_unpacked_cpu.default(tensor) +def to_float6_e3m2_unpacked_cpu(tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_float6_e3m2_unpacked_cpu.default(tensor) -def to_fp6_packed_cpu(tensor: Tensor) -> Tensor: - return torch.ops.torchao.to_fp6_packed_cpu.default(tensor) +def to_float6_e3m2_packed_cpu(tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_float6_e3m2_packed_cpu.default(tensor) -def from_fp6_unpacked_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: - return torch.ops.torchao.from_fp6_unpacked_cpu.default(tensor, dtype) +def from_float6_e3m2_unpacked_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_float6_e3m2_unpacked_cpu.default(tensor, dtype) -def from_fp6_packed_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: - return torch.ops.torchao.from_fp6_packed_cpu.default(tensor, dtype) +def from_float6_e3m2_packed_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_float6_e3m2_packed_cpu.default(tensor, dtype) From 22007a13ffaca99d2b1fbfaf2b8f9c1ae4bfa169 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 15:26:26 +0800 Subject: [PATCH 77/80] add some comments --- test/dtypes/test_float6_e3m2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py index 2df6224542..b821504731 100644 --- a/test/dtypes/test_float6_e3m2.py +++ b/test/dtypes/test_float6_e3m2.py @@ -83,10 +83,10 @@ def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): [ (0b000000, 0.0), (0b001100, 1.0), - (0b011111, 28.0), - (0b000001, 0.0625), + (0b011111, 28.0), # max + (0b000001, 0.0625), # min (0b001110, 1.5), - (0b000011, 0.1875), + (0b000011, 0.1875), # subnormal ], ) def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): From f97421a98767b352e3a7c1c09e412de631a9a65d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 16:16:01 +0800 Subject: [PATCH 78/80] small cleanup --- torchao/csrc/fp6_llm/float6_e3m2.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/fp6_llm/float6_e3m2.cpp b/torchao/csrc/fp6_llm/float6_e3m2.cpp index a9bb59bf2b..d1a58da4d6 100644 --- a/torchao/csrc/fp6_llm/float6_e3m2.cpp +++ b/torchao/csrc/fp6_llm/float6_e3m2.cpp @@ -40,10 +40,9 @@ static uint8_t to_float6_e3m2_bits(T bits) { static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); - T remainder = 0u; T sign = bits >> N_EXP_MAN << 5u; bits &= ones_mask(N_EXP_MAN); // clear sign bit - T result; + T result, remainder; // all exponent bits are 1s if (bits >= (ones_mask(N_EXP) << N_MAN)) throw float6_e3m2_nan_inf(); @@ -74,6 +73,7 @@ static uint8_t to_float6_e3m2_bits(T bits) { } // FP6 underflow. E=000, M=00 else { + remainder = 0u; result = sign; } From f727de0c5f5cfd50b6838f9a7f6a77deafd39b01 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 18:38:46 +0800 Subject: [PATCH 79/80] always use uint32_t for bit manipulation --- torchao/csrc/fp6_llm/float6_e3m2.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchao/csrc/fp6_llm/float6_e3m2.cpp b/torchao/csrc/fp6_llm/float6_e3m2.cpp index d1a58da4d6..16d71f51d9 100644 --- a/torchao/csrc/fp6_llm/float6_e3m2.cpp +++ b/torchao/csrc/fp6_llm/float6_e3m2.cpp @@ -29,7 +29,7 @@ static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } // inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" template -static uint8_t to_float6_e3m2_bits(T bits) { +static uint8_t to_float6_e3m2_bits(T bits_) { constexpr uint32_t N_EXP = FP_SPEC >> 16u; constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; @@ -40,9 +40,10 @@ static uint8_t to_float6_e3m2_bits(T bits) { static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); - T sign = bits >> N_EXP_MAN << 5u; + uint32_t bits = bits_; // bit extension + uint32_t sign = bits >> N_EXP_MAN << 5u; bits &= ones_mask(N_EXP_MAN); // clear sign bit - T result, remainder; + uint32_t result, remainder; // all exponent bits are 1s if (bits >= (ones_mask(N_EXP) << N_MAN)) throw float6_e3m2_nan_inf(); @@ -52,14 +53,14 @@ static uint8_t to_float6_e3m2_bits(T bits) { // FP6 normal number (E>=001) if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { - remainder = bits << (1u + N_EXP + 2u); - bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent + remainder = bits << (32u - (N_MAN - 2u)); // shift the truncated bits to most significant position + bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent result = sign | (bits >> (N_MAN - 2u)); } // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) { - T exp = bits >> N_MAN; - T man = bits & ones_mask(N_MAN); + uint32_t exp = bits >> N_MAN; + uint32_t man = bits & ones_mask(N_MAN); // to make subnormal FP6 from normal FP16 // step 1: add implicit 1 to mantissa @@ -67,8 +68,8 @@ static uint8_t to_float6_e3m2_bits(T bits) { // step 2: shift mantissa right so that exponent value is equal to // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) - T shift = EXP_BIAS_DIFF + 1u - exp; - remainder = man << (1u + N_EXP + 2u - shift); + uint32_t shift = EXP_BIAS_DIFF + 1u - exp; + remainder = man << (32u - (N_MAN - 2u + shift)); // shift the truncated bits to most significant position result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 } // FP6 underflow. E=000, M=00 @@ -78,8 +79,7 @@ static uint8_t to_float6_e3m2_bits(T bits) { } // round to nearest even - constexpr T HALF_REMAINDER = 1u << N_EXP_MAN; - if ((remainder > HALF_REMAINDER) || ((remainder == HALF_REMAINDER) && (result & 0x1u))) { + if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 0x1u))) { result += 1; } return result; From 78e79ac2c9e302474fe0b965c4810052469ad1e1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 25 May 2024 18:40:57 +0800 Subject: [PATCH 80/80] simplify test --- test/test_ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4aa7fddbd2..4e463b4e26 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -99,9 +99,8 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) - fp32_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8)) * fp16_scale[:, None] - fp16_weight = fp32_weight.half().cuda() - results_fp16 = act_cuda @ fp16_weight.T + fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] + results_fp16 = act_cuda @ fp16_weight.cuda().T error = (results_fp6 - results_fp16).abs() relative_error = error / results_fp16.abs()