Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ void dynamic_scaled_fp8_quant(
torch::Tensor& input,
torch::Tensor& scale);

void quant_per_tensor(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
9 changes: 9 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");

ops.def(
"quant_per_tensor",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&quant_per_tensor),
"Per-tensor Quantization");


// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
Expand Down
50 changes: 50 additions & 0 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <assert.h>

#include "../../dispatch_utils.h"

static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}

namespace vllm {

template <typename scalar_t, typename scale_type>
__global__ void quant_kernel(
const scalar_t* __restrict__ input,
int8_t* __restrict__ out,
scale_type scale,
const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
} // namespace vllm

void quant_per_tensor(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale,
hidden_size);
});
}
2 changes: 2 additions & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
nvidia-cutlass == 3.5.0

Loading