Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {

///////////////////////////////////////////////////////////////////////////////
// A more general block-wise dequantization implementation that supports
// different block sizes and block orientations (row-wise/column-wise).
template <
int Row_, ///< rows of a matrix
int Column_ ///< columns of a matrix
>
struct Shape2D {
static int const kRow = Row_; ///< rows of a matrix
static int const kColumn = Column_; ///< columns of a matrix
static int const kCount = Row_ * Column_; ///< total number of elements in a matrix
};

/**
* @brief Blockwise quantization constants
* @tparam ElementT source data type, e.g. fp32/fp16
* @tparam block_size number of elemenets quantized together
* @tparam qbits number of bits in each quantized element
* @tparam Columnwise true: elements in a block come from one single column
* false: elements in a block come from one single row
*/
template <
typename ElementT,
int32_t block_size,
int32_t qbits,
bool Columnwise>
struct BlkQuantTraits {
// number of qbit elements to pack into whole bytes
static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0;
static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!");

using QuantBlk = std::conditional_t<Columnwise, Shape2D<block_size, 1>, Shape2D<1, block_size>>;

using ThreadBlk = Shape2D<QuantBlk::kRow * kPackSize, QuantBlk::kColumn>;
};

template <class T, typename ZeroT>
Status Dequantize4Bits(
T* output,
Expand All @@ -19,6 +56,18 @@ Status Dequantize4Bits(
int block_size,
cudaStream_t stream);

template <class T, typename ZeroT>
Status Dequantize8Bits(
T* output,
const uint8_t* quant_data,
const T* scales_data,
const ZeroT* zero_points,
const int32_t* reorder_idx,
int k,
int n,
int block_size,
cudaStream_t stream);

/**
* @brief Dequantize a block-wise quantized matrix, and store the result in a
* column major matrix for use in subsequent GEMM. This implementation supports
Expand All @@ -45,6 +94,17 @@ Status DequantizeBlockwise4b(
int columns,
cudaStream_t stream);

template <typename T>
Status DequantizeBlockwise8b(
T* dst,
const uint8_t* qelements,
const T* scales,
const uint8_t* zero_points,
int block_size,
bool columnwise,
int rows,
int columns,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading
Loading