From 7bf01173a886f533acbe33217cacb761c8fbe66f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 11 Apr 2024 03:48:23 +0000 Subject: [PATCH 1/8] rebased w8a8 --- CMakeLists.txt | 4 + benchmarks/benchmark_throughput.py | 2 +- csrc/attention/attention_dtypes.h | 1 + csrc/attention/dtype_float32.cuh | 8 + csrc/attention/dtype_int8.cuh | 49 ++ csrc/dispatch_utils.h | 10 +- csrc/ops.h | 21 + csrc/pybind.cpp | 37 + .../quantization/smoothquant/fused_kernels.cu | 162 +++++ .../smoothquant/int8gemm/allocator.h | 232 ++++++ .../smoothquant/int8gemm/cublasAlgoMap.cc | 188 +++++ .../smoothquant/int8gemm/cublasAlgoMap.h | 108 +++ .../int8gemm/cublasINT8MMWrapper.cc | 676 ++++++++++++++++++ .../int8gemm/cublasINT8MMWrapper.h | 71 ++ .../smoothquant/int8gemm/cuda_utils.cc | 45 ++ .../smoothquant/int8gemm/cuda_utils.h | 158 ++++ .../smoothquant/int8gemm/int8_gemm.h | 127 ++++ csrc/quantization/smoothquant/quant_utils.cuh | 243 +++++++ csrc/reduction_utils.cuh | 25 + examples/offline_profile.py | 268 +++++++ experiments.sh | 211 ++++++ neuralmagic/tools/profiler/print_table.py | 77 ++ neuralmagic/tools/profiler/visualize_trace.py | 209 ++++++ requirements-dev.txt | 5 + tests/kernels/test_fusion.py | 94 +++ vllm/config.py | 3 +- vllm/engine/arg_utils.py | 23 +- .../layers/quantization/__init__.py | 2 + .../layers/quantization/smoothquant.py | 348 +++++++++ vllm/model_executor/model_loader.py | 24 +- vllm/model_executor/models/__init__.py | 10 + vllm/model_executor/models/llama.py | 159 +++- vllm/profiler/__init__.py | 5 + vllm/profiler/nm_profile.py | 346 +++++++++ vllm/profiler/utils.py | 146 ++++ vllm/worker/model_runner.py | 4 +- 36 files changed, 4046 insertions(+), 55 deletions(-) create mode 100644 csrc/attention/dtype_int8.cuh create mode 100644 csrc/quantization/smoothquant/fused_kernels.cu create mode 100644 csrc/quantization/smoothquant/int8gemm/allocator.h create mode 100644 csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc create mode 100644 csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h create mode 100644 csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc create mode 100644 csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h create mode 100644 csrc/quantization/smoothquant/int8gemm/cuda_utils.cc create mode 100644 csrc/quantization/smoothquant/int8gemm/cuda_utils.h create mode 100644 csrc/quantization/smoothquant/int8gemm/int8_gemm.h create mode 100644 csrc/quantization/smoothquant/quant_utils.cuh create mode 100644 examples/offline_profile.py create mode 100755 experiments.sh create mode 100644 neuralmagic/tools/profiler/print_table.py create mode 100644 neuralmagic/tools/profiler/visualize_trace.py create mode 100644 tests/kernels/test_fusion.py create mode 100644 vllm/model_executor/layers/quantization/smoothquant.py create mode 100644 vllm/profiler/__init__.py create mode 100644 vllm/profiler/nm_profile.py create mode 100644 vllm/profiler/utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 184515118128..4cb48c119e6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,10 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc" + "csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc" + "csrc/quantization/smoothquant/int8gemm/cuda_utils.cc" + "csrc/quantization/smoothquant/fused_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d6bf18c82e46..a708dbde4f50 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -261,7 +261,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'gptq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 64f86381d9db..1efd714a26e1 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -5,3 +5,4 @@ #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" #include "dtype_fp8.cuh" +#include "dtype_int8.cuh" diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb..51407f35e2d0 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) { return c; } +// for compiling, the above function seems to be useless +inline __device__ Float4_ add(Float4_ a, Float4_ b) { + Float4_ c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + // Vector multiplication. template<> inline __device__ float mul(float a, float b) { diff --git a/csrc/attention/dtype_int8.cuh b/csrc/attention/dtype_int8.cuh new file mode 100644 index 000000000000..91e6ec40b038 --- /dev/null +++ b/csrc/attention/dtype_int8.cuh @@ -0,0 +1,49 @@ +#pragma once + +#include +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +namespace vllm { +// define int8 vector types for quantization of kv cache + +template<> +struct Vec { + using Type = int8_t; +}; + +template<> +struct Vec { + using Type = int16_t; +}; + +template<> +struct Vec { + using Type = int32_t; +}; + +template<> +struct Vec { + using Type = int64_t; +}; + +template<> +struct FloatVec { + using Type = float; +}; + +template<> +struct FloatVec { + using Type = float2; +}; + +template<> +struct FloatVec { + using Type = Float4_; +}; + +template<> +struct FloatVec { + using Type = Float8_; +}; +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4b..20ea511a8529 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -9,12 +9,20 @@ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index 41ecc1e89371..76ba6188c7a8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -131,6 +131,27 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); +void dequant( + torch::Tensor& out, + torch::Tensor& input, + float scale); + +void dequant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale, + float weight_dequant_scale); + +void quant( + torch::Tensor& out, + torch::Tensor& input, + float scale); + +void quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc16211..8c4fbdaed105 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -1,6 +1,7 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" +#include "quantization/smoothquant/int8gemm/int8_gemm.h" #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -49,6 +50,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "fused_add_rms_norm", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); + ops.def( + "dequant", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + float>(&dequant), + "Dequant."); + ops.def( + "dequant", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + torch::Tensor&, + float>(&dequant), + "Per-token dequant."); + ops.def( + "quant", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + float>(&quant), + "Quant."); + ops.def( + "quant", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + torch::Tensor&>( + &quant), + "Per-token quant."); // Rotary embedding ops.def( @@ -71,6 +102,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + pybind11::class_(ops, "I8CUGEMM") + .def(pybind11::init<>()) + .def("linear_a8_w8_o32", &I8CUGEMM::linear_a8_w8_o32) + .def("linear_a8_w8_o8", &I8CUGEMM::linear_a8_w8_o8) + .def("linear_a8_w8_o8_", &I8CUGEMM::linear_a8_w8_o8_) + .def("linear_a8_w8_o32_", &I8CUGEMM::linear_a8_w8_o32_); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/smoothquant/fused_kernels.cu b/csrc/quantization/smoothquant/fused_kernels.cu new file mode 100644 index 000000000000..1e9d1acf8f47 --- /dev/null +++ b/csrc/quantization/smoothquant/fused_kernels.cu @@ -0,0 +1,162 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" +#include "../../reduction_utils.cuh" +#include "quant_utils.cuh" + +namespace vllm { +template +__global__ void dequant_kernel( + const int32_t* __restrict__ input, + scalar_t* __restrict__ out, + const float scale, + const int m, + const int hidden_size, + const int input_stride, + const int out_stride, + const float* __restrict__ act_scale = nullptr) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + float scale_ = scale; + if constexpr (use_per_token_dequant) { + scale_ = scale * act_scale[token_idx]; + } + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * out_stride + i] = + (scalar_t)(((float)input[token_idx * input_stride + i]) * scale_); + } +} + +template +__global__ void quant_kernel( + const scalar_t* __restrict__ input, + int8_t* __restrict__ out, + scale_type scale, + const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + if constexpr (use_per_token_quant) { + float amax_val = 0.0f; + const float zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = (float)input[token_idx * hidden_size + i]; + val = val > zero ? val : -val; + if (val > amax_val) + amax_val = val; + } + + __shared__ float s_amax; + const float block_amax_val = blockReduceMax(amax_val); + if (tid == 0) { + s_amax = block_amax_val; + scale[token_idx] = block_amax_val / 127.0f; + } + __syncthreads(); + + float tmp_scale = 127.0f / s_amax; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) * tmp_scale); + } + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } + } +} +} // namespace vllm + +void dequant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + int input_stride = input.stride(-2); + int out_stride = out.stride(-2); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] { + vllm::dequant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale, + num_tokens, + hidden_size, + input_stride, + out_stride); + }); +} + +void dequant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale, + float weight_dequant_scale) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + int input_stride = input.stride(-2); + int out_stride = out.stride(-2); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] { + vllm::dequant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + weight_dequant_scale, + num_tokens, + hidden_size, + input_stride, + out_stride, + scale.data_ptr()); + }); +} + +void quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale, + hidden_size); + }); +} + +void quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale) { // [num_tokens] + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), + hidden_size); + }); +} \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/allocator.h b/csrc/quantization/smoothquant/int8gemm/allocator.h new file mode 100644 index 000000000000..79be2e99e29c --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/allocator.h @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Memory Allocator + **/ + +#pragma once + +#include "cuda_utils.h" +#include +#include +#include + +#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 +#define CUDA_MEMORY_POOL_DISABLED +#endif + +enum class AllocatorType { CUDA, TF, TH }; + +enum class ReallocType { + INCREASE, + REUSE, + DECREASE, +}; + +class IAllocator { +public: + virtual ~IAllocator(){}; + + virtual void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) = 0; + virtual void free(void **ptr, bool is_host = false) const = 0; + virtual void setStream(cudaStream_t stream) = 0; + virtual cudaStream_t returnStream() = 0; + virtual void memSet(void *ptr, const int val, const size_t size) = 0; + + template + void *reMalloc(T *ptr, size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + size = ((size + 31) / 32) * 32; // make the buffer align with 32 bytes + void *void_ptr = (void *)ptr; + void *ptr_address = getAddress(void_ptr); + if (isExist(ptr_address)) { + ReallocType realloc_type = isReMalloc(ptr_address, size); + if (realloc_type == ReallocType::INCREASE) { + // FT_LOG_DEBUG("ReMalloc the buffer %p since it is too small.", + // void_ptr); + free((void **)(&void_ptr), is_host); + return malloc(size, is_set_zero, is_host); + } +#if !defined(CUDA_MEMORY_POOL_DISABLED) + else if (realloc_type == ReallocType::DECREASE) { + // FT_LOG_DEBUG("ReMalloc the buffer %p to release unused memory to + // memory pools.", void_ptr); + free((void **)(&void_ptr), is_host); + return malloc(size, is_set_zero, is_host); + } +#endif + else { + // FT_LOG_DEBUG("Reuse original buffer %p with size %d and do nothing + // for reMalloc.", void_ptr, size); + if (is_set_zero) { + memSet(void_ptr, 0, size); + } + return void_ptr; + } + } else { + // FT_LOG_DEBUG("Cannot find buffer %p, mallocing new one.", void_ptr); + return malloc(size, is_set_zero, is_host); + } + } + +protected: + virtual bool isExist(void *address) const = 0; + virtual ReallocType isReMalloc(void *address, size_t size) const = 0; + + void *getAddress(void *ptr) const { return ptr; } +}; + +template class Allocator; + +template <> class Allocator : public IAllocator { +private: + const int device_id_; + cudaStream_t stream_ = 0; // initialize as default stream + std::unordered_map *pointer_mapping_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + if (pointer_mapping_->at(address) < size) { + return ReallocType::INCREASE; + } else if (pointer_mapping_->at(address) == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator(int device_id) : device_id_(device_id) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + pointer_mapping_ = new std::unordered_map(); +#if defined(CUDA_MEMORY_POOL_DISABLED) + // FT_LOG_WARNING( + // "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync + // cudaMalloc/Free." "Note this may lead to hang with NCCL kernels + // launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); +#else + int device_count = 1; + check_cuda_error(cudaGetDeviceCount(&device_count)); + cudaMemPool_t mempool; + check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); + cudaMemAccessDesc desc = {}; + int peer_access_available = 0; + for (int i = 0; i < device_count; i++) { + if (i == device_id) { + continue; + } + check_cuda_error( + cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); + if (!peer_access_available) { + // FT_LOG_WARNING("Device " + std::to_string(device_id) + " peer access + // Device " + std::to_string(i) + // + " is not available."); + continue; + } + desc.location.type = cudaMemLocationTypeDevice; + desc.location.id = i; + desc.flags = cudaMemAccessFlagsProtReadWrite; + check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); + } + // set memory pool threshold to avoid shrinking the pool + uint64_t setVal = UINT64_MAX; + check_cuda_error(cudaMemPoolSetAttribute( + mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); +#endif + } + + virtual ~Allocator() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + while (!pointer_mapping_->empty()) { + free((void **)(&pointer_mapping_->begin()->first)); + } + delete pointer_mapping_; + } + + void setStream(cudaStream_t stream) { stream_ = stream; } + + cudaStream_t returnStream() { return stream_; }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (size == 0) { + return nullptr; + } + void *ptr = nullptr; + int o_device = 0; + + check_cuda_error(getSetDevice(device_id_, &o_device)); + if (is_host) { + check_cuda_error(cudaMallocHost(&ptr, (size_t)(ceil(size / 32.)) * 32)); + } else { +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32)); +#else + check_cuda_error( + cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); +#endif + } + if (is_set_zero) { + check_cuda_error( + cudaMemsetAsync(ptr, 0, (size_t)(ceil(size / 32.)) * 32, stream_)); + } + check_cuda_error(getSetDevice(o_device)); + // FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); + + pointer_mapping_->insert({getAddress(ptr), size}); + + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + if (*ptr != nullptr) { + int o_device = 0; + if (pointer_mapping_->count(address)) { + // FT_LOG_DEBUG("Free buffer %p", address); + check_cuda_error(getSetDevice(device_id_, &o_device)); + if (is_host) { + check_cuda_error(cudaFreeHost(*ptr)); + } else { +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaFree(*ptr)); +#else + check_cuda_error(cudaFreeAsync(*ptr, stream_)); + cudaStreamSynchronize(stream_); +#endif + } + check_cuda_error(getSetDevice(o_device)); + pointer_mapping_->erase(address); + } else { + // FT_LOG_WARNING("pointer_mapping_ does not have information of ptr at + // %p.", address); + } + } + *ptr = nullptr; + return; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_)); + } +}; \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc b/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc new file mode 100644 index 000000000000..61e41438c6a8 --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasAlgoMap.h" + +cublasAlgoMap::cublasAlgoMap(const std::string filename, + const std::string sp_config_filename) + : config_filename_(filename), sp_config_filename_(sp_config_filename) { + loadGemmConfig(); + loadSpGemmConfig(); +} + +cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap &algo_map) + : config_filename_(algo_map.config_filename_), + sp_config_filename_(algo_map.sp_config_filename_), + algo_map_(algo_map.algo_map_), sp_algo_map_(algo_map.sp_algo_map_) {} + +cublasAlgoMap::~cublasAlgoMap() { algo_map_.clear(); } + +void cublasAlgoMap::loadGemmConfig() { + FILE *fd; + fd = fopen(config_filename_.c_str(), "r"); + if (fd == NULL) { + std::cout << "[WARNING] " << config_filename_ + << " is not found; using default GEMM algo" << std::endl; + return; + } + + int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val; + int batch_size, seq_len, head_num, size_per_head, dataType; + int swizzle, reductionScheme, workspaceSize, stages; + int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode; + float exec_time; + char tmp[1024]; + if (!fgets(tmp, 1024, fd)) { + printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); + exit(-1); + } + while (fscanf(fd, + "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d " +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + "%d %d " +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + "%d %d %d " +#endif + "%f\n", + &batch_size, &seq_len, &head_num, &size_per_head, &dataType, + &batchCount2, &n2, &m2, &k2, &algoId, &customOption, &tile, + &splitK_val, &swizzle, &reductionScheme, &workspaceSize, + &stages, +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + &inner_shapeId, &cluster_shapeId, +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + &mma_shapeId, &cga_shapeId, &sche_mode, +#endif + &exec_time) != EOF) { + if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && + dataType != BFLOAT16_DATATYPE && dataType != INT8_DATATYPE && + dataType != FP8_DATATYPE) { + printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType); + continue; + } + cublasAlgoConfig_t markStr{batchCount2, m2, n2, k2, + static_cast(dataType)}; + // workspaceSize should be zero + if (algo_map_.find(markStr) == algo_map_.end()) { + algo_map_[markStr].algoId = algoId; + algo_map_[markStr].customOption = customOption; + algo_map_[markStr].tile = tile; + algo_map_[markStr].splitK_val = splitK_val; + algo_map_[markStr].swizzle = swizzle; + algo_map_[markStr].reductionScheme = reductionScheme; + algo_map_[markStr].workspaceSize = workspaceSize; + algo_map_[markStr].stages = stages; +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + algo_map_[markStr].inner_shapeId = (uint16_t)inner_shapeId; + algo_map_[markStr].cluster_shapeId = (uint16_t)cluster_shapeId; +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + algo_map_[markStr].mma_shapeId = (uint16_t)mma_shapeId; + algo_map_[markStr].cga_shapeId = (uint16_t)cga_shapeId; + algo_map_[markStr].sche_mode = (uint16_t)sche_mode; +#endif + algo_map_[markStr].exec_time = exec_time; + } + } + fclose(fd); +} + +bool cublasAlgoMap::isExist(const int batch_count, const int m, const int n, + const int k, const CublasDataType data_type) { + cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; + return algo_map_.find(mark) != algo_map_.end(); +} + +cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(const int batch_count, + const int m, const int n, + const int k, + const CublasDataType data_type) { + cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; + if (algo_map_.find(mark) != algo_map_.end()) { + return algo_map_[mark]; + } else { + cublasLtMatmulAlgo_info tmp_algo; + tmp_algo.algoId = static_cast(data_type == FLOAT_DATATYPE + ? CUBLAS_GEMM_DEFAULT + : CUBLAS_GEMM_DEFAULT_TENSOR_OP); + tmp_algo.customOption = -1; + tmp_algo.tile = -1; + tmp_algo.splitK_val = -1; + tmp_algo.swizzle = -1; + tmp_algo.reductionScheme = -1; + tmp_algo.workspaceSize = -1; + tmp_algo.stages = -1; + tmp_algo.exec_time = -1.0f; + return tmp_algo; + } +} + +void cublasAlgoMap::loadSpGemmConfig() { + if (sp_config_filename_.empty()) { + return; + } + FILE *fd = fopen(sp_config_filename_.c_str(), "r"); + if (fd == NULL) { + printf("[WARNING] %s is not found; using SPGEMM algo id 0\n", + sp_config_filename_.c_str()); + return; + } + sp_algo_map_.clear(); + int batch_size, seq_len, head_num, size_per_head, data_type; + int batchCount, m, n, k, algoId; + float exec_time; + char tmp[1024]; + if (!fgets(tmp, 1024, fd)) { + printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); + exit(-1); + } + while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, + &seq_len, &head_num, &size_per_head, &data_type, &batchCount, + &m, &n, &k, &algoId, &exec_time) != EOF) { + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k); + std::string markStr(mark); + sp_algo_map_[markStr] = algoId; + } + fclose(fd); +} + +int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, + const int k) { + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); + if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { + return sp_algo_map_[mark]; + } else { + // for remove padding, select algo 1 for simplicity + return 0; + } +} + +bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, + const int k) { + // not available to use cusparselt. + if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0) { + return false; + } + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); + if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { + return sp_algo_map_[mark] != -1; + } else { + // no gemm test case, choose sparse according to sparse flag + return true; + } +} diff --git a/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h b/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h new file mode 100644 index 000000000000..beb9d3a23d90 --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include +#include + +#pragma once + +#define GEMM_NUM 6 +#define GEMM_CONFIG "gemm_config.in" +#define IGEMM_CONFIG "igemm_config.in" +#define SPGEMM_CONFIG "spgemm_config.in" +#define SPIGEMM_CONFIG "spigemm_config.in" + +typedef struct { + int algoId, customOption, tile, splitK_val; + int swizzle, reductionScheme, workspaceSize; + // only used in cublasLt >= 11.0 + int stages; +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + uint16_t inner_shapeId, cluster_shapeId; +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + uint16_t mma_shapeId, cga_shapeId, sche_mode; +#endif + float exec_time; +} cublasLtMatmulAlgo_info; + +/* Structure to store information about different run trials */ +typedef struct { + cublasLtMatmulAlgo_t algo; + cublasStatus_t status; + float time; + size_t workspaceSize; // actual memory workspace needed + cublasMath_t mathMode; + cublasLtReductionScheme_t reductionScheme; + int customOption; + float wavesCount; +} customMatmulPerf_t; + +struct cublasAlgoConfig_t { + int batch_count; + int m; + int n; + int k; + CublasDataType data_type; + bool operator==(cublasAlgoConfig_t const &config) const { + return (batch_count == config.batch_count) && (m == config.m) && + (n == config.n) && (k == config.k) && + (data_type == config.data_type); + } +}; + +class cublasAlgoConfig_hasher { +public: + std::size_t operator()(cublasAlgoConfig_t const &config) const { + return config.batch_count * 98317ull ^ config.m * 49157ull ^ + config.n * 24593ull ^ config.k * 196613ull ^ + static_cast(config.data_type) * 6151ull; + } +}; + +class cublasAlgoMap { +private: + std::unordered_map + algo_map_; + std::string config_filename_; + std::string sp_config_filename_; + std::map sp_algo_map_; + +public: + cublasAlgoMap(){}; + explicit cublasAlgoMap(const std::string filename, + const std::string sp_config_filename = ""); + cublasAlgoMap(const cublasAlgoMap &map); + ~cublasAlgoMap(); + void loadGemmConfig(); + void loadSpGemmConfig(); + int getSpAlgo(const int batch_count, const int m, const int n, const int k); + bool isUseSparse(const int batch_count, const int m, const int n, + const int k); + + bool isExist(const int batch_count, const int m, const int n, const int k, + const CublasDataType data_type); + + cublasLtMatmulAlgo_info getAlgo(const int batch_count, const int m, + const int n, const int k, + const CublasDataType data_type); +}; diff --git a/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc b/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc new file mode 100644 index 000000000000..03c656b10cbd --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc @@ -0,0 +1,676 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasINT8MMWrapper.h" + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublas_handle_(nullptr), cublaslt_handle_(cublaslt_handle), + stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), + allocator_(nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} + +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), + stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), + allocator_(nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} + + +cublasINT8MMWrapper::~cublasINT8MMWrapper() { mu_ = nullptr; } + +cublasINT8MMWrapper::cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper) + : cublas_handle_(nullptr), cublaslt_handle_(wrapper.cublaslt_handle_), + stream_(wrapper.stream_), cublas_algo_map_(wrapper.cublas_algo_map_), mu_(wrapper.mu_), + allocator_(wrapper.allocator_), use_ORDER_COL32_2R_4R4_(wrapper.use_ORDER_COL32_2R_4R4_) { +} + +// for int8 cublasLtMM with algo +// ATransform should be m*n, CUBLASLT_ORDER_COL32 +// kernel should be n*k, CUBLASLT_ORDER_COL4_4R2_8C or +// CUBLASLT_ORDER_COL32_2R_4R4 res is m*n, CUBLASLT_ORDER_COL32 +void cublasINT8MMWrapper::Gemm(int *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + cublasLtOrder_t order_matrixB; +#if (CUDART_VERSION >= 11000) + if (use_ORDER_COL32_2R_4R4_) { + order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; + } else { + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + } +#else + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; +#endif + + int ldaTransform = 32 * m; + int ldbTransform; + if (use_ORDER_COL32_2R_4R4_) { + ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; + } else { + ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; + } + int ldcTransform = 32 * m; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); + cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_matrixB, sizeof(order_matrixB)); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, m, n, ldcTransform); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + int alphaI = 1; + int betaI = 0; + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + // printf("find algo %s\n", markStr.c_str()); + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + if (use_ORDER_COL32_2R_4R4_) { + algoId = 7; + } else { + algoId = 6; + } + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + if (use_ORDER_COL32_2R_4R4_) { + stages = 15; + } else { + stages = 13; + } + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, ATransform, + AtransformDesc, kernel, BtransformDesc, &betaI, res, + CtransformDesc, res, CtransformDesc, + (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_32I +// alpha: CUDA_R_32I should be 1 +// beta: CUDA_R_32I should be 0 +// computeType: CUBLAS_COMPUTE_32I +void cublasINT8MMWrapper::Gemm_(int *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + int alphaI = 1; + int betaI = 0; + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + // printf("find algo %s\n", markStr.c_str()); + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, kernel, AtransformDesc, + ATransform, BtransformDesc, &betaI, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// for int8 IO cublasLtMM with algo +// ATransform should be m*k CUBLASLT_ORDER_COL32 +// kernel should be n*k CUBLASLT_ORDER_COL4_4R2_8C +// res is m*n CUBLASLT_ORDER_COL32 +void cublasINT8MMWrapper::Gemm(int8_t *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + cublasLtOrder_t order_matrixB; +#if (CUDART_VERSION >= 11000) + if (use_ORDER_COL32_2R_4R4_) { + order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; + } else { + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + } +#else + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; +#endif + + int ldaTransform = 32 * m; + + int ldbTransform; + if (use_ORDER_COL32_2R_4R4_) { + ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; + } else { + ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; + } + + int ldcTransform = 32 * m; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); + cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_matrixB, sizeof(order_matrixB)); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, m, n, ldcTransform); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + if (use_ORDER_COL32_2R_4R4_) { + algoId = 7; + } else { + algoId = 6; + } + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + if (use_ORDER_COL32_2R_4R4_) { + stages = 15; + } else { + stages = 13; + } + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_8I +// alpha: CUDA_R_32F +// beta: CUDA_R_32F +// computeType: CUBLAS_COMPUTE_32I +void cublasINT8MMWrapper::Gemm_(int8_t *res, int batchCount, int m, int n, + int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, n, m, k, INT8_DATATYPE)) { + findAlgo = 1; + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, n, m, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +bool cublasINT8MMWrapper::getUseOrderCol322R4R4() { + return use_ORDER_COL32_2R_4R4_; +} diff --git a/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h b/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h new file mode 100644 index 000000000000..8bc209f58b91 --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "allocator.h" +#include "cublasAlgoMap.h" +#include +#include +#include +#include +#include +#include + +#pragma once + +class cublasINT8MMWrapper{ +protected: + cublasHandle_t cublas_handle_; + cublasLtHandle_t cublaslt_handle_; + cudaStream_t stream_; + cublasAlgoMap *cublas_algo_map_; + std::mutex *mu_; + IAllocator *allocator_ = nullptr; + +private: + bool use_ORDER_COL32_2R_4R4_; + +public: + cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, + bool use_ORDER_COL32_2R_4R4); + + cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, + bool use_ORDER_COL32_2R_4R4); + + ~cublasINT8MMWrapper(); + + cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper); + + void Gemm(int *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const int8_t *ATransform, + const int8_t *kernel); + + void Gemm_(int *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const int8_t *ATransform, + const int8_t *kernel); + + void Gemm(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + void Gemm_(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + bool getUseOrderCol322R4R4(); +}; \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/cuda_utils.cc b/csrc/quantization/smoothquant/int8gemm/cuda_utils.cc new file mode 100644 index 000000000000..588375570937 --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/cuda_utils.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" + +cudaError_t getSetDevice(int i_device, int *o_device) { + int current_dev_id = 0; + cudaError_t err = cudaSuccess; + + if (o_device != NULL) { + err = cudaGetDevice(¤t_dev_id); + if (err != cudaSuccess) { + return err; + } + if (current_dev_id == i_device) { + *o_device = i_device; + } else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + *o_device = current_dev_id; + } + } else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + } + + return cudaSuccess; +} diff --git a/csrc/quantization/smoothquant/int8gemm/cuda_utils.h b/csrc/quantization/smoothquant/int8gemm/cuda_utils.h new file mode 100644 index 000000000000..f1d9bba4ab06 --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/cuda_utils.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +enum CublasDataType { + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 +}; + +static const char *_cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorString(error); +} + +static const char *_cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +template +void check(T result, char const *const func, const char *const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline void syncAndCheck(const char *const file, int const line) { + // When FT_DEBUG_LEVEL=DEBUG, must check error + static char *level_name = std::getenv("FT_DEBUG_LEVEL"); + if (level_name != nullptr) { + static std::string level = std::string(level_name); + if (level == "DEBUG") { + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error( + std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } + // FT_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line)); + } + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +#endif +} + +#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__) + + +[[noreturn]] inline void throwRuntimeError(const char *const file, + int const line, + std::string const &info = "") { + throw std::runtime_error(std::string("[FT][ERROR] ") + info + + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); +} + +inline void myAssert(bool result, const char *const file, int const line, + std::string const &info = "") { + if (!result) { + throwRuntimeError(file, line, info); + } +} + +#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) + +#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) + +cudaError_t getSetDevice(int i_device, int *o_device = NULL); + +inline int getDevice() { + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; +} + +inline int getDeviceCount() { + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; +} \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/int8_gemm.h b/csrc/quantization/smoothquant/int8gemm/int8_gemm.h new file mode 100644 index 000000000000..2e80d4efe22a --- /dev/null +++ b/csrc/quantization/smoothquant/int8gemm/int8_gemm.h @@ -0,0 +1,127 @@ +/* + gemm methods are adapted from ft +*/ +#include +#include "cublasAlgoMap.h" +#include "cublasINT8MMWrapper.h" + +class I8CUGEMM { +private: + cublasINT8MMWrapper *int8_gemm_wrapper = nullptr; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +public: + I8CUGEMM(); + ~I8CUGEMM(); + + void linear_a8_w8_o32( + torch::Tensor& input, + torch::Tensor& weight, + torch::Tensor& output); + void linear_a8_w8_o32_( + torch::Tensor& input, + torch::Tensor& weight, + torch::Tensor& output); + void linear_a8_w8_o8( + torch::Tensor& input, + torch::Tensor& weight, + torch::Tensor& output, + float alpha); + void linear_a8_w8_o8_( + torch::Tensor& input, + torch::Tensor& weight, + torch::Tensor& output, + float alpha); +}; +I8CUGEMM::I8CUGEMM() { + // cublasAlgoMap *cublas_algo_map = new cublasAlgoMap("igemm_config.in"); + cublasAlgoMap *cublas_algo_map = new cublasAlgoMap(); + std::mutex *cublas_wrapper_mutex = new std::mutex(); + bool use_ORDER_COL32_2R_4R4 = true; + + cublasLtHandle_t cublaslt_handle; + cublasLtCreate(&cublaslt_handle); + + int8_gemm_wrapper = new cublasINT8MMWrapper( + cublaslt_handle, + this->stream, + cublas_algo_map, + cublas_wrapper_mutex, + use_ORDER_COL32_2R_4R4); +} + +I8CUGEMM::~I8CUGEMM() {} + +void I8CUGEMM::linear_a8_w8_o32( + torch::Tensor& input, // INT8 + torch::Tensor& weight, // INT8 + torch::Tensor& out // INT32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t* input_ptr = input.data_ptr(); + int8_t* weight_ptr = weight.data_ptr(); + int32_t* output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_o32_( + torch::Tensor& input, // INT8 + torch::Tensor& weight, // INT8 + torch::Tensor& out // INT32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t* input_ptr = input.data_ptr(); + int8_t* weight_ptr = weight.data_ptr(); + int32_t* output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_o8( + torch::Tensor& input, // INT8 + torch::Tensor& weight, // INT8 + torch::Tensor& out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t* input_ptr = input.data_ptr(); + int8_t* weight_ptr = weight.data_ptr(); + int8_t* output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_o8_( + torch::Tensor& input, // INT8 + torch::Tensor& weight, // INT8 + torch::Tensor& out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t* input_ptr = input.data_ptr(); + int8_t* weight_ptr = weight.data_ptr(); + int8_t* output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} diff --git a/csrc/quantization/smoothquant/quant_utils.cuh b/csrc/quantization/smoothquant/quant_utils.cuh new file mode 100644 index 000000000000..dcbf3da3fcb7 --- /dev/null +++ b/csrc/quantization/smoothquant/quant_utils.cuh @@ -0,0 +1,243 @@ +// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +#pragma once + +#include +#include +#include +#include +#include "../../attention/attention_dtypes.h" +#include "../../attention/dtype_float32.cuh" +using namespace vllm; + +// this function is for function matching, delete it after writing customized dispatch functions +inline __device__ int8_t quant(double a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ int8_t quant(float a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ short quant(float2 a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + return int16; +} + +inline __device__ int32_t quant(float4 a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + int8[2] = round(max(-128.f, min(127.f, (a.z - zp) / scale))); + int8[3] = round(max(-128.f, min(127.f, (a.w - zp) / scale))); + return int32; +} + +// float16 to int8 +inline __device__ int8_t quant(uint16_t a, const float scale, const float zp) +{ + int8_t int8; + float b = half_to_float(a); + int8 = round(max(-128.f, min(127.f, (b - zp) / scale))); + return int8; +} + +// float16x2 to int8x2 +inline __device__ int16_t quant(uint32_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + float2 b = half2_to_float2(a); + + int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale))); + return int16; +} + +// float16x4 to int8x4 +inline __device__ int32_t quant(uint2 a, const float scale, const float zp) +{ + union { + int16_t int16[2]; + int32_t int32; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + return int32; +} + +// float16x8 to int8x8 +inline __device__ int64_t quant(uint4 a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + int16[2] = quant(a.z, scale, zp); + int16[3] = quant(a.w, scale, zp); + return int64; +} + +// int8 to float32, then `vec_conversion` to target format +inline __device__ float dequant(int8_t a, const float scale, const float zp) +{ + float b = a * scale + zp; + return b; +} + +// int8x2 to float32x2 +inline __device__ float2 dequant(int16_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + int16 = a; + + float2 b; + b.x = int8[0] * scale + zp; + b.y = int8[1] * scale + zp; + return b; +} + +// int8x4 to float32x4 +inline __device__ Float4_ dequant(int32_t a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int32 = a; + + Float4_ b; + b.x.x = (int8[0] * scale) + zp; + b.x.y = (int8[1] * scale) + zp; + b.y.x = (int8[2] * scale) + zp; + b.y.y = (int8[3] * scale) + zp; + return b; +} + +inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + int64 = a; + + Float8_ b; + b.x = dequant(int16[0], scale, zp); + b.y = dequant(int16[1], scale, zp); + b.z = dequant(int16[2], scale, zp); + b.w = dequant(int16[3], scale, zp); + return b; +} + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template<> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + + return b; +} + +template<> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { + return __float22bfloat162_rn(a); +} + +template<> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { + bf16_4_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + return b; +} + +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { + bf16_8_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + b.z = vec_conversion<__nv_bfloat162, float2>(a.z); + b.w = vec_conversion<__nv_bfloat162, float2>(a.w); + return b; +} + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} \ No newline at end of file diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bb5171f854d5..4065efc8efa2 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -62,4 +62,29 @@ __inline__ __device__ T blockReduceSum(T val) { return val; } +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMax(val); // get maxx in each warp + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + __syncthreads(); + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + return val; +} } // namespace vllm diff --git a/examples/offline_profile.py b/examples/offline_profile.py new file mode 100644 index 000000000000..da0b700909d2 --- /dev/null +++ b/examples/offline_profile.py @@ -0,0 +1,268 @@ +import argparse +import torch +import sys +import json +import inspect + +from dataclasses import dataclass, asdict +from typing import Optional +from vllm import LLM, SamplingParams +from vllm.profiler import nm_profile + +BATCH_SIZE_DEFAULT = 1 +PROMPT_LEN_DEFAULT = 256 +MAX_SEQ_LEN_DEFAULT = 1024 + + +@dataclass +class ProfileContext: + model: str + tokenizer: str + model_revision: str + quantization: str + max_seq_len: int + max_num_batched_tokens: int + prompt_len: int + batch_size: int + tensor_parallel_size: int + kv_cache_dtype: str + kv_quant_params_path: str + allow_cuda_graphs: bool + + +def run_profile(context: ProfileContext, csv_output: Optional[str], + json_output: Optional[str]): + print("Run profile with:") + for key, value in asdict(context).items(): + print(f" {key} = {value}") + + # Create sampling params + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8) + + # Sparsity is in the future + # Create LLM + llm = None + if context.kv_quant_params_path is not None: + llm = LLM( + model=context.model, + tokenizer=context.tokenizer if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_seq_len, + quantization=context.quantization, + max_num_batched_tokens=context.max_num_batched_tokens, + kv_cache_dtype=context.kv_cache_dtype, + kv_quant_params_path=context.kv_quant_params_path) + else: + llm = LLM( + model=context.model, + tokenizer=context.tokenizer if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_seq_len, + quantization=context.quantization, + max_num_batched_tokens=context.max_num_batched_tokens) + + batch_size = context.batch_size + prompt_len = context.prompt_len + + scheduler_config = llm.llm_engine.scheduler_config + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + if batch_size * prompt_len > max_num_batched_tokens: + print(f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max_num_batched_tokens") + sys.exit(-1) + if batch_size >= max_num_seqs: + print( + f"ERROR: chosen batch_size ({batch_size}) is larger than " + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " + f"single profile step, please choose a smaller batch size") + sys.exit(-1) + + for i in range(batch_size): + llm.llm_engine.add_request( + request_id=f"seq{i}", + prompt=None, + prompt_token_ids=torch.randint( + 128, # 128 to skip over special tokens + llm.llm_engine.model_config.get_vocab_size() // 2, + size=(prompt_len, )).tolist(), + sampling_params=sampling_params) + + with nm_profile() as prefill_prof: + llm.llm_engine.step() # First step is prefill + + with nm_profile() as decode_prof: + llm.llm_engine.step() + + prefill_results = prefill_prof.results + decode_results = decode_prof.results + + print("=" * 80) + print(f"= Prefill Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_model_table() + print() + print("=" * 80) + print(f"= Decode Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results.print_model_table() + print() + print("=" * 80) + print(f"= Prefill Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_summary_table() + print() + print("=" * 80) + print(f"= Decode Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results.print_summary_table() + + if csv_output: + csv_filename_base = csv_output.rstrip(".csv") + prefill_results.export_model_stats_table_csv( + csv_filename_base + "_prefill_model_table.csv") + prefill_results.export_summary_stats_table_csv( + csv_filename_base + "_prefill_summary_table.csv") + decode_results.export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results.export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") + + if json_output: + cuda_devices = [ + torch.cuda.get_device_properties(dev_idx) + for dev_idx in range(torch.cuda.device_count()) + ] + + json_dict = { + "context": { + "python_version": f"{sys.version}", + "torch_version": f"{torch.__version__}", + "torch_cuda_version": f"{torch.version.cuda}", + "cuda_devices": f"{cuda_devices}", + **asdict(context) + }, + "prefill": prefill_results.convert_stats_to_dict(), + "decode": decode_results.convert_stats_to_dict() + } + + with open(json_output.rstrip(".json") + ".json", "w+") as f: + json.dump(json_dict, f, indent=2) + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model", + type=str, + required=True, + help='The name or path of a HuggingFace Transformers model.') + parser.add_argument( + "--tokenizer", + type=str, + default=None, + help="path to the tokenizer") + + parser.add_argument("--model-revision", type=str, default=None) + parser.add_argument( + "--csv", + type=str, + default=None, + help="Export the results as multiple csv file. This should be the root " + "filename, will create _prefill_model_table.csv, " + "_prefill_summary_table.csv, " + "_decode_model_table.csv, and " + "_decode_summary_table.csv") + parser.add_argument( + "--json", + type=str, + default=None, + help="Export the results as a json file. This should be the filename") + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None], + default=None, + help="The method used to quantize the model weights, " + "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=['auto', 'fp8_e5m2', 'int8'], + default='auto', + help= + 'Data type for kv cache storage. If "auto", will use model data type.') + parser.add_argument( + "--kv-quant-params-path", + type=str, + default=None, + help='Path to scales and zero points of kv cache quantizaiton ' + 'when kv cache dtype is int8.') + parser.add_argument( + "--max-seq-len", + type=int, + default=MAX_SEQ_LEN_DEFAULT, + help=f"Maximum length of a sequence (including prompt and output), " + f"default={MAX_SEQ_LEN_DEFAULT}") + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="Maximum number of tokens to be processed in a single iteration. " + " Should be greater than batch-size * prompt-len so the prefill can " + " run in a single iteration.") + parser.add_argument( + "--prompt-len", + type=int, + default=PROMPT_LEN_DEFAULT, + help=f"Length of the random prompt to use when profiling, all batched " + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument("--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}") + parser.add_argument("--tensor-parallel-size", + "-tp", + type=int, + default=1, + help="Number of GPUs to use i.e. tensor parallelism, " + "default=1") + parser.add_argument( + "--allow-cuda-graphs", + action='store_true', + help="Enables cuda graphs to be used, well remove a lot of the module " + "level info in the profiler results since almost everything runs in " + "the graph where we do not have access to an informative stack trace") + + args = parser.parse_args() + + context = ProfileContext( + **{ + k: v + for k, v in vars(args).items() + if k in inspect.signature(ProfileContext).parameters + }) + run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/experiments.sh b/experiments.sh new file mode 100755 index 000000000000..efdd85d523f4 --- /dev/null +++ b/experiments.sh @@ -0,0 +1,211 @@ +#! /bin/bash + +set -e +set -u +set -x + +# global args +model_path=/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b +tokenizer=/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b/ +max_seq_len=2048 +max_num_batched_tokens=7000 +tensor_parallel_size=1 + +# experiment args +prefill_prompt_len=512 +decode_batch_sizes=(1 2 8 16 32 64 128) + +# quantization specific args +quant_model_path=$model_path/quantized_model/llama-13b/Nous-Hermes-Llama2-13b-smoothquant/ + +# kv quant specific args +kv_cache_dtype=int8 +kv_quant_params_path=/home/varun/code/vllm/act_quant_data/exported_kv/ + +run_quantized_prefill() { + output_directory=$1 + + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_quantized-${now} + + echo "Running prefill ${prefill_prompt_len} store at ${out_base}" + python3 examples/offline_profile.py --model $quant_model_path \ + --tokenizer $tokenizer \ + --batch-size 1 \ + --prompt-len $prefill_prompt_len \ + --quantization smoothquant \ + --max-seq-len $max_seq_len \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 +} + +run_quantized_decode() { + output_directory=$1 + + for bs in "${decode_batch_sizes[@]}" + do + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/decode_bs_${bs}_llama13_quantized-${now} + + echo "Running decode bs ${bs} store at ${out_base}" + python3 examples/offline_profile.py --model $quant_model_path \ + --tokenizer $tokenizer \ + --batch-size $bs \ + --prompt-len 1 \ + --quantization smoothquant \ + --max-seq-len $max_seq_len \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 + done +} + +run_kv_quant_prefill() { + + output_directory=$1 + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_kv_quant-${now} + + echo "Running prefill ${prefill_prompt_len} store at ${out_base}" + + python3 examples/offline_profile.py --model $model_path \ + --tokenizer $tokenizer \ + --batch-size 1 \ + --prompt-len $prefill_prompt_len \ + --kv-cache-dtype $kv_cache_dtype \ + --kv-quant-params-path $kv_quant_params_path \ + --max-seq-len $max_seq_len \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 +} + +run_kv_quant_decode() { + output_directory=$1 + + for bs in "${decode_batch_sizes[@]}" + do + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/decode_bs_${bs}_llama13_kv_quant-${now} + + echo "Running decode bs ${bs} store at ${out_base}" + python3 examples/offline_profile.py --model $model_path \ + --tokenizer $tokenizer \ + --batch-size $bs \ + --prompt-len 1 \ + --kv-cache-dtype $kv_cache_dtype \ + --kv-quant-params-path $kv_quant_params_path \ + --max-seq-len $max_seq_len \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 + done +} + +run_fp16_prefill() { + + output_directory=$1 + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_fp16-${now} + + echo "Running prefill ${prefill_prompt_len} store at ${out_base}" + + python3 examples/offline_profile.py --model $model_path \ + --tokenizer $tokenizer \ + --batch-size 1 \ + --prompt-len $prefill_prompt_len \ + --max-seq-len $max_seq_len \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 +} + +run_fp16_decode() { + output_directory=$1 + + for bs in "${decode_batch_sizes[@]}" + do + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/decode_bs_${bs}_llama13_fp16-${now} + + echo "Running decode bs ${bs} store at ${out_base}" + python3 examples/offline_profile.py --model $model_path \ + --tokenizer $tokenizer \ + --batch-size $bs \ + --prompt-len 1 \ + --max-seq-len $max_seq_len \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 + done +} + +## Arg parser and invocation + +usage() { + echo`` + echo "Run profiler" + echo + echo "usage: ${0} " + echo + echo " -t - pass in w8a8 or kv_quant" + echo " -n - pass in num_benchmark_iterations" + echo " -o - out directory" + echo +} + +exp_type="" # should be either w8a8 or kv_quant +num_benchmark_iterations=1 +output_directory="./" + +while getopts ':t:n:o:h:' OPT; do + case "${OPT}" in + t) + exp_type="${OPTARG}" + ;; + n) + num_benchmark_iterations=${OPTARG} + ;; + o) + output_directory="${OPTARG}" + ;; + h) + usage + exit 1 + ;; + esac +done + +if [ "$exp_type" != "w8a8" -a "$exp_type" != "kv_quant" -a "$exp_type" != "fp16" ]; +then + echo "Invalid arg $exp_type" + usage + exit 1 +fi + +for i in $(seq 1 $num_benchmark_iterations); +do + echo "Running benchmark iteration ${i} ..." + if [[ "${exp_type}" == "w8a8" ]]; + then + run_quantized_prefill $output_directory + run_quantized_decode $output_directory + fi + if [[ "${exp_type}" == "kv_quant" ]]; + then + run_kv_quant_prefill $output_directory + run_kv_quant_decode $output_directory + fi + if [[ "${exp_type}" == "fp16" ]]; + then + run_fp16_prefill $output_directory + run_fp16_decode $output_directory + fi +done diff --git a/neuralmagic/tools/profiler/print_table.py b/neuralmagic/tools/profiler/print_table.py new file mode 100644 index 000000000000..728a5e386089 --- /dev/null +++ b/neuralmagic/tools/profiler/print_table.py @@ -0,0 +1,77 @@ +import argparse +import json + +from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry +from vllm.profiler.utils import indent_string, TablePrinter +from typing import Dict + + +def flatten_entries(entry_cls, profile_dict: Dict): + entries_and_depth = [] + + def get_entries(node, curr_depth=0): + entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) + + for child in node["children"]: + get_entries( + child, + curr_depth=curr_depth + 1, + ) + + for root in profile_dict: + get_entries(root) + + return entries_and_depth + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--json-trace", + type=str, + required=True, + help="json trace file output by " + "examples/offline_profile.py") + parser.add_argument("--phase", + type=str, + choices=["prefill", "decode"], + required=True, + help="The phase to print the table for.") + parser.add_argument("--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the " + "layerwise model table") + + args = parser.parse_args() + + with open(args.json_trace, "r") as f: + profile_data = json.load(f) + + if args.table == "summary": + entries_and_depths = flatten_entries( + SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) + column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + elif args.table == "model": + entries_and_depths = flatten_entries( + ModelStatsEntry, profile_data[args.phase]["model_stats"]) + column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + + # ident entry names based on the depth + entries = [] + for entry, depth in entries_and_depths: + entry.name = indent_string( + entry.name, + indent=depth, + indent_style=lambda indent: "|" + "-" * indent + " ") + entries.append(entry) + + TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py new file mode 100644 index 000000000000..9d0f7f328543 --- /dev/null +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -0,0 +1,209 @@ +import argparse +import json +import pandas as pd +import matplotlib.pyplot as plt + + +def trim_string_back(string: str, width: int): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +def abbreviate_known_names(name: str): + abbreviations = { + "MergedColumnParallelLinear": "MCPLinear", + "QKVParallelLinear": "QKVPLinear", + "RowParallelLinear": "RPLinear", + "weight=": "w=", + "bfloat16": "bf16", + "float16": "f16", + } + for key, value in abbreviations.items(): + name = name.replace(key, value) + return name + + +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_profile.py") + parser.add_argument( + "--output", + type=str, + required=False, + help="Output figure file, should be a image file such as pdf, " + "jpeg, png, etc., defaults to .pdf") + parser.add_argument("--level", + type=str, + default="module", + choices=["module", "kernel"]) + parser.add_argument("--top_k", + type=int, + default=9, + help="Only graph the top `top_k` entries by time.") + parser.add_argument("--ignore_sampler", + action='store_true', + help="Ignore everything under the \"Sampler\" module") + + args = parser.parse_args() + + ignore_sampler = args.ignore_sampler + make_names_unique = False + top_k = args.top_k + + if args.level == "module": + depth = -2 + make_names_unique = True + elif args.level == "kernel": + depth = -1 + else: + raise Exception(f"Unexpected level value ({args.level})") + + if ignore_sampler: + print("WARNING: ignoring Sampler time so the pct_cuda_time will not " + "add up to 100%") + + json_trace = args.json_trace + output = args.output if args.output else json_trace.strip(".json") + ".pdf" + + with open(json_trace, "r") as f: + profile_data = json.load(f) + + prefill_entries_and_traces = [] + decode_entries_and_traces = [] + + def largest_dist_from_leaf(node, depth=0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + def get_entries_at_depth(depth, + entries_and_traces, + node, + curr_depth=0, + trace=()): + if ignore_sampler and node["entry"]["name"] == "Sampler": + return + + if (depth >= 0 and depth == curr_depth) or ( + depth < 0 + and largest_dist_from_leaf(node) == (abs(depth) - 1)): + entries_and_traces.append((node["entry"], trace)) + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + for root in profile_data["prefill"]["summary_stats"]: + get_entries_at_depth(depth, prefill_entries_and_traces, root) + for root in profile_data["decode"]["summary_stats"]: + get_entries_at_depth(depth, decode_entries_and_traces, root) + + def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [ + (entry, trace) for entry, trace in entries_and_traces + if entry["name"] == name + ] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + if make_names_unique: + attempt_to_make_names_unique(prefill_entries_and_traces) + attempt_to_make_names_unique(decode_entries_and_traces) + + def keep_only_top_entries(df, metric, top_k=9): + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + + prefill_df = pd.DataFrame( + [entry for entry, _ in prefill_entries_and_traces]) + prefill_df["phase"] = "prefill" + decode_df = pd.DataFrame([entry for entry, _ in decode_entries_and_traces]) + decode_df["phase"] = "decode" + + if top_k: + keep_only_top_entries(prefill_df, "cuda_time_us", top_k) + keep_only_top_entries(decode_df, "cuda_time_us", top_k) + + df = pd.concat([prefill_df, decode_df]) + df["cuda_time_ms"] = df["cuda_time_us"] / 1000 + + fig, axes = plt.subplots(2, figsize=(5, 8), sharex=True) + + def plot_metric(metric: str, ax, add_totals=False): + pivoted_df = df.pivot_table(index="phase", + columns="name", + values=metric, + aggfunc="sum") + pivoted_df.plot.bar(stacked=True, legend=False, ax=ax) + ax.set_ylabel(metric) + + if add_totals: + ax.bar_label(ax.containers[-1]) + + plot_metric("cuda_time_ms", ax=axes[0], add_totals=True) + plot_metric("pct_cuda_time", ax=axes[1]) + + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(0.93, 0.5)) + shorten_plot_legend_strings(legend, 50) + + context = profile_data["context"] + plt.suptitle( + f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + context['sparsity'] if context['sparsity'] else ''}" + ) + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1..f78949c24718 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,8 @@ aiohttp # Multimodal pillow + +# Profiling +matplotlib +pandas +pyarrow diff --git a/tests/kernels/test_fusion.py b/tests/kernels/test_fusion.py new file mode 100644 index 000000000000..07d9ce60a403 --- /dev/null +++ b/tests/kernels/test_fusion.py @@ -0,0 +1,94 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randint( + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + (num_tokens, hidden_size), + dtype=torch.int32, + device="cuda", + ) + + out1 = (x * scale).to(dtype) + out2 = torch.empty_like(x, dtype=dtype) + ops.dequant(out2, x, scale) + assert torch.allclose(out1, out2, atol=0.001) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_per_token_dequant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randint( + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + (num_tokens, hidden_size), + dtype=torch.int32, + device="cuda", + ) + scale = torch.rand(num_tokens, 1, dtype=torch.float32, device="cuda") + out1 = (x * scale).to(dtype) + out2 = torch.empty_like(x, dtype=dtype) + scale = torch.squeeze(scale) + ops.dequant(out2, x, scale) + assert torch.allclose(out1, out2, atol=0.001) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp(-128, 127).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.quant(out2, x, scale) + assert torch.allclose(out1, out2, atol=1) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_per_token_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + scale1 = torch.max(x, dim=1)[0].to(torch.float32) / 127.0 + out1 = (x / scale1.view(-1, 1)).round().clamp(-128, 127).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale2 = torch.empty(num_tokens, dtype=torch.float32, device="cuda") + ops.quant(out2, x, scale2) + assert torch.allclose(out1, out2, atol=1) diff --git a/vllm/config.py b/vllm/config.py index 6762a75f25f2..3149aaf68914 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,7 +173,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] + supported_quantization = ["awq", "gptq", "squeezellm", "smoothquant"] rocm_not_supported_quantization = ["awq", "marlin"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -868,6 +868,7 @@ def get_image_input_enum_type( _STR_DTYPE_TO_TORCH_DTYPE = { + "int8": torch.int8, "half": torch.float16, "float16": torch.float16, "float": torch.float32, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a6197942645e..3e6742eea2af 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -270,17 +270,18 @@ def add_cli_args( action='store_true', help='disable logging statistics') # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'gptq', 'squeezellm', None], - default=EngineArgs.quantization, - help='Method used to quantize the weights. If ' - 'None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') + parser.add_argument( + '--quantization', + '-q', + type=str, + choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None], + default=None, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ad988d48755b..301bcf48de6b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -6,11 +6,13 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.smoothquant import SmoothQuantConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "smoothquant": SmoothQuantConfig, "marlin": MarlinConfig, } diff --git a/vllm/model_executor/layers/quantization/smoothquant.py b/vllm/model_executor/layers/quantization/smoothquant.py new file mode 100644 index 000000000000..d9d82e6cfbc3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant.py @@ -0,0 +1,348 @@ +from typing import Any, Dict, List, Tuple, Optional + +import torch +from torch._tensor import Tensor +from torch.nn.parameter import Parameter +import threading + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant + + Reference: https://github.com/mit-han-lab/smoothquant + """ + + def __init__(self, + weight_bits: int = 8, + quant_map: dict[str:str] = None) -> None: + self.weight_bits = weight_bits + self.quant_map = quant_map + + if self.weight_bits != 8: + raise ValueError( + "Currently, only w8a8 quantization is supported for " + f"SmoothQuant, but got {self.weight_bits} bits.") + if self.quant_map is None or self.quant_map == {}: + raise ValueError( + 'Quant_map for SmoothQuant should not be None or an empty dict. ' + 'For example, when using llama, you should set a quant_config.json in model directory, like ' + '{ "qkv": "per-tensor", "out": "per-token", "fc1": "per-tensor", "fc2": "per-token" }' + ) + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, " + f"quant_map={self.quant_map})") + + def get_name(self) -> str: + return "smoothquant" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.float] + + def get_min_capability(self) -> int: + # The smoothquant kernel only supports Ampere or newer GPUs. + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + try: + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + except ValueError as e: + weight_bits = 8 + print(str(e) + " Set weight_bits = 8 by default.") + + quant_map = {} + for key, value in config.items(): + if value in ["per-tensor", "per-token"]: + quant_map[key] = value + return cls(weight_bits, quant_map) + + def get_linear_method(self) -> "SQLinearMethod": + return SQLinearMethod(Int8GEMM) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Int8GEMM(object): + _instance_lock = threading.Lock() + + def __init__(self): + if not hasattr(self, "i8cugemm"): + self.i8cugemm = ops.I8CUGEMM() + + def __new__(cls, *args, **kwargs): + if not hasattr(Int8GEMM, "_instance"): + with Int8GEMM._instance_lock: + if not hasattr(Int8GEMM, "_instance"): + Int8GEMM._instance = object.__new__(cls) + return Int8GEMM._instance + + def get_i8cugemm(self): + return self.i8cugemm + + +class SQLinearMethod(LinearMethodBase): + """Linear method for SmoothQuant. + """ + + def __init__(self, gemm): + i8_gemm = gemm() + self.i8cugemm = i8_gemm.get_i8cugemm() + + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Tensor]: + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + device="cuda", + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + # q k v dequant_scales are used in QKVParallelLinear + q_dequant_scale = Parameter( + torch.tensor(1.0, dtype=torch.float32, device='cpu'), + requires_grad=False, + ) + k_dequant_scale = Parameter( + torch.tensor(1.0, dtype=torch.float32, device='cpu'), + requires_grad=False, + ) + v_dequant_scale = Parameter( + torch.tensor(1.0, dtype=torch.float32, device='cpu'), + requires_grad=False, + ) + # gate up dequant_scales are used in MergedColumnParallelLinear + gate_dequant_scale = Parameter( + torch.tensor(1.0, dtype=torch.float32, device='cpu'), + requires_grad=False, + ) + up_dequant_scale = Parameter( + torch.tensor(1.0, dtype=torch.float32, device='cpu'), + requires_grad=False, + ) + # dequant_scale is used in RowParallelLinear + dequant_scale = Parameter( + torch.tensor(1.0, dtype=torch.float32, device='cpu'), + requires_grad=False, + ) + return { + "weight": weight, + "q_dequant_scale": q_dequant_scale, + "k_dequant_scale": k_dequant_scale, + "v_dequant_scale": v_dequant_scale, + "gate_dequant_scale": gate_dequant_scale, + "up_dequant_scale": up_dequant_scale, + "dequant_scale": dequant_scale + } + + def apply_weights(self, + weights: Dict[str, Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> Tensor: + assert bias is None + weight = weights["weight"] + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], weight.shape[0]), + dtype=torch.int32, + device=x.device) + self.i8cugemm.linear_a8_w8_o32_(x, weight, y) + y = y.view(*x_shape[:-1], -1) + return y + + +class SQLinearMethodQKV(SQLinearMethod): + + def __init__(self, + gemm, + qkv_sizes : Tuple[int, int, int], + quant_dtype : torch.dtype = torch.int8, + dequant_dtype : torch.dtype = torch.float): + super().__init__(gemm) + self.qkv_sizes = qkv_sizes + self.quant_dtype = quant_dtype + self.dequant_dtype = dequant_dtype + + def quantize(self, x): + assert x.dtype != self.quant_dtype + x_q = torch.empty_like(x, dtype=self.quant_dtype) + ops.quant(x_q, x, 1.0) + return x_q + + def dequantize(self, x_q, weights : Dict[str, Tensor]): + # split to get the quantized qkv + q_q, k_q, v_q = x_q.split(list(self.qkv_sizes), dim=-1) + + # create dequant qkv buffer and split to get the individual dequant qkv + # buffers + qkv = torch.empty_like(x_q, dtype=self.dequant_dtype) + q, k, v = qkv.split(list(self.qkv_sizes), dim=-1) + + q_scale, k_scale, v_scale = (weights['q_dequant_scale'], + weights['k_dequant_scale'], + weights['v_dequant_scale']) + ops.dequant(q, q_q, q_scale) + ops.dequant(k, k_q, k_scale) + ops.dequant(v, v_q, v_scale) + + return qkv + + def apply_weights(self, + weights: Dict[str, Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> Tensor: + x_q = self.quantize(x) + y_q = super().apply_weights(weights, x_q, bias) + return self.dequantize(y_q, weights) + +class SQLinearMethodOProj(SQLinearMethod): + + def __init__(self, + gemm, + use_per_token_quant:bool, + quant_dtype : torch.dtype = torch.int8, + dequant_dtype : torch.dtype = torch.float): + super().__init__(gemm) + self.use_per_token_quant = use_per_token_quant + self.quant_dtype = quant_dtype + self.dequant_dtype = dequant_dtype + + def quantize(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # x is the paged-attention output + assert x.dtype != self.quant_dtype + act_scale = None + x_q = torch.empty_like(x, dtype=self.quant_dtype) + if self.use_per_token_quant: + act_scale = torch.empty(x.numel() // x.shape[-1], + dtype=torch.float32, + device=x.device) + ops.quant(x_q, x, act_scale) + else: + ops.quant(x_q, x, 1.0) + return x_q, act_scale + + def dequantize(self, x_q, weights : Dict[str, Tensor], act_scale : torch.Tensor) -> torch.Tensor: + o_dequant_scale = weights['dequant_scale'] + x = torch.empty_like( + x_q, + dtype=self.dequant_dtype, + device=x_q.device) + ops.dequant(x, x_q, act_scale, o_dequant_scale) + return x + + def apply_weights(self, + weights: Dict[str, Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> Tensor: + pass + x_q, act_scale = self.quantize(x) + y_q = super().apply_weights(weights, x_q, bias) + return self.dequantize(y_q, weights, act_scale) + +class SQLinearMethodGateUpProj(SQLinearMethod): + + def __init__(self, + gemm, + quant_dtype : torch.dtype = torch.int8, + dequant_dtype : torch.dtype = torch.float): + super().__init__(gemm) + self.quant_dtype = quant_dtype + self.dequant_dtype = dequant_dtype + + def quantize(self, x) -> torch.Tensor: + # x is the attention output + assert x.dtype != self.quant_dtype + x_q = torch.empty_like(x, dtype=self.quant_dtype, device=x.device) + ops.quant(x_q, x, 1.0) + return x_q + + def dequantize(self, gate_up_q: torch.Tensor, weights : Dict[str, Tensor]) -> torch.Tensor: + + def split_gate_up(gate_up : torch.Tensor): + d = gate_up.shape[-1] + return (torch.narrow(gate_up, 1, 0, d//2), + torch.narrow(gate_up, 1, d//2, d//2)) + + # create a dequant gate_up buffer and split it into constituent parts. + gate_up = torch.empty_like(gate_up_q, + dtype=self.dequant_dtype, + device=gate_up_q.device) + + # split quantized gate_up into constituent parts. + gate_q, up_q = split_gate_up(gate_up_q) + # split output gate_up buffer into constituent parts. + gate, up = split_gate_up(gate_up) + + gate_scale, up_scale = (weights['gate_dequant_scale'], + weights['up_dequant_scale']) + ops.dequant(gate, gate_q, gate_scale) + ops.dequant(up, up_q, up_scale) + + return gate_up + + def apply_weights(self, + weights: Dict[str, Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> Tensor: + x_q = self.quantize(x) + gate_up_q = super().apply_weights(weights, x_q, bias) + return self.dequantize(gate_up_q, weights) + +class SQLinearMethodDownProj(SQLinearMethod): + + def __init__(self, + gemm, + quant_dtype : torch.dtype = torch.int8, + dequant_dtype : torch.dtype = torch.float): + super().__init__(gemm) + self.quant_dtype = quant_dtype + self.dequant_dtype = dequant_dtype + + def quantize(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dtype != self.quant_dtype + # TODO (varun) : This is per-token quant - Read from config + x_q = torch.empty_like(x, dtype=self.quant_dtype) + scale = torch.empty(x.numel() // x.shape[-1], + dtype=torch.float32, + device=x.device) + ops.quant(x_q, x, scale) + return x_q, scale + + def dequantize(self, x_q, weights : Dict[str, Tensor], act_scale : torch.Tensor) -> torch.Tensor: + down_dequant_scale = weights['dequant_scale'] + x = torch.empty_like( + x_q, + dtype=self.dequant_dtype, + device=x_q.device) + ops.dequant(x, x_q, act_scale, down_dequant_scale) + return x + + def apply_weights(self, + weights: Dict[str, Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x_q, act_scale = self.quantize(x) + y_q = super().apply_weights(weights, x_q, bias) + return self.dequantize(y_q, weights, act_scale) \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0..b191dc4009b5 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -46,6 +46,10 @@ def _get_model_architecture( def get_architecture_class_name(model_config: ModelConfig) -> str: return _get_model_architecture(model_config)[1] +def _is_support_smoothquant(model_config: ModelConfig) -> bool: + architectures = getattr(model_config.hf_config, "architectures", []) + supported_archs = ModelRegistry.get_supported_smoothquant_archs() + return any(arch in supported_archs for arch in architectures) def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: @@ -55,6 +59,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, # Get the (maybe quantized) linear method. linear_method = None + quant_config = None if model_config.quantization is not None: quant_config = get_quant_config(model_config) capability = torch.cuda.get_device_capability() @@ -77,21 +82,26 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device(device_config.device): - if hasattr(model_class, "supported_lora_modules"): + if _is_support_smoothquant(model_config): + model = model_class(model_config.hf_config, linear_method, + quant_config) + elif hasattr(model_class, "supported_lora_modules"): model = model_class(model_config.hf_config, linear_method, lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") else: if model_class not in _VISION_MODEL_CLASSES: model = model_class(model_config.hf_config, linear_method) else: model = model_class(model_config.hf_config, vision_language_config, linear_method) + + if not hasattr(model_class, "supported_lora_modules") and lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17fc97056804..f488711d1158 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -74,6 +74,13 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } +# Models supported by smoothquant +_SUPPORTED_SMOOTHQUANT_MODELS = { + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), +} + class ModelRegistry: @@ -112,6 +119,9 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + def get_supported_smoothquant_archs() -> List[str]: + return list(_SUPPORTED_SMOOTHQUANT_MODELS.keys()) __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ef19c41e67ae..1fffbc5fa30c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,12 +29,21 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.smoothquant import ( + Int8GEMM, + SQLinearMethod, + SQLinearMethodQKV, + SQLinearMethodOProj, + SQLinearMethodGateUpProj, + SQLinearMethodDownProj) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) + +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -49,7 +58,6 @@ from vllm.sequence import SamplerOutput from vllm.utils import is_hip - class LlamaMLP(nn.Module): def __init__( @@ -58,19 +66,42 @@ def __init__( intermediate_size: int, hidden_act: str, linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + self.hidden_size = hidden_size + self.use_int8 = quant_config is not None and quant_config.get_name( + ) == "smoothquant" + + gate_up_linear_method = linear_method + if self.use_int8: + # override gate_up linear method + assert isinstance(linear_method, SQLinearMethod) + gate_up_linear_method = SQLinearMethodGateUpProj( + gemm=Int8GEMM, + quant_dtype=torch.int8, + dequant_dtype=torch.float) self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=gate_up_linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") + + down_proj_linear_method = linear_method + if self.use_int8: + # override gate_up linear method + assert isinstance(linear_method, SQLinearMethod) + down_proj_linear_method = SQLinearMethodDownProj( + gemm=Int8GEMM, + quant_dtype=torch.int8, + dequant_dtype=torch.float) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=down_proj_linear_method) + self.act_fn = SiluAndMul() def forward(self, x): @@ -79,7 +110,6 @@ def forward(self, x): x, _ = self.down_proj(x) return x - class LlamaAttention(nn.Module): def __init__( @@ -91,31 +121,38 @@ def __init__( rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + self.default_dtype = torch.get_default_dtype() + + if self.total_num_kv_heads >= self.tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % self.tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.use_int8 = quant_config is not None and quant_config.get_name( + ) == "smoothquant" + # Needs to be ironed out!! + self.use_per_token_quant = self.use_int8 # This will be overwritten by model initialization if we are using it. # N.B. currently we only support per tensor scalar scaling factors @@ -126,33 +163,55 @@ def __init__( # scaling_factor = tensor_amax / FPtype_max self.kv_scale = 1.0 + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + qkv_linear_method = linear_method + if self.use_int8: + # override qkv linear method + assert isinstance(linear_method, SQLinearMethod) + qkv_linear_method = SQLinearMethodQKV( + gemm=Int8GEMM, + qkv_sizes=(self.q_size, self.kv_size, self.kv_size), + quant_dtype=torch.int8, + dequant_dtype=self.rotary_emb.cos_sin_cache.dtype) self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + linear_method=qkv_linear_method, ) + + o_proj_linear_method = linear_method + if self.use_int8: + # override o_proj linear method + assert isinstance(linear_method, SQLinearMethod) + o_proj_linear_method = SQLinearMethodOProj( + gemm=Int8GEMM, + use_per_token_quant=True, # TODO (varun) : Read from config + quant_dtype = torch.int8, + dequant_dtype= torch.float) + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - linear_method=linear_method, + linear_method=o_proj_linear_method, ) - self.rotary_emb = get_rope( + self.attn = Attention( + self.num_heads, self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) def forward( self, @@ -169,16 +228,19 @@ def forward( output, _ = self.o_proj(attn_output) return output - class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_int8 = quant_config is not None and quant_config.get_name( + ) == "smoothquant" + self.tp_size = get_tensor_model_parallel_world_size() rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -193,6 +255,7 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, linear_method=linear_method, + quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, ) @@ -201,6 +264,7 @@ def __init__( intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -215,6 +279,7 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention if residual is None: residual = hidden_states @@ -222,6 +287,7 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -230,8 +296,8 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, + residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -242,6 +308,7 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -257,7 +324,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) + LlamaDecoderLayer(config, linear_method, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -323,12 +390,15 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.quant_config = quant_config + self.model = LlamaModel(config, linear_method, lora_config=lora_config, + quant_config = quant_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -377,6 +447,9 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + # For SmoothQuant + int8_fusion = self.quant_config is not None and \ + self.quant_config.get_name() == "smoothquant" stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -395,6 +468,26 @@ def load_weights(self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + # bias is useless for llama + if "bias" in name: + continue + # load dequant scale for qkv_proj and gate_up_proj + if int8_fusion: + is_fusion_scale = False + if "scale" in name: + for (param_name, weight_name, _) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + prefix = weight_name.split('_')[0] + suffix = name.split('.')[-1] + new_name = prefix + '_' + suffix + param = params_dict[name.replace(suffix, new_name)] + param.copy_(loaded_weight) + is_fusion_scale = True + break + if is_fusion_scale: + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py new file mode 100644 index 000000000000..93ec4a800e60 --- /dev/null +++ b/vllm/profiler/__init__.py @@ -0,0 +1,5 @@ +from .nm_profile import nm_profile + +__all__ = [ + "nm_profile", +] diff --git a/vllm/profiler/nm_profile.py b/vllm/profiler/nm_profile.py new file mode 100644 index 000000000000..1912a300c02b --- /dev/null +++ b/vllm/profiler/nm_profile.py @@ -0,0 +1,346 @@ +import pandas as pd +import copy + +from collections import defaultdict +from dataclasses import dataclass, field, asdict +from vllm.profiler.utils import (indent_string, TablePrinter, event_has_module, + event_is_torch_op, event_module_repr, + event_torch_op_stack_trace) +from typing import Dict, List, Union, Optional, Tuple, Callable, TypeAlias +from torch.profiler import profile, ProfilerActivity +from torch.autograd.profiler import FunctionEvent +from torch._C._autograd import _ProfilerResult, _KinetoEvent, DeviceType +from torch._C._profiler import _EventType, _ProfilerEvent, _ExperimentalConfig + + +@dataclass +class _ModuleTreeNode: + event: _ProfilerEvent + parent: Optional['_ModuleTreeNode'] = None + children: List['_ModuleTreeNode'] = field(default_factory=list) + trace: str = "" + + @property + def is_leaf(self): + return (self.event.children is None or len(self.event.children) == 0) + + @property + def is_torch_op(self): + return event_is_torch_op(self.event) + + @property + def is_cuda(self): + return (self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA) + + +@dataclass +class SummaryStatsEntry: + name: str + cuda_time_us: float + pct_cuda_time: float + invocations: int + + +@dataclass +class ModelStatsEntry: + name: str + cpu_time_us: float + cuda_time_us: float + pct_cuda_time: float + trace: str + + +StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] + + +@dataclass +class _StatsTreeNode: + entry: StatsEntry + children: List[StatsEntry] + parent: Optional[StatsEntry] + + +@dataclass +class NMProfileResults(profile): + _kineto_results: _ProfilerResult + _kineto_event_correlation_map: Dict[int, + List[_KinetoEvent]] = field(init=False) + _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) + _module_tree: List[_ModuleTreeNode] = field(init=False) + _model_stats_tree: List[_StatsTreeNode] = field(init=False) + _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + + def __post_init__(self): + self._build_correlation_map() + self._build_module_tree() + self._build_stats_trees() + + def print_model_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + if column_widths: + _column_widths.update(**column_widths) + filtered_model_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._model_stats_tree) + if row.cuda_time_us > 0 or row.cpu_time_us > 0 + ] + TablePrinter(ModelStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_model_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def print_summary_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + if column_widths: + _column_widths.update(**column_widths) + filtered_summary_table = [(depth, row) + for depth, row in self._flatten_stats_tree( + self._summary_stats_tree) + if row.cuda_time_us > 0] + TablePrinter(SummaryStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_summary_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def export_model_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._model_stats_tree) + ]) + df.to_csv(filename) + + def export_summary_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ]) + df.to_csv(filename) + + def convert_stats_to_dict(self) -> str: + return { + "summary_stats": + self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": + self._convert_stats_tree_to_dict(self._model_stats_tree) + } + + @staticmethod + def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + StatsEntry]], + indent_style: Union[Callable[[int], + str], + str] = " "): + indented_rows = [] + for depth, row in depths_rows: + if row.cuda_time_us == 0: + continue + indented_row = copy.deepcopy(row) + indented_row.name = indent_string(indented_row.name, depth, + indent_style) + indented_rows.append(indented_row) + return indented_rows + + def _build_correlation_map(self): + self._kineto_event_correlation_map = defaultdict(list) + for event in self._kineto_results.events(): + self._kineto_event_correlation_map[event.correlation_id()].append( + event) + + def _build_module_tree(self): + self._module_tree = [] + event_tree = self._kineto_results.experimental_event_tree() + + def _df_traversal(event: _ProfilerEvent, + curr_node: Optional[_ModuleTreeNode] = None): + if event_has_module(event): + node = _ModuleTreeNode(event=event, parent=curr_node) + if curr_node: + curr_node.children.append(node) + else: + self._module_tree.append(node) + curr_node = node + + is_leaf = (event.children is None or len(event.children) == 0) + if is_leaf and curr_node: + node = _ModuleTreeNode( + event=event, + parent=curr_node, + trace=event_torch_op_stack_trace( + event, until=lambda x: event_has_module(x))) + curr_node.children.append(node) + curr_node = node + + for child in event.children: + _df_traversal(child, curr_node) + + for root in event_tree: + _df_traversal(root) + + def _get_kineto_gpu_event(self, node: _ModuleTreeNode): + if node.event.tag != _EventType.Kineto: + return None + correlated_kineto_events = self._kineto_event_correlation_map.get( + node.event.correlation_id, []) + iterator = (x for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA) + return next(iterator, None) + + def _cumulative_cuda_time(self, node: _ModuleTreeNode): + + def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): + if node.is_leaf and (gpu_kineto_event := + self._get_kineto_gpu_event(node)): + return gpu_kineto_event.duration_us() + else: + cumulative_cuda_time = 0 + for child in node.children: + cumulative_cuda_time += _cumulative_cuda_time_recursive( + child) + return cumulative_cuda_time + + return _cumulative_cuda_time_recursive(node) + + def _total_cuda_time(self): + return sum( + [self._cumulative_cuda_time(root) for root in self._module_tree]) + + def _build_stats_trees(self): + summary_dict: Dict[str, self.StatsTreeNode] = {} + total_cuda_time = self._total_cuda_time() + + def pct_cuda_time(cuda_time_us): + return (cuda_time_us / total_cuda_time) * 100 + + def build_summary_stats_tree_df( + node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None, + summary_trace: Tuple[str] = ()): + + if event_has_module(node.event): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_us() + else: + return None + + summary_trace = summary_trace + (name, ) + if summary_trace in summary_dict: + entry = summary_dict[summary_trace].entry + entry.cuda_time_us += cuda_time_us + entry.invocations += 1 + entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) + else: + new_node = _StatsTreeNode(entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1), + children=[], + parent=parent) + if parent: + parent.children.append(new_node) + summary_dict[summary_trace] = new_node + + for child in node.children: + build_summary_stats_tree_df(child, summary_dict[summary_trace], + summary_trace) + + return summary_dict[summary_trace] + + self._summary_stats_tree = [] + for root in self._module_tree: + self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + + def build_model_stats_tree_df(node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None): + if event_has_module(node.event, ): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + cpu_time_us = node.event.duration_time_ns / 1000 + trace = "" + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_us() + cpu_time_us = 0 + trace = node.trace + else: + return None + + new_node = _StatsTreeNode(entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace), + parent=parent, + children=[]) + if parent: + parent.children.append(new_node) + + for child in node.children: + build_model_stats_tree_df(child, new_node) + + return new_node + + self._model_stats_tree = [] + for root in self._module_tree: + self._model_stats_tree.append(build_model_stats_tree_df(root)) + + def _flatten_stats_tree( + self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: + entries: List[Tuple[int, StatsEntry]] = [] + + def df_traversal(node: _StatsTreeNode, depth=0): + entries.append((depth, node.entry)) + for child in node.children: + df_traversal(child, depth=depth + 1) + + for root in tree: + df_traversal(root) + + return entries + + def _convert_stats_tree_to_dict(self, + tree: List[_StatsTreeNode]) -> List[Dict]: + root_dicts: List[Dict] = [] + + def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + curr_json_list.append({ + "entry": asdict(node.entry), + "children": [] + }) + for child in node.children: + df_traversal(child, curr_json_list[-1]["children"]) + + for root in tree: + df_traversal(root, root_dicts) + + return root_dicts + + +class nm_profile(profile): + + def __init__(self): + super().__init__( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + with_modules=True, + experimental_config=_ExperimentalConfig(verbose=True)) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.results = NMProfileResults(self.profiler.kineto_results) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py new file mode 100644 index 000000000000..f8ead593d178 --- /dev/null +++ b/vllm/profiler/utils.py @@ -0,0 +1,146 @@ +import dataclasses + +from typing import Callable, Dict, Type, List, Union + +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata + +# +# String / Print Manipulation +# + + +def trim_string_front(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[offset:] + if len(string) > 3: + string = "..." + string[3:] + return string + + +def trim_string_back(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +class TablePrinter: + + def __init__(self, row_cls: Type[dataclasses.dataclass], + column_widths: Dict[str, int]): + self.row_cls = row_cls + self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] + self.column_widths = column_widths + assert set(self.column_widths.keys()) == set(self.fieldnames) + + def print_table(self, rows: List[dataclasses.dataclass]): + self._print_header() + self._print_line() + for row in rows: + self._print_row(row) + + def _print_header(self): + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + print(trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n") + + def _print_row(self, row): + assert isinstance(row, self.row_cls) + + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + val = getattr(row, f) + + val_str = "" + if isinstance(val, str): + val_str = trim_string_back(val, col_width).ljust(col_width) + elif type(val) in [float, int]: + val_str = f"{float(val):>.2f}".rjust(col_width) + else: + val_str = f"{val}".rjust(col_width) + print(val_str, end=" | " if not last else "\n") + + def _print_line(self): + total_col_width = 0 + for column_width in self.column_widths.values(): + total_col_width += column_width + print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) + + +def indent_string(string: str, + indent: int, + indent_style: Union[Callable[[int], str], str] = " ") -> str: + if indent: + if isinstance(indent_style, str): + return indent_style * indent + string + else: + return indent_style(indent) + string + else: + return string + + +# +# _ProfilerEvent utils +# + + +def event_has_module(event: _ProfilerEvent) -> bool: + event_type, typed_event = event.typed + if event_type == _EventType.PyCall: + return typed_event.module is not None + return False + + +def event_is_torch_op(event: _ProfilerEvent) -> bool: + return event.tag == _EventType.TorchOp + + +def event_arg_repr(arg) -> str: + if arg is None or type(arg) in [float, int, bool, str]: + return f"{arg}" + elif isinstance(arg, list): + return f"[{', '.join([event_arg_repr(x) for x in arg])}]" + elif isinstance(arg, tuple): + return f"({', '.join([event_arg_repr(x) for x in arg])})" + else: + assert isinstance(arg, + _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ', '.join([str(x) for x in arg.sizes]) + return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" + + +def event_torch_op_repr(event: _ProfilerEvent) -> str: + assert event.tag == _EventType.TorchOp + args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + return f"{event.name}({args_str})".replace("aten::", "") + + +def event_module_repr(event: _ProfilerEvent) -> str: + assert event_has_module(event) + module = event.typed[1].module + if module.parameters and len(module.parameters) > 0: + args_str = ', '.join( + [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + return f"{module.cls_name}({args_str})" + else: + return module.cls_name + + +def event_torch_op_stack_trace(curr_event: _ProfilerEvent, + until: Callable[[_ProfilerEvent], bool]) -> str: + trace = "" + curr_event = curr_event.parent + while curr_event and not until(curr_event): + if event_is_torch_op(curr_event): + if len(trace) > 0: + trace += " <- " + trace += event_torch_op_repr(curr_event) + curr_event = curr_event.parent + + return trace diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a..f984a8bda4d7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -900,9 +900,10 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() -class CUDAGraphRunner: +class CUDAGraphRunner(nn.Module): def __init__(self, model: nn.Module): + super().__init__() self.model = model self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} @@ -984,6 +985,7 @@ def forward( # Return the output tensor. return self.output_buffers["hidden_states"] + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) From d39af9625bfb903472f73b722cceb8929663db6c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 11 Apr 2024 04:17:10 +0000 Subject: [PATCH 2/8] add offline quantized inference --- examples/offline_quantized_inference.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 examples/offline_quantized_inference.py diff --git a/examples/offline_quantized_inference.py b/examples/offline_quantized_inference.py new file mode 100644 index 000000000000..8276a8ceabc6 --- /dev/null +++ b/examples/offline_quantized_inference.py @@ -0,0 +1,46 @@ +from vllm import LLM, SamplingParams +import torch + +model_path="/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b/quantized_model/llama-13b/Nous-Hermes-Llama2-13b-smoothquant/" +tokenizer="/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b/" + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create a sampling params object. +#sampling_params = SamplingParams(temperature=0.0, top_p=0.95) +sampling_params = SamplingParams(temperature=0.0, top_k = 1,max_tokens=20) + +# Create an LLM. +llm = LLM( + model=model_path, + tokenizer=tokenizer, + gpu_memory_utilization=0.9, + max_model_len=2048, + quantization="smoothquant", + dtype=torch.float, + enforce_eager=True, + tensor_parallel_size=1, + max_num_batched_tokens=7000) + +#llm = LLM( +# model=tokenizer, +# tokenizer=tokenizer, +# gpu_memory_utilization=0.9, +# max_model_len=2048, +# max_num_batched_tokens=7000) + + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From c2c59a9135dfdf322e3866dc0be818b2d213b31e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 11 Apr 2024 18:19:30 +0000 Subject: [PATCH 3/8] use hf-model --- examples/offline_quantized_inference.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/examples/offline_quantized_inference.py b/examples/offline_quantized_inference.py index 8276a8ceabc6..124a99468704 100644 --- a/examples/offline_quantized_inference.py +++ b/examples/offline_quantized_inference.py @@ -1,8 +1,8 @@ from vllm import LLM, SamplingParams import torch -model_path="/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b/quantized_model/llama-13b/Nous-Hermes-Llama2-13b-smoothquant/" -tokenizer="/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b/" +hf_path="nm-testing/Nous-Hermes-Llama2-13b-smoothquant" +model_path=hf_path # Sample prompts. prompts = [ @@ -13,13 +13,11 @@ ] # Create a sampling params object. -#sampling_params = SamplingParams(temperature=0.0, top_p=0.95) sampling_params = SamplingParams(temperature=0.0, top_k = 1,max_tokens=20) # Create an LLM. llm = LLM( model=model_path, - tokenizer=tokenizer, gpu_memory_utilization=0.9, max_model_len=2048, quantization="smoothquant", @@ -28,14 +26,6 @@ tensor_parallel_size=1, max_num_batched_tokens=7000) -#llm = LLM( -# model=tokenizer, -# tokenizer=tokenizer, -# gpu_memory_utilization=0.9, -# max_model_len=2048, -# max_num_batched_tokens=7000) - - # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From 3e7d1c89ae3da48d54fc32ac7439d3a4a42e96e0 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:43:08 -0400 Subject: [PATCH 4/8] Cleanup / Refactor to support non-uniform quantization via config. This merge is a combination of 2 PRs - #186 and #188: - #188 is based on #186 and #188 is squash-merged onto #186. #186 : [1/N] Rs/vllm quantization - Refactor to minimize llama.py changes #188 : [2/N] Rs/vllm quantization - Refactor refactor to support non-uniform via config The PR description from both the PRs are included here for context. #188 's PR description should be the most relevant as it is the most recent. [2/N] Rs/vllm quantization - Refactor refactor to support non-uniform via config Refactored to support nonuniform quantization by adding a new layer of Abstraction. Now, SmoothQuantLinearMethod can hold a SmoothQuantFormat, which implements the details of how to do quant and dequant operations. There are two SmoothQuantFormat classes: SmoothQuantDynamicPerToken SmoothQuantStaticPerTensor We have the following lifecycle: LinearMethod is created during get_model, has access to QuantizationConfig Layer is initialized and passed a LinearMethod Layer calls LinearMethod.create_weights, which creates a dictionary of weights and metadata Layer calls LinearMethod.apply_weights during inference, passing the dictionary created during create_weights This PR modifies the LinearMethod.create_weights API to receive a layer_name as argument. The LinearMethod then looks in the config to determine which SmoothQuantFormat to use for the layer with layer_name As a result, the LinearMethod is responsible for parsing the config from disk and making decisions about what the inference format should look like. In this specific case, since the SmoothQuantConfig is not very good, we just match on the suffix qkv to determine what each layer should use --> but for SparseMLConfig, we could use a similar structure In this PR, the SmoothQuantFormat is passed in the dictionary returned by create_weights and then is used by apply_weights In Summary I think this is a good overall structure because it: (a) allows us to make minimal changes to the existing models (b) allows us to make no changes to the model loading lifecycle (i.e. config / constructor / linear method) ** critically requires having one LinearMethod that propagates through the whole model (c) encapsulates the nonuniform logic into the LinearMethod, allowing us to have a clean interface into For SparseML Models We could imagine the following architecture: Config Config is responsible for: loading config from disk mapping layer_names --> SparseMLFormat class SparseMLConfig def from_dict() def get_layer_format(layer_name): return SparseMLFormat LinearMethod Config is responsible for: interface between layers and kernels (so LinearMethod is what is used by the model) class SparseMLLinearMethod: def __init__(self, sparseml_config) self.sparseml_config = sparseml_config def create_weights(layer_name, ...): # this, e.g. is where nonuniform might be supported format = self.sparseml_config.get_layer_format(layer_name) weights = format.get_weights() weights["format"] = format return weights # wrapper around the SparseML format def apply_weights(x, weights, ...) format = weights["format"] weights = weights["weights"] return format.apply_weights(x, weights) SparseMLFormat Format is responsible for: actual weight creation and forward class SparseMLLinearMethod: def __init__(self, sparseml_config) self.sparseml_config = sparseml_config def get_weights(sizes): # returns dictionary , e.g. return { "weights": x "scales": y } def apply_weights(weights, x): # calls cuda kernel return output Sample Formats: - W8A8DynamicPerToken - SparseW8A8StaticPerTensorAsymmetric - W4A8DynamicPerToken - ... [1/N] Rs/vllm quantization - Refactor to minimize llama.py changes #186 Paired with @dsikka to refactor `SmoothQuantLinearMethod` to avoid making changes to `llama.py` - Removed all the "layer specific" `SmoothQuantLinearMethod` by making the indexing (splitting QKV into logical shards generic and explicitly handling state_dict converion - Successfully whittled down to only add one LOC to `llama.py` Many todos left, including: - We currently have hardcoded `use_per_token`, need to use the quant config for this - We need a way to pass different quantconfigs to each layer to support nonuniform quantization --- examples/offline_inference.py | 2 +- examples/offline_quantized_inference.py | 2 +- examples/simple_test.py | 35 ++ vllm/config.py | 4 +- vllm/model_executor/layers/linear.py | 150 ++++++-- .../model_executor/layers/quantization/awq.py | 15 +- .../layers/quantization/base_config.py | 6 +- .../layers/quantization/gptq.py | 7 +- .../layers/quantization/marlin.py | 6 +- .../layers/quantization/smoothquant.py | 348 ------------------ .../quantization/smoothquant/__init__.py | 14 + .../layers/quantization/smoothquant/config.py | 306 +++++++++++++++ .../quantization/smoothquant/formats.py | 100 +++++ .../layers/quantization/squeezellm.py | 15 +- vllm/model_executor/model_loader.py | 9 +- vllm/model_executor/models/llama.py | 195 +++------- 16 files changed, 682 insertions(+), 532 deletions(-) create mode 100644 examples/simple_test.py delete mode 100644 vllm/model_executor/layers/quantization/smoothquant.py create mode 100644 vllm/model_executor/layers/quantization/smoothquant/__init__.py create mode 100644 vllm/model_executor/layers/quantization/smoothquant/config.py create mode 100644 vllm/model_executor/layers/quantization/smoothquant/formats.py diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..6b548d5e8921 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_quantized_inference.py b/examples/offline_quantized_inference.py index 124a99468704..8b3dbea72ae6 100644 --- a/examples/offline_quantized_inference.py +++ b/examples/offline_quantized_inference.py @@ -17,7 +17,7 @@ # Create an LLM. llm = LLM( - model=model_path, + model="nm-testing/Nous-Hermes-Llama2-13b-smoothquant", gpu_memory_utilization=0.9, max_model_len=2048, quantization="smoothquant", diff --git a/examples/simple_test.py b/examples/simple_test.py new file mode 100644 index 000000000000..dcf8b8c7ed1e --- /dev/null +++ b/examples/simple_test.py @@ -0,0 +1,35 @@ +import argparse +from vllm import LLM, SamplingParams + +MODELS = { + "tinyllama-fp16": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "tinyllama-marlin": "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", + "tinyllama-gptq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + "tinyllama-awq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ", +} + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str) +parser.add_argument("--tensor-parallel-size", type=int, default=1) +args = parser.parse_args() + +if args.model not in MODELS: + print(f"Got model id of {args.model}; Must be in {list(MODELS.keys())}") + raise ValueError +else: + model_id = MODELS[args.model] + print(f"Using model_id = {model_id}") + +messages=[{ + "role": "system", + "content": "You are a helpful assistant." +}, { + "role": "user", + "content": "What is deep learning?" +}] + +model = LLM(model_id, enforce_eager=True, max_model_len=2048, tensor_parallel_size=args.tensor_parallel_size, dtype="float16") +prompt = model.llm_engine.tokenizer.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +out = model.generate(prompt, SamplingParams(max_tokens=50)) +print(f"\n-----prompt\n{prompt}") +print(f"\n-----generation\n{out[0].outputs[0].text}") diff --git a/vllm/config.py b/vllm/config.py index 3149aaf68914..cd48fe4f1b9d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,8 +173,8 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "smoothquant"] - rocm_not_supported_quantization = ["awq", "marlin"] + supported_quantization = ["awq", "gptq", "marlin", "squeezellm", "smoothquant"] + rocm_not_supported_quantization = ["awq", "marlin", "smoothquant"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2..2598156bbed3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -29,8 +29,11 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + def create_weights(self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" @@ -43,6 +46,12 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights to the input tensor.""" raise NotImplementedError + + def maybe_update_loaded_weight_name(self, name: str) -> str: + """Update the name of a loaded weight to enable generic handling of + cases where serialized state_dict does not match vllm model definition. + """ + return name class UnquantizedLinearMethod(LinearMethodBase): @@ -56,17 +65,20 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, + def create_weights(self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - weight = Parameter(torch.empty(output_size_per_partition, + weight = Parameter(torch.empty(sum(output_sizes_per_partition), input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) return {"weight": weight} + def apply_weights(self, weights: Dict[str, torch.Tensor], x: torch.Tensor, @@ -83,6 +95,7 @@ class ReplicatedLinear(torch.nn.Module): """Replicated linear layer. Args: + layer_name: name of the layer in the state dict. input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. bias: If true, add bias. @@ -93,6 +106,7 @@ class ReplicatedLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -103,6 +117,7 @@ def __init__( super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add @@ -113,8 +128,8 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.layer_name, self.input_size, [self.output_size], + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -139,6 +154,7 @@ class ColumnParallelLinear(torch.nn.Module): its second dimension as A = [A_1, ..., A_p]. Args: + layer_name: name of the layer in the state dict. input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. @@ -150,10 +166,14 @@ class ColumnParallelLinear(torch.nn.Module): skip adding bias but instead return it. params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. + logical_widths: Optional list of widths for logical weight matrices. + E.g. for QKVParallelLinear, this parameter defines + the width """ def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -165,12 +185,20 @@ def __init__( super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_sizes_per_partition = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if self.output_sizes is not None: + self.output_sizes_per_partition = [ + divide(output_size, tp_size) for output_size in self.output_sizes + ] + self.skip_bias_add = skip_bias_add if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -179,8 +207,13 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) + layer_name=self.layer_name, + input_size_per_partition=self.input_size, + output_sizes_per_partition=self.output_sizes_per_partition, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + ) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -246,6 +279,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): def __init__( self, + layer_name: str, input_size: int, output_sizes: List[int], bias: bool = True, @@ -257,8 +291,15 @@ def __init__( self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + super().__init__( + layer_name=layer_name, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -266,6 +307,18 @@ def weight_loader(self, loaded_shard_id: Optional[int] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + param_shard_splitter = getattr(param, "shard_splitter", None) + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -318,6 +371,10 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -325,6 +382,7 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -340,6 +398,7 @@ class QKVParallelLinear(ColumnParallelLinear): be replicated while the query heads are partitioned. Args: + layer_name: name of the layer in the state dict. hidden_size: input hidden state size of the transformer. head_size: size of each attention head. total_num_heads: total number of attention query heads. @@ -355,6 +414,7 @@ class QKVParallelLinear(ColumnParallelLinear): def __init__( self, + layer_name: str, hidden_size: int, head_size: int, total_num_heads: int, @@ -383,8 +443,21 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + layer_name=layer_name, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -392,6 +465,18 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) if loaded_shard_id is None: # Loaded weight is already packed. @@ -427,6 +512,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -450,15 +537,19 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) + param_data = param_data.narrow( + output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -466,7 +557,11 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") - assert param_data.shape == loaded_weight.shape + + assert ( + param_data.shape == loaded_weight.shape or + (len(param_data.shape) == 0 and len(loaded_weight.shape) == 0) + ) param_data.copy_(loaded_weight) @@ -483,6 +578,7 @@ class RowParallelLinear(torch.nn.Module): | A_p | - - Arguments: + layer_name: name of the layer in the state dict. input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. Note that bias is not parallelized. @@ -498,6 +594,7 @@ class RowParallelLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -509,6 +606,7 @@ def __init__( ): super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel @@ -525,8 +623,13 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) + layer_name=self.layer_name, + input_size_per_partition=self.input_size_per_partition, + output_sizes_per_partition=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + ) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -555,6 +658,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # TODO: canon + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..7cf94ae9f44e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,17 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights( + self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size, output_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 6115e7c3be95..868e09252bb2 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -51,14 +51,12 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_linear_method(self) -> LinearMethodBase: + def get_linear_method(self, name) -> LinearMethodBase: """Get the linear method to use for the quantized linear layer.""" raise NotImplementedError - @abstractmethod def get_scaled_act_names(self) -> List[str]: """Returns the activation function names that should be post-scaled. - For now, this is only used by AWQ. """ - raise NotImplementedError + return [] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..8c3492ae67d8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,13 +89,16 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer_name: str, input_size_per_partition: int, - output_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + del output_size, layer_name # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf..59d217567919 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,13 +91,15 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer_name: str, input_size_per_partition: int, - output_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + del layer_name, input_size, output_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) if params_dtype != torch.float16: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/smoothquant.py b/vllm/model_executor/layers/quantization/smoothquant.py deleted file mode 100644 index d9d82e6cfbc3..000000000000 --- a/vllm/model_executor/layers/quantization/smoothquant.py +++ /dev/null @@ -1,348 +0,0 @@ -from typing import Any, Dict, List, Tuple, Optional - -import torch -from torch._tensor import Tensor -from torch.nn.parameter import Parameter -import threading - -from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig - - -class SmoothQuantConfig(QuantizationConfig): - """Config class for SmoothQuant - - Reference: https://github.com/mit-han-lab/smoothquant - """ - - def __init__(self, - weight_bits: int = 8, - quant_map: dict[str:str] = None) -> None: - self.weight_bits = weight_bits - self.quant_map = quant_map - - if self.weight_bits != 8: - raise ValueError( - "Currently, only w8a8 quantization is supported for " - f"SmoothQuant, but got {self.weight_bits} bits.") - if self.quant_map is None or self.quant_map == {}: - raise ValueError( - 'Quant_map for SmoothQuant should not be None or an empty dict. ' - 'For example, when using llama, you should set a quant_config.json in model directory, like ' - '{ "qkv": "per-tensor", "out": "per-token", "fc1": "per-tensor", "fc2": "per-token" }' - ) - - def __repr__(self) -> str: - return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, " - f"quant_map={self.quant_map})") - - def get_name(self) -> str: - return "smoothquant" - - def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half, torch.float] - - def get_min_capability(self) -> int: - # The smoothquant kernel only supports Ampere or newer GPUs. - return 80 - - @classmethod - def get_config_filenames(cls) -> List[str]: - """List of filenames to search for in the model directory.""" - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": - try: - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - except ValueError as e: - weight_bits = 8 - print(str(e) + " Set weight_bits = 8 by default.") - - quant_map = {} - for key, value in config.items(): - if value in ["per-tensor", "per-token"]: - quant_map[key] = value - return cls(weight_bits, quant_map) - - def get_linear_method(self) -> "SQLinearMethod": - return SQLinearMethod(Int8GEMM) - - def get_scaled_act_names(self) -> List[str]: - return [] - - -class Int8GEMM(object): - _instance_lock = threading.Lock() - - def __init__(self): - if not hasattr(self, "i8cugemm"): - self.i8cugemm = ops.I8CUGEMM() - - def __new__(cls, *args, **kwargs): - if not hasattr(Int8GEMM, "_instance"): - with Int8GEMM._instance_lock: - if not hasattr(Int8GEMM, "_instance"): - Int8GEMM._instance = object.__new__(cls) - return Int8GEMM._instance - - def get_i8cugemm(self): - return self.i8cugemm - - -class SQLinearMethod(LinearMethodBase): - """Linear method for SmoothQuant. - """ - - def __init__(self, gemm): - i8_gemm = gemm() - self.i8cugemm = i8_gemm.get_i8cugemm() - - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Tensor]: - weight = Parameter( - torch.empty( - output_size_per_partition, - input_size_per_partition, - device="cuda", - dtype=torch.int8, - ), - requires_grad=False, - ) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - }) - # q k v dequant_scales are used in QKVParallelLinear - q_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - k_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - v_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - # gate up dequant_scales are used in MergedColumnParallelLinear - gate_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - up_dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - # dequant_scale is used in RowParallelLinear - dequant_scale = Parameter( - torch.tensor(1.0, dtype=torch.float32, device='cpu'), - requires_grad=False, - ) - return { - "weight": weight, - "q_dequant_scale": q_dequant_scale, - "k_dequant_scale": k_dequant_scale, - "v_dequant_scale": v_dequant_scale, - "gate_dequant_scale": gate_dequant_scale, - "up_dequant_scale": up_dequant_scale, - "dequant_scale": dequant_scale - } - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - assert bias is None - weight = weights["weight"] - x_shape = x.shape - x = x.view(-1, x_shape[-1]) - y = torch.empty((x.shape[0], weight.shape[0]), - dtype=torch.int32, - device=x.device) - self.i8cugemm.linear_a8_w8_o32_(x, weight, y) - y = y.view(*x_shape[:-1], -1) - return y - - -class SQLinearMethodQKV(SQLinearMethod): - - def __init__(self, - gemm, - qkv_sizes : Tuple[int, int, int], - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.qkv_sizes = qkv_sizes - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x): - assert x.dtype != self.quant_dtype - x_q = torch.empty_like(x, dtype=self.quant_dtype) - ops.quant(x_q, x, 1.0) - return x_q - - def dequantize(self, x_q, weights : Dict[str, Tensor]): - # split to get the quantized qkv - q_q, k_q, v_q = x_q.split(list(self.qkv_sizes), dim=-1) - - # create dequant qkv buffer and split to get the individual dequant qkv - # buffers - qkv = torch.empty_like(x_q, dtype=self.dequant_dtype) - q, k, v = qkv.split(list(self.qkv_sizes), dim=-1) - - q_scale, k_scale, v_scale = (weights['q_dequant_scale'], - weights['k_dequant_scale'], - weights['v_dequant_scale']) - ops.dequant(q, q_q, q_scale) - ops.dequant(k, k_q, k_scale) - ops.dequant(v, v_q, v_scale) - - return qkv - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - x_q = self.quantize(x) - y_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(y_q, weights) - -class SQLinearMethodOProj(SQLinearMethod): - - def __init__(self, - gemm, - use_per_token_quant:bool, - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.use_per_token_quant = use_per_token_quant - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # x is the paged-attention output - assert x.dtype != self.quant_dtype - act_scale = None - x_q = torch.empty_like(x, dtype=self.quant_dtype) - if self.use_per_token_quant: - act_scale = torch.empty(x.numel() // x.shape[-1], - dtype=torch.float32, - device=x.device) - ops.quant(x_q, x, act_scale) - else: - ops.quant(x_q, x, 1.0) - return x_q, act_scale - - def dequantize(self, x_q, weights : Dict[str, Tensor], act_scale : torch.Tensor) -> torch.Tensor: - o_dequant_scale = weights['dequant_scale'] - x = torch.empty_like( - x_q, - dtype=self.dequant_dtype, - device=x_q.device) - ops.dequant(x, x_q, act_scale, o_dequant_scale) - return x - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - pass - x_q, act_scale = self.quantize(x) - y_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(y_q, weights, act_scale) - -class SQLinearMethodGateUpProj(SQLinearMethod): - - def __init__(self, - gemm, - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x) -> torch.Tensor: - # x is the attention output - assert x.dtype != self.quant_dtype - x_q = torch.empty_like(x, dtype=self.quant_dtype, device=x.device) - ops.quant(x_q, x, 1.0) - return x_q - - def dequantize(self, gate_up_q: torch.Tensor, weights : Dict[str, Tensor]) -> torch.Tensor: - - def split_gate_up(gate_up : torch.Tensor): - d = gate_up.shape[-1] - return (torch.narrow(gate_up, 1, 0, d//2), - torch.narrow(gate_up, 1, d//2, d//2)) - - # create a dequant gate_up buffer and split it into constituent parts. - gate_up = torch.empty_like(gate_up_q, - dtype=self.dequant_dtype, - device=gate_up_q.device) - - # split quantized gate_up into constituent parts. - gate_q, up_q = split_gate_up(gate_up_q) - # split output gate_up buffer into constituent parts. - gate, up = split_gate_up(gate_up) - - gate_scale, up_scale = (weights['gate_dequant_scale'], - weights['up_dequant_scale']) - ops.dequant(gate, gate_q, gate_scale) - ops.dequant(up, up_q, up_scale) - - return gate_up - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> Tensor: - x_q = self.quantize(x) - gate_up_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(gate_up_q, weights) - -class SQLinearMethodDownProj(SQLinearMethod): - - def __init__(self, - gemm, - quant_dtype : torch.dtype = torch.int8, - dequant_dtype : torch.dtype = torch.float): - super().__init__(gemm) - self.quant_dtype = quant_dtype - self.dequant_dtype = dequant_dtype - - def quantize(self, x) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dtype != self.quant_dtype - # TODO (varun) : This is per-token quant - Read from config - x_q = torch.empty_like(x, dtype=self.quant_dtype) - scale = torch.empty(x.numel() // x.shape[-1], - dtype=torch.float32, - device=x.device) - ops.quant(x_q, x, scale) - return x_q, scale - - def dequantize(self, x_q, weights : Dict[str, Tensor], act_scale : torch.Tensor) -> torch.Tensor: - down_dequant_scale = weights['dequant_scale'] - x = torch.empty_like( - x_q, - dtype=self.dequant_dtype, - device=x_q.device) - ops.dequant(x, x_q, act_scale, down_dequant_scale) - return x - - def apply_weights(self, - weights: Dict[str, Tensor], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - x_q, act_scale = self.quantize(x) - y_q = super().apply_weights(weights, x_q, bias) - return self.dequantize(y_q, weights, act_scale) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/smoothquant/__init__.py b/vllm/model_executor/layers/quantization/smoothquant/__init__.py new file mode 100644 index 000000000000..2f62cee49d95 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/__init__.py @@ -0,0 +1,14 @@ +from vllm.model_executor.layers.quantization.smoothquant.formats import ( + SmoothQuantFormat +) + +from vllm.model_executor.layers.quantization.smoothquant.config import ( + SmoothQuantConfig, + SmoothQuantLinearMethod +) + +__all__ = [ + "SmoothQuantFormat", + "SmoothQuantConfig", + "SmoothQuantLinearMethod", +] diff --git a/vllm/model_executor/layers/quantization/smoothquant/config.py b/vllm/model_executor/layers/quantization/smoothquant/config.py new file mode 100644 index 000000000000..885ffce3e36d --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/config.py @@ -0,0 +1,306 @@ +from typing import Any, Dict, List, Tuple, Type, Optional, Union +import threading + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.smoothquant.formats import ( + SmoothQuantFormat, + SmoothQuantDynamicPerToken, + SmoothQuantStaticPerTensor, +) + +LAYER_KEYS = ["qkv", "out", "fc1", "fc2"] +FORMAT_REGISTRY = { + "per-token": SmoothQuantDynamicPerToken, + "per-tensor": SmoothQuantStaticPerTensor, +} + +def get_sq_format_cls(format_key: str) -> Type[SmoothQuantFormat]: + if format_key not in FORMAT_REGISTRY: + raise ValueError(f"Invalid smoothquant format: {format_key}") + return FORMAT_REGISTRY[format_key] + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant. + + Reference: https://github.com/mit-han-lab/smoothquant + """ + def __init__(self, + layer_format_map: Dict[str, str]) -> None: + self.layer_format_map = layer_format_map + + for key, format in self.layer_format_map.items(): + if key not in LAYER_KEYS: + raise ValueError( + f"Found key of {key} in {self.layer_format_map}, " + f"but key must be one of {LAYER_KEYS}" + ) + if format not in FORMAT_REGISTRY: + raise ValueError( + f"Found format of {format} in {self.layer_format_map}, " + f"but format must be one of {FORMAT_REGISTRY}" + ) + for key in LAYER_KEYS: + if key not in self.layer_format_map: + raise ValueError( + f"Could not find {key} in {layer_format_map}" + ) + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(layer_format_map={self.layer_format_map})") + + def get_name(self) -> str: + return "smoothquant" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + # TODO: check if we support fp16 / bf16 as well. + return [torch.float] + + def get_min_capability(self) -> int: + # TODO: check if this is right. + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + layer_format_map: Dict[str, str] = {} + for layer_key, format in config.items(): + if format in FORMAT_REGISTRY: + layer_format_map[layer_key] = format + return cls(layer_format_map) + + def get_linear_method(self) -> "SmoothQuantLinearMethod": + return SmoothQuantLinearMethod(self) + + +# TODO: why is this needed? +class Int8GEMM(object): + _instance_lock = threading.Lock() + + def __init__(self): + if not hasattr(self, "i8cugemm"): + self.i8cugemm = ops.I8CUGEMM() + + def __new__(cls, *args, **kwargs): + if not hasattr(Int8GEMM, "_instance"): + with Int8GEMM._instance_lock: + if not hasattr(Int8GEMM, "_instance"): + Int8GEMM._instance = object.__new__(cls) + return Int8GEMM._instance + + def get_i8cugemm(self): + return self.i8cugemm + + +class SmoothQuantLinearMethod(LinearMethodBase): + def __init__(self, sq_config: SmoothQuantConfig) -> None: + self.sq_config = sq_config + self.sq_type = None + self.i8cugemm = Int8GEMM().get_i8cugemm() + + def maybe_update_loaded_weight_name(self, + name: str) -> str: + """Convert serialized name k_dequant_scale to dequant_scale. + + This function is called by model_cls.load_weights() during the weight + loading process to match on disk state dict to vllm state dict. + """ + if "dequant_scale" in name: + suffix = name.split('.')[-1] + name.replace(suffix, "dequant_scale") + return name + + def scales_shard_splitter(self, + param: torch.Tensor, + loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + """Index into param for for loading. + + This function is called by QKVColumnLinear and MergedColumnParallelLinear + during weight loading to put the scales from disk in the right spot. + """ + if type(shard_id) == str: + qkv_idxs = { "q": 0, "k": 1, "v": 2 } + if shard_id not in qkv_idxs: + raise ValueError(f"Invalid shard_id {shard_id}") + shard_id = qkv_idxs[shard_id] + elif type(shard_id) != int: + raise ValueError(f"Invalid shard id {shard_id}") + + return param[shard_id], loaded_weight + + def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: + """ + Gets the SmoothQuantFormat for a specific layer. + + SmoothQuantLinearMethod uses SmoothQuantType to support non-uniform quantization + (where each layer has a different format). To determine the SmoothQuantFormat + for a layer, we match the layer_name to the layer_keys=["qkv","out","fc1","fc2"] + and use layer_format_map to to determine the SQFormat. + + Args: + layer_name: Name of the layer we are creating the LinearMethod for. + Returns + sq_linear_method: SmoothQuantLinearMethod with the right SQFormat. + """ + # Note: AutoSmoothQuant Serialization is not very good yet. + # + # It looks like the following (which does not map to layer names in the model): + # { + # "qkv": "per-tensor", + # "out": "per-token", + # "fc1": "per-tensor", + # "fc2": "per-token" + # } + # + # So, this is a hack for llama now. But with the SparseMLConfig, we can make robust, + # where we actually use the layer_name in the model to look up what the format is + # based on the config. + # + # What it would actually look like: + # layer_config is None + # for supported_key in SUPPORTED_LAYER_KEYS: + # if supported_key in layer_name: + # sq_format = self.layer_mapping[lookup_key] + # return get_sq_format_cls(sq_format)() + + HACKED_REMAP_FOR_LLAMA = { + "qkv": "qkv", + "o_proj": "out", + "gate_up": + "fc1", "down": "fc2", + } + + for match_key, lookup_key in HACKED_REMAP_FOR_LLAMA.items(): + if match_key in layer_name: + sq_format = self.sq_config.layer_format_map[lookup_key] + return get_sq_format_cls(sq_format)() + + raise ValueError + + def create_weights(self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + del input_size, output_size + + # Statically Quantized Weights. + weight = Parameter( + torch.empty( + sum(output_sizes_per_partition), + input_size_per_partition, + device="cuda", dtype=torch.int8, + ), requires_grad=False, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + + # Static scale for each logical weight (e.g. 3 for QKV). + dequant_scale = Parameter( + torch.empty( + len(output_sizes_per_partition), + device='cuda', dtype=params_dtype, + ), requires_grad=False + ) + set_weight_attrs(dequant_scale, { + "shard_splitter": self.scales_shard_splitter, + }) + + return { + "weight": weight, + "dequant_scale": dequant_scale, + "logical_widths": output_sizes_per_partition, + "sq_format": self.get_layer_format(layer_name) + } + + def _quantize(self, + x: torch.Tensor, + sq_format: SmoothQuantFormat) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Quantize activations. + + Args: + x: Activation at floating point precision. + Returns: + x_q: Quantized activation at INT8 + activation_scales: Optional dynamic scales for each token. + """ + x_q = torch.empty_like(x, dtype=torch.int8) + x_q, activation_scales = sq_format.quantize_op(x, x_q) + return x_q, activation_scales + + def _dequantize(self, + x_q: torch.Tensor, + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor, + logical_widths: List[int], + dtype: torch.dtype, + sq_format: SmoothQuantFormat) -> torch.Tensor: + """Dequantize activations. + + Args: + x_q: quantized activations. + dynamic_scales: Optional dynamic scales. + static_scales: Static dequantization scales. + logical_widths: Width of each logical activation (for QKV case). + dtype: Datatype to dequantize to. + Returns: + x_dq: dequantized activation at output_dtype precision + """ + # Split X_q and X_dq buffer into logical activations (for QKV case). + x_q_split = x_q.split(logical_widths, dim=-1) + x_dq = torch.empty_like(x_q, dtype=dtype) + x_dq_split = x_dq.split(logical_widths, dim=-1) + # Dequantize in place and return. + sq_format.dequantize_op(x_q_split, x_dq_split, dynamic_scales, static_scales) + return x_dq + + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward method. Computes Q --> GEMM --> DQ. + + Args: + weigths: Dictionary of weights, scales, and metadata. + x: Input in floating point precision. + bias: Optional bias. + Returns: + a_dq: Dequantized activation at floating point precision. + """ + if bias is not None: + raise NotImplementedError + weight_q = weights["weight"] + static_scales = weights["dequant_scale"] + logical_widths = weights["logical_widths"] + sq_format = weights["sq_format"] + + # Q + x_q, activation_scales = self._quantize(x, sq_format) + + # GEMM + x_q = x_q.view(-1, x_q.shape[-1]) + a_q = torch.empty((x_q.shape[0], weight_q.shape[0]), dtype=torch.int32, device="cuda") + self.i8cugemm.linear_a8_w8_o32_(x_q, weight_q, a_q) + a_q = a_q.view(*x_q.shape[:-1], -1) + + # DQ + return self._dequantize(a_q, activation_scales, static_scales, logical_widths, x.dtype, sq_format) diff --git a/vllm/model_executor/layers/quantization/smoothquant/formats.py b/vllm/model_executor/layers/quantization/smoothquant/formats.py new file mode 100644 index 000000000000..b8ddd642c888 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/formats.py @@ -0,0 +1,100 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type + +import torch + +from vllm._C import ops + + +class SmoothQuantFormat(ABC): + @abstractmethod + def dequantize_op(self, + x_qs: List[torch.Tensor], + x_dqs: List[torch.Tensor], + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor) -> None: + """Dequantize the activations. x_dq is updated in place. + + Args: + x_qs: List of N quantized activations. + x_dqs: List of N buffers to fill with dequantized values. + dynamic_scales: Optional dynamic scales for dequantization. + static_scales: Static scales for dequantization. N values. + """ + raise NotImplementedError + + + @abstractmethod + def quantize_op(self, + x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Quantize the input and (optionally compute dequant scales). + + Args: + x: Input data in floating point format. + x_q: Buffer for quantized inputs. + Returns: + x_q: Quantized input. + activation_scales: Optional dynamic scales for the activations. + """ + raise NotImplementedError + + +class SmoothQuantDynamicPerToken(SmoothQuantFormat): + def dequantize_op(self, + x_qs: List[torch.Tensor], + x_dqs: List[torch.Tensor], + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor) -> None: + """Notes: + dynamic_scales: N scales for N tokens in the activation. + static_scales: K scales for K logical activations (equals just w_scale). + """ + if dynamic_scales is None: + raise ValueError + + # Dequantize each logical activation. + # TODO: test this for case when logical_widths > 1 (may need to reshape) + for x_dq, x_q, dynamic_scale, static_scale in zip( + x_dqs, x_qs, dynamic_scales, static_scales): + + # Dequantize (updates x_dq in place). + ops.dequant(x_dq, x_q, dynamic_scale, static_scale) + + + def quantize_op(self, + x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Notes: + Returns quantized activaiton and dynamic activation scales. + """ + activation_scales = torch.empty(x.numel() // x.shape[-1], dtype=x.dtype, device=x.device) + ops.quant(x_q, x, activation_scales) + return x_q, activation_scales + + +class SmoothQuantStaticPerTensor(SmoothQuantFormat): + def dequantize_op(self, + x_qs: List[torch.Tensor], + x_dqs: List[torch.Tensor], + dynamic_scales: Optional[torch.Tensor], + static_scales: torch.Tensor) -> None: + """Notes: + dynamic_scales: None + static_scales: K scales for K logical activations (equals w_scale * a_scale). + """ + if dynamic_scales is not None: + raise ValueError + + # Dequantize each logical activation. + for xdq, xq, static_scale in zip(x_dqs, x_qs, static_scales): + ops.dequant(xdq, xq, static_scale) + + def quantize_op(self, + x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Notes: + Returns quantized activaiton and no dynamic scales. + """ + ops.quant(x_q, x, 1.0) + return x_q, None diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1..893e6781089d 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,17 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights( + self, + layer_name: str, + input_size_per_partition: int, + output_sizes_per_partition: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index b191dc4009b5..de7910c4860b 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -46,10 +46,6 @@ def _get_model_architecture( def get_architecture_class_name(model_config: ModelConfig) -> str: return _get_model_architecture(model_config)[1] -def _is_support_smoothquant(model_config: ModelConfig) -> bool: - architectures = getattr(model_config.hf_config, "architectures", []) - supported_archs = ModelRegistry.get_supported_smoothquant_archs() - return any(arch in supported_archs for arch in architectures) def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: @@ -82,10 +78,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device(device_config.device): - if _is_support_smoothquant(model_config): - model = model_class(model_config.hf_config, linear_method, - quant_config) - elif hasattr(model_class, "supported_lora_modules"): + if hasattr(model_class, "supported_lora_modules"): model = model_class(model_config.hf_config, linear_method, lora_config) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1fffbc5fa30c..0b6c75705764 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,21 +29,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.layers.quantization.smoothquant import ( - Int8GEMM, - SQLinearMethod, - SQLinearMethodQKV, - SQLinearMethodOProj, - SQLinearMethodGateUpProj, - SQLinearMethodDownProj) -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) - -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -58,6 +49,7 @@ from vllm.sequence import SamplerOutput from vllm.utils import is_hip + class LlamaMLP(nn.Module): def __init__( @@ -65,43 +57,25 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + parent_name: str, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.hidden_size = hidden_size - self.use_int8 = quant_config is not None and quant_config.get_name( - ) == "smoothquant" - - gate_up_linear_method = linear_method - if self.use_int8: - # override gate_up linear method - assert isinstance(linear_method, SQLinearMethod) - gate_up_linear_method = SQLinearMethodGateUpProj( - gemm=Int8GEMM, - quant_dtype=torch.int8, - dequant_dtype=torch.float) self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=False, - linear_method=gate_up_linear_method) + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - - down_proj_linear_method = linear_method - if self.use_int8: - # override gate_up linear method - assert isinstance(linear_method, SQLinearMethod) - down_proj_linear_method = SQLinearMethodDownProj( - gemm=Int8GEMM, - quant_dtype=torch.int8, - dequant_dtype=torch.float) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=down_proj_linear_method) - self.act_fn = SiluAndMul() def forward(self, x): @@ -110,6 +84,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class LlamaAttention(nn.Module): def __init__( @@ -117,42 +92,36 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, + parent_name: str, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: super().__init__() self.hidden_size = hidden_size - self.tp_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // self.tp_size + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - self.default_dtype = torch.get_default_dtype() - - if self.total_num_kv_heads >= self.tp_size: + if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % self.tp_size == 0 + assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert self.tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.use_int8 = quant_config is not None and quant_config.get_name( - ) == "smoothquant" - # Needs to be ironed out!! - self.use_per_token_quant = self.use_int8 # This will be overwritten by model initialization if we are using it. # N.B. currently we only support per tensor scalar scaling factors @@ -163,55 +132,35 @@ def __init__( # scaling_factor = tensor_amax / FPtype_max self.kv_scale = 1.0 - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - - qkv_linear_method = linear_method - if self.use_int8: - # override qkv linear method - assert isinstance(linear_method, SQLinearMethod) - qkv_linear_method = SQLinearMethodQKV( - gemm=Int8GEMM, - qkv_sizes=(self.q_size, self.kv_size, self.kv_size), - quant_dtype=torch.int8, - dequant_dtype=self.rotary_emb.cos_sin_cache.dtype) self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, - linear_method=qkv_linear_method, + linear_method=linear_method, ) - - o_proj_linear_method = linear_method - if self.use_int8: - # override o_proj linear method - assert isinstance(linear_method, SQLinearMethod) - o_proj_linear_method = SQLinearMethodOProj( - gemm=Int8GEMM, - use_per_token_quant=True, # TODO (varun) : Read from config - quant_dtype = torch.int8, - dequant_dtype= torch.float) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, - linear_method=o_proj_linear_method, + linear_method=linear_method, ) - self.attn = Attention( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) def forward( self, @@ -228,19 +177,17 @@ def forward( output, _ = self.o_proj(attn_output) return output + class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + parent_name: str, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_int8 = quant_config is not None and quant_config.get_name( - ) == "smoothquant" - self.tp_size = get_tensor_model_parallel_world_size() rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -254,8 +201,8 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + parent_name=f"{parent_name}.self_attn", linear_method=linear_method, - quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, ) @@ -263,8 +210,8 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + parent_name=f"{parent_name}.mlp", linear_method=linear_method, - quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -279,7 +226,6 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention if residual is None: residual = hidden_states @@ -287,7 +233,6 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -296,8 +241,8 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, - residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -308,8 +253,7 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, + lora_config: Optional[LoRAConfig] = None ) -> None: super().__init__() self.config = config @@ -324,8 +268,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config, + parent_name=f"model.layers.{idx}", + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -390,15 +336,12 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, - quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.quant_config = quant_config - self.model = LlamaModel(config, linear_method, lora_config=lora_config, - quant_config = quant_config) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -447,9 +390,6 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - # For SmoothQuant - int8_fusion = self.quant_config is not None and \ - self.quant_config.get_name() == "smoothquant" stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -459,8 +399,13 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -468,26 +413,6 @@ def load_weights(self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - # bias is useless for llama - if "bias" in name: - continue - # load dequant scale for qkv_proj and gate_up_proj - if int8_fusion: - is_fusion_scale = False - if "scale" in name: - for (param_name, weight_name, _) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - prefix = weight_name.split('_')[0] - suffix = name.split('.')[-1] - new_name = prefix + '_' + suffix - param = params_dict[name.replace(suffix, new_name)] - param.copy_(loaded_weight) - is_fusion_scale = True - break - if is_fusion_scale: - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue From c109d5f05abee4cb8f75b476847402d97066e581 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 11 Apr 2024 18:52:09 +0000 Subject: [PATCH 5/8] refactor experiments.sh and profiling fixes --- examples/offline_profile.py | 60 +++++-------- experiments.sh | 173 +++++++++++------------------------- 2 files changed, 73 insertions(+), 160 deletions(-) diff --git a/examples/offline_profile.py b/examples/offline_profile.py index da0b700909d2..ab728e01ce54 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -24,11 +24,15 @@ class ProfileContext: max_num_batched_tokens: int prompt_len: int batch_size: int + dtype: str tensor_parallel_size: int - kv_cache_dtype: str - kv_quant_params_path: str allow_cuda_graphs: bool +def get_dtype(dtype:str): + if dtype == "torch.float": + return torch.float + else: + return dtype def run_profile(context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]): @@ -37,35 +41,21 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], print(f" {key} = {value}") # Create sampling params - sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8) + sampling_params = SamplingParams(temperature=0.0, top_k=1, max_tokens=8) # Sparsity is in the future # Create LLM - llm = None - if context.kv_quant_params_path is not None: - llm = LLM( - model=context.model, - tokenizer=context.tokenizer if context.tokenizer is not None else context.model, - revision=context.model_revision, - enforce_eager=not context.allow_cuda_graphs, - tensor_parallel_size=context.tensor_parallel_size, - gpu_memory_utilization=0.9, - max_model_len=context.max_seq_len, - quantization=context.quantization, - max_num_batched_tokens=context.max_num_batched_tokens, - kv_cache_dtype=context.kv_cache_dtype, - kv_quant_params_path=context.kv_quant_params_path) - else: - llm = LLM( - model=context.model, - tokenizer=context.tokenizer if context.tokenizer is not None else context.model, - revision=context.model_revision, - enforce_eager=not context.allow_cuda_graphs, - tensor_parallel_size=context.tensor_parallel_size, - gpu_memory_utilization=0.9, - max_model_len=context.max_seq_len, - quantization=context.quantization, - max_num_batched_tokens=context.max_num_batched_tokens) + llm = LLM( + model=context.model, + tokenizer=context.tokenizer if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_seq_len, + quantization=context.quantization, + dtype=get_dtype(context.dtype), + max_num_batched_tokens=context.max_num_batched_tokens) batch_size = context.batch_size prompt_len = context.prompt_len @@ -208,18 +198,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], help="The method used to quantize the model weights, " "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"") parser.add_argument( - "--kv-cache-dtype", - type=str, - choices=['auto', 'fp8_e5m2', 'int8'], - default='auto', - help= - 'Data type for kv cache storage. If "auto", will use model data type.') - parser.add_argument( - "--kv-quant-params-path", + "--dtype", type=str, - default=None, - help='Path to scales and zero points of kv cache quantizaiton ' - 'when kv cache dtype is int8.') + default='auto', + help="model dtype") parser.add_argument( "--max-seq-len", type=int, diff --git a/experiments.sh b/experiments.sh index efdd85d523f4..a45229d31499 100755 --- a/experiments.sh +++ b/experiments.sh @@ -2,11 +2,14 @@ set -e set -u -set -x +#set -x # global args -model_path=/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b -tokenizer=/home/varun/code/vllm/llama-models/Nous-Hermes-Llama2-13b/ +model_path=NousResearch/Nous-Hermes-Llama2-13b +quant_model_path=nm-testing/Nous-Hermes-Llama2-13b-smoothquant + +# model generation args +enforce_eager=True max_seq_len=2048 max_num_batched_tokens=7000 tensor_parallel_size=1 @@ -15,47 +18,46 @@ tensor_parallel_size=1 prefill_prompt_len=512 decode_batch_sizes=(1 2 8 16 32 64 128) -# quantization specific args -quant_model_path=$model_path/quantized_model/llama-13b/Nous-Hermes-Llama2-13b-smoothquant/ - -# kv quant specific args -kv_cache_dtype=int8 -kv_quant_params_path=/home/varun/code/vllm/act_quant_data/exported_kv/ - -run_quantized_prefill() { - output_directory=$1 +run_prefill() { + model=$1 + desc=$2 + dtype=$3 + output_directory=$4 now=`date +"%Y-%m-%d-%I-%M-%S"` - out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_quantized-${now} + out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_${desc}-${now} - echo "Running prefill ${prefill_prompt_len} store at ${out_base}" - python3 examples/offline_profile.py --model $quant_model_path \ - --tokenizer $tokenizer \ + echo "Running prefill ${prefill_prompt_len} model ${model} desc ${desc} dtype ${dtype} store at ${out_base}" + python3 examples/offline_profile.py --model $model \ --batch-size 1 \ --prompt-len $prefill_prompt_len \ --quantization smoothquant \ --max-seq-len $max_seq_len \ + --dtype $dtype \ --max-num-batched-tokens $max_num_batched_tokens \ --tensor-parallel-size $tensor_parallel_size \ --json $out_base \ --csv $out_base > ${out_base}_stdout.txt 2>&1 } -run_quantized_decode() { - output_directory=$1 +run_decode() { + model=$1 + desc=$2 + dtype=$3 + output_directory=$4 for bs in "${decode_batch_sizes[@]}" do now=`date +"%Y-%m-%d-%I-%M-%S"` - out_base=${output_directory}/decode_bs_${bs}_llama13_quantized-${now} + out_base=${output_directory}/decode_bs_${bs}_llama13_${desc}-${now} - echo "Running decode bs ${bs} store at ${out_base}" - python3 examples/offline_profile.py --model $quant_model_path \ - --tokenizer $tokenizer \ + echo "Running decode bs ${bs} model ${model} desc ${desc} dtype ${dtype} store at ${out_base}" + python3 examples/offline_profile.py --model $model \ --batch-size $bs \ --prompt-len 1 \ --quantization smoothquant \ --max-seq-len $max_seq_len \ + --dtype $dtype --max-num-batched-tokens $max_num_batched_tokens \ --tensor-parallel-size $tensor_parallel_size \ --json $out_base \ @@ -63,90 +65,6 @@ run_quantized_decode() { done } -run_kv_quant_prefill() { - - output_directory=$1 - now=`date +"%Y-%m-%d-%I-%M-%S"` - out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_kv_quant-${now} - - echo "Running prefill ${prefill_prompt_len} store at ${out_base}" - - python3 examples/offline_profile.py --model $model_path \ - --tokenizer $tokenizer \ - --batch-size 1 \ - --prompt-len $prefill_prompt_len \ - --kv-cache-dtype $kv_cache_dtype \ - --kv-quant-params-path $kv_quant_params_path \ - --max-seq-len $max_seq_len \ - --max-num-batched-tokens $max_num_batched_tokens \ - --tensor-parallel-size $tensor_parallel_size \ - --json $out_base \ - --csv $out_base > ${out_base}_stdout.txt 2>&1 -} - -run_kv_quant_decode() { - output_directory=$1 - - for bs in "${decode_batch_sizes[@]}" - do - now=`date +"%Y-%m-%d-%I-%M-%S"` - out_base=${output_directory}/decode_bs_${bs}_llama13_kv_quant-${now} - - echo "Running decode bs ${bs} store at ${out_base}" - python3 examples/offline_profile.py --model $model_path \ - --tokenizer $tokenizer \ - --batch-size $bs \ - --prompt-len 1 \ - --kv-cache-dtype $kv_cache_dtype \ - --kv-quant-params-path $kv_quant_params_path \ - --max-seq-len $max_seq_len \ - --max-num-batched-tokens $max_num_batched_tokens \ - --tensor-parallel-size $tensor_parallel_size \ - --json $out_base \ - --csv $out_base > ${out_base}_stdout.txt 2>&1 - done -} - -run_fp16_prefill() { - - output_directory=$1 - now=`date +"%Y-%m-%d-%I-%M-%S"` - out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_fp16-${now} - - echo "Running prefill ${prefill_prompt_len} store at ${out_base}" - - python3 examples/offline_profile.py --model $model_path \ - --tokenizer $tokenizer \ - --batch-size 1 \ - --prompt-len $prefill_prompt_len \ - --max-seq-len $max_seq_len \ - --max-num-batched-tokens $max_num_batched_tokens \ - --tensor-parallel-size $tensor_parallel_size \ - --json $out_base \ - --csv $out_base > ${out_base}_stdout.txt 2>&1 -} - -run_fp16_decode() { - output_directory=$1 - - for bs in "${decode_batch_sizes[@]}" - do - now=`date +"%Y-%m-%d-%I-%M-%S"` - out_base=${output_directory}/decode_bs_${bs}_llama13_fp16-${now} - - echo "Running decode bs ${bs} store at ${out_base}" - python3 examples/offline_profile.py --model $model_path \ - --tokenizer $tokenizer \ - --batch-size $bs \ - --prompt-len 1 \ - --max-seq-len $max_seq_len \ - --max-num-batched-tokens $max_num_batched_tokens \ - --tensor-parallel-size $tensor_parallel_size \ - --json $out_base \ - --csv $out_base > ${out_base}_stdout.txt 2>&1 - done -} - ## Arg parser and invocation usage() { @@ -155,17 +73,21 @@ usage() { echo echo "usage: ${0} " echo - echo " -t - pass in w8a8 or kv_quant" + echo " -t - pass in quant or base" + echo " -d - description" + echo " -m - model data type" echo " -n - pass in num_benchmark_iterations" echo " -o - out directory" echo } -exp_type="" # should be either w8a8 or kv_quant +exp_type="" # should be either quant or base num_benchmark_iterations=1 output_directory="./" +desc="" +dtype="" -while getopts ':t:n:o:h:' OPT; do +while getopts ':t:n:o:d:m:h:' OPT; do case "${OPT}" in t) exp_type="${OPTARG}" @@ -176,6 +98,12 @@ while getopts ':t:n:o:h:' OPT; do o) output_directory="${OPTARG}" ;; + d) + desc="${OPTARG}" + ;; + m) + dtype="${OPTARG}" + ;; h) usage exit 1 @@ -183,29 +111,32 @@ while getopts ':t:n:o:h:' OPT; do esac done -if [ "$exp_type" != "w8a8" -a "$exp_type" != "kv_quant" -a "$exp_type" != "fp16" ]; +if [ "$exp_type" != "quant" -a "$exp_type" != "base" ]; then echo "Invalid arg $exp_type" usage exit 1 fi +if [[ "${output_directory}" == "" || "${desc}" == "" || "${dtype}" == "" ]]; +then + echo "Either output_directory or desc is not set" + usage + exit 1 +fi + + for i in $(seq 1 $num_benchmark_iterations); do echo "Running benchmark iteration ${i} ..." - if [[ "${exp_type}" == "w8a8" ]]; - then - run_quantized_prefill $output_directory - run_quantized_decode $output_directory - fi - if [[ "${exp_type}" == "kv_quant" ]]; + if [[ "${exp_type}" == "quant" ]]; then - run_kv_quant_prefill $output_directory - run_kv_quant_decode $output_directory + run_prefill $quant_model_path "${exp_type}-${desc}" $dtype $output_directory + #run_decode $quant_model_path "${exp_type}-${desc}" $dtype $output_directory fi - if [[ "${exp_type}" == "fp16" ]]; + if [[ "${exp_type}" == "base" ]]; then - run_fp16_prefill $output_directory - run_fp16_decode $output_directory + run_prefill $model_path "${exp_type}-${desc}" $dtype $output_directory + #run_decode $model_path "${exp_type}-${desc}" $dtype $output_directory fi done From cb7b2e107d0ddc8792d270e626ae1debd821c4a9 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Mon, 22 Apr 2024 17:25:36 -0400 Subject: [PATCH 6/8] [3/N] rs/vllm quantization - Adding models to non-uniform refactor (#189) Since we changes the `LinearMethod` interface to require `layer_name`, we need to update each model.py to plumb this information through the models. We need to do this, because we need to pass the `layer_name` to `LinearMethodBase.create_weights`, such that we have have non-uniform quantization / compression (as we need to be able to consult the quantization config to determine what the weights / format should look like and we use the layer name to decide this So far, have updated: - `llama` - `gemma` - `phi-2` - `gpt2` - `starcoder2` - `qwen2` - `deepseek` and `deepseekMoE` - `baichuan` To test: ```bash python3 examples/simple_test.py --help ``` To Update: - Pass `layer_name` to `QKVParallelLinear`, `MergedColumnParallelLinear`, `ColumnParallelLinear`, `RowParallelLinear` by plumbing `parent_name` through from `Model` --> `DecoderLayer` --> `MLP` / `SelfAttention` --> `Layer` - Updated `weight_loader` with `linear_method.maybe_update_name` --- examples/simple_test.py | 24 +++++++-- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/models/baichuan.py | 45 ++++++++++++----- vllm/model_executor/models/deepseek.py | 58 +++++++++++++++------- vllm/model_executor/models/gemma.py | 43 +++++++++++----- vllm/model_executor/models/gpt2.py | 42 +++++++++++----- vllm/model_executor/models/llama.py | 18 +++---- vllm/model_executor/models/phi.py | 43 +++++++++++----- vllm/model_executor/models/qwen2.py | 63 +++++++++++++++--------- vllm/model_executor/models/starcoder2.py | 43 ++++++++++------ 10 files changed, 260 insertions(+), 121 deletions(-) diff --git a/examples/simple_test.py b/examples/simple_test.py index dcf8b8c7ed1e..81cb82e928e3 100644 --- a/examples/simple_test.py +++ b/examples/simple_test.py @@ -6,6 +6,25 @@ "tinyllama-marlin": "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "tinyllama-gptq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "tinyllama-awq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ", + "gemma-fp16": "google/gemma-1.1-2b-it", + "gemma-awq": "TechxGenus/gemma-1.1-2b-it-AWQ", + "gemma-gptq": "TechxGenus/gemma-1.1-2b-it-GPTQ", + "phi-2-fp16": "abacaj/phi-2-super", + "phi-2-marlin": "neuralmagic/phi-2-super-marlin", + "deepseek-fp16": "deepseek-ai/deepseek-coder-1.3b-instruct", + "deepseek-gptq": "TheBloke/deepseek-coder-1.3b-instruct-GPTQ", + "deepseek-awq": "TheBloke/deepseek-coder-1.3b-instruct-AWQ", + "deepseek-moe-fp16": "deepseek-ai/deepseek-moe-16b-chat", + "baichuan-fp16": "baichuan-inc/Baichuan2-7B-Chat", + "baichuan-gptq": "csdc-atl/Baichuan2-7B-Chat-GPTQ-Int4", + "qwen-fp16": "Qwen/Qwen1.5-1.8B", + "qwen-gptq": "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4", + "qwen-awq": "Qwen/Qwen1.5-1.8B-Chat-AWQ", + "gpt2-fp16": "openai-community/gpt2", + "gpt2-gptq": "etyacke/GPT2-GPTQ-int4", + "starcoder2-fp16": "bigcode/starcoder2-3b", + "starcoder2-gptq": "TechxGenus/starcoder2-3b-GPTQ", + "starcoder2-awq": "TechxGenus/starcoder2-3b-AWQ", } parser = argparse.ArgumentParser() @@ -21,14 +40,11 @@ print(f"Using model_id = {model_id}") messages=[{ - "role": "system", - "content": "You are a helpful assistant." -}, { "role": "user", "content": "What is deep learning?" }] -model = LLM(model_id, enforce_eager=True, max_model_len=2048, tensor_parallel_size=args.tensor_parallel_size, dtype="float16") +model = LLM(model_id, enforce_eager=True, max_model_len=1024, tensor_parallel_size=args.tensor_parallel_size, dtype="float16", trust_remote_code=True) prompt = model.llm_engine.tokenizer.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) out = model.generate(prompt, SamplingParams(max_tokens=50)) print(f"\n-----prompt\n{prompt}") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 2598156bbed3..20891acbbfb5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -194,7 +194,7 @@ def __init__( self.output_size_per_partition = divide(self.output_size, tp_size) self.output_sizes_per_partition = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. - if self.output_sizes is not None: + if hasattr(self, "output_sizes"): self.output_sizes_per_partition = [ divide(output_size, tp_size) for output_size in self.output_sizes ] diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index fa5a27b5a697..c437f28d1fcf 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -75,6 +75,7 @@ class BaiChuanMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -82,13 +83,17 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -106,6 +111,7 @@ class BaiChuanAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, position_embedding: str, @@ -128,16 +134,18 @@ def __init__( # pylint: disable=invalid-name self.W_pack = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_heads, + layer_name=f"{parent_name}.W_pack", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=False, linear_method=linear_method, ) @@ -183,6 +191,7 @@ def forward( class BaiChuanDecoderLayer(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, position_embedding: str, linear_method: Optional[LinearMethodBase] = None): @@ -192,6 +201,7 @@ def __init__(self, max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = BaiChuanAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, position_embedding=position_embedding, @@ -200,6 +210,7 @@ def __init__(self, linear_method=linear_method, ) self.mlp = BaiChuanMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -255,8 +266,12 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, linear_method) - for _ in range(config.num_hidden_layers) + BaiChuanDecoderLayer( + parent_name=f"model.layers.{idx}", + config=config, + position_embedding=position_embedding, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -353,6 +368,10 @@ def load_weights(self, params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4eba..973f70803b87 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -55,6 +55,7 @@ class DeepseekMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -63,14 +64,18 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method, - reduce_results=reduce_results) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + linear_method=linear_method, + reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -87,6 +92,7 @@ class DeepseekMoE(nn.Module): def __init__( self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ): @@ -102,7 +108,8 @@ def __init__( f"the number of experts {self.n_routed_experts}.") self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, + DeepseekMLP(parent_name=f"{parent_name}.experts", + hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, @@ -111,8 +118,9 @@ def __init__( ]) self.pack_params() - self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, + self.gate = ReplicatedLinear(layer_name=f"{parent_name}.gate", + input_size=config.hidden_size, + output_size=self.n_routed_experts, bias=False, linear_method=None) @@ -120,6 +128,7 @@ def __init__( intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) self.shared_experts = DeepseekMLP( + parent_name=f"{parent_name}.shared_experts", hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, @@ -173,6 +182,7 @@ class DeepseekAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -205,17 +215,19 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=False, linear_method=linear_method, ) @@ -251,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: PretrainedConfig, layer_idx: int, linear_method: Optional[LinearMethodBase] = None, @@ -262,6 +275,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = DeepseekAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -273,9 +287,11 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + self.mlp = DeepseekMoE(parent_name=f"{parent_name}.mlp", + config=config, linear_method=linear_method) else: self.mlp = DeepseekMLP( + parent_name=f"{parent_name}.mlp", hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -331,10 +347,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, - layer_idx, + DeepseekDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, layer_idx=idx, linear_method=linear_method) - for layer_idx in range(config.num_hidden_layers) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -417,6 +433,10 @@ def load_weights(self, load_format, revision, fall_back_to_pt=False): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 08609532b8b3..9476df293f84 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -75,6 +75,7 @@ class GemmaMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: Optional[str] = None, @@ -83,13 +84,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): @@ -102,6 +107,7 @@ def forward(self, x): class GemmaAttention(nn.Module): def __init__(self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -132,16 +138,18 @@ def __init__(self, self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=False, linear_method=linear_method, ) @@ -177,12 +185,14 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: GemmaConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = GemmaAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -192,6 +202,7 @@ def __init__( linear_method=linear_method, ) self.mlp = GemmaMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -247,8 +258,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) + GemmaDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -364,6 +377,10 @@ def load_weights(self, loaded_params = set() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f816a9996be..48963eab4868 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -45,6 +45,7 @@ class GPT2Attention(nn.Module): def __init__( self, + parent_name: str, config: GPT2Config, linear_method: Optional[LinearMethodBase] = None, ): @@ -59,15 +60,17 @@ def __init__( self.scale = self.head_dim**-0.5 self.c_attn = QKVParallelLinear( - self.hidden_size, - self.head_dim, - total_num_heads, + layer_name=f"{parent_name}.c_attn", + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=total_num_heads, bias=True, linear_method=linear_method, ) self.c_proj = RowParallelLinear( - self.hidden_size, - self.hidden_size, + layer_name=f"{parent_name}.c_proj", + input_size=self.hidden_size, + output_size=self.hidden_size, bias=True, linear_method=linear_method, ) @@ -90,6 +93,7 @@ class GPT2MLP(nn.Module): def __init__( self, + parent_name: str, intermediate_size: int, config: GPT2Config, linear_method: Optional[LinearMethodBase] = None, @@ -97,14 +101,16 @@ def __init__( super().__init__() hidden_size = config.hidden_size self.c_fc = ColumnParallelLinear( - hidden_size, - intermediate_size, + layer_name=f"{parent_name}.c_fc", + input_size=hidden_size, + output_size=intermediate_size, bias=True, linear_method=linear_method, ) self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, + layer_name=f"{parent_name}.c_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=True, linear_method=linear_method, ) @@ -123,6 +129,7 @@ class GPT2Block(nn.Module): def __init__( self, + parent_name: str, config: GPT2Config, linear_method: Optional[LinearMethodBase] = None, ): @@ -132,9 +139,12 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, linear_method) + self.attn = GPT2Attention(parent_name=f"{parent_name}.attn", + config=config, linear_method=linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, linear_method) + self.mlp = GPT2MLP(parent_name=f"{parent_name}.mlp", + intermediate_size=inner_dim, + config=config, linear_method=linear_method) def forward( self, @@ -176,8 +186,10 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, linear_method) - for _ in range(config.num_hidden_layers) + GPT2Block( + parent_name=f"transformer.h.{idx}", + config=config, linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -248,6 +260,10 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0b6c75705764..e396cd44a51f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,10 +54,10 @@ class LlamaMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, - parent_name: str, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() @@ -89,10 +89,10 @@ class LlamaAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, - parent_name: str, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, @@ -134,7 +134,7 @@ def __init__( self.qkv_proj = QKVParallelLinear( layer_name=f"{parent_name}.qkv_proj", - hidden_size=hidden_size, + hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, @@ -144,7 +144,7 @@ def __init__( self.o_proj = RowParallelLinear( layer_name=f"{parent_name}.o_proj", input_size=self.total_num_heads * self.head_dim, - output_size=hidden_size, + output_size=self.hidden_size, bias=bias, linear_method=linear_method, ) @@ -182,8 +182,8 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, - config: LlamaConfig, parent_name: str, + config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() @@ -194,6 +194,7 @@ def __init__( 8192) sliding_window = getattr(config, "sliding_window", None) self.self_attn = LlamaAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", @@ -201,16 +202,15 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - parent_name=f"{parent_name}.self_attn", linear_method=linear_method, bias=getattr(config, "bias", False), sliding_window=sliding_window, ) self.mlp = LlamaMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - parent_name=f"{parent_name}.mlp", linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, @@ -268,8 +268,8 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, - parent_name=f"model.layers.{idx}", + LlamaDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, linear_method=linear_method) for idx in range(config.num_hidden_layers) ]) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 40e068acaba7..e3379b4880be 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -63,6 +63,7 @@ class PhiAttention(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() @@ -78,15 +79,17 @@ def __init__(self, # pylint: disable=C0103 self.qkv_proj = QKVParallelLinear( - self.hidden_size, - self.head_size, - self.total_num_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=self.hidden_size, + head_size=self.head_size, + total_num_heads=self.total_num_heads, bias=True, linear_method=linear_method, ) self.dense = RowParallelLinear( - self.hidden_size, - self.hidden_size, + layer_name=f"{parent_name}.dense", + input_size=self.hidden_size, + output_size=self.hidden_size, linear_method=linear_method, ) @@ -126,6 +129,7 @@ def forward( class PhiMLP(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() @@ -134,13 +138,15 @@ def __init__(self, n_inner = n_inner if n_inner is not None else 4 * config.hidden_size self.fc1 = ColumnParallelLinear( - config.hidden_size, - n_inner, + layer_name=f"{parent_name}.fc1", + input_size=config.hidden_size, + output_size=n_inner, linear_method=linear_method, ) self.fc2 = RowParallelLinear( - n_inner, - config.hidden_size, + layer_name=f"{parent_name}.fc2", + input_size=n_inner, + output_size=config.hidden_size, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) @@ -156,13 +162,18 @@ def forward(self, hidden_states): class PhiLayer(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, linear_method) - self.mlp = PhiMLP(config, linear_method) + self.self_attn = PhiAttention( + parent_name=f"{parent_name}.self_attn", + config=config, linear_method=linear_method) + self.mlp = PhiMLP( + parent_name=f"{parent_name}.mlp", + config=config, linear_method=linear_method) def forward( self, @@ -195,8 +206,10 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, linear_method) - for _ in range(config.num_hidden_layers) + PhiLayer(parent_name=f"model.layers.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -281,6 +294,10 @@ def load_weights(self, for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8c92cd773f6b..d8ca9b4c6037 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -53,6 +53,7 @@ class Qwen2MLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -60,13 +61,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -81,15 +86,18 @@ def forward(self, x): class Qwen2Attention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - use_sliding_window: bool = False, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + parent_name: str, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + use_sliding_window: bool = False, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -114,16 +122,18 @@ def __init__(self, self.sliding_window = sliding_window if use_sliding_window else None self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=True, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, bias=False, linear_method=linear_method, ) @@ -159,6 +169,7 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: Qwen2Config, layer_idx: int, linear_method: Optional[LinearMethodBase] = None, @@ -170,6 +181,7 @@ def __init__( use_sliding_window = (config.use_sliding_window and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, @@ -179,6 +191,7 @@ def __init__( linear_method=linear_method, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -235,8 +248,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, linear_method) - for layer_idx in range(config.num_hidden_layers) + Qwen2DecoderLayer(parent_name=f"model.layers.{idx}", + config=config, layer_idx=idx, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,6 +363,10 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 50d23e0a3b6e..e3730777014b 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -46,6 +46,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, + parent_name: str, config: Starcoder2Config, linear_method: Optional[LinearMethodBase] = None): super().__init__() @@ -76,16 +77,18 @@ def __init__(self, self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( - self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=self.use_bias, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, bias=self.use_bias, linear_method=linear_method, ) @@ -122,18 +125,21 @@ def forward( class Starcoder2MLP(nn.Module): def __init__(self, + parent_name: str, config: Starcoder2Config, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.c_fc = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, + layer_name=f"{parent_name}.c_fc", + input_size=config.hidden_size, + output_size=config.intermediate_size, bias=config.use_bias, linear_method=linear_method, ) self.c_proj = RowParallelLinear( - config.intermediate_size, - config.hidden_size, + layer_name=f"{parent_name}.c_proj", + input_size=config.intermediate_size, + output_size=config.hidden_size, bias=config.use_bias, linear_method=linear_method, ) @@ -151,13 +157,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Starcoder2DecoderLayer(nn.Module): def __init__(self, + parent_name: str, config: Starcoder2Config, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, + self.self_attn = Starcoder2Attention(parent_name=f"{parent_name}.self_attn", + config=config, linear_method=linear_method) - self.mlp = Starcoder2MLP(config, linear_method=linear_method) + self.mlp = Starcoder2MLP(parent_name=f"{parent_name}.mlp", + config=config, linear_method=linear_method) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -204,8 +213,9 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, linear_method=linear_method) - for _ in range(config.num_hidden_layers) + Starcoder2DecoderLayer(parent_name=f"model.layers.{idx}", + config=config, linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -232,6 +242,7 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config + self.linear_method = linear_method self.model = Starcoder2Model(config, linear_method=linear_method) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size @@ -290,6 +301,10 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue From cc08dc44f58063ff578deba465cbd1ec7b3d480a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 24 Apr 2024 15:19:48 -0400 Subject: [PATCH 7/8] Use cutlass kernels (#202) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description: Cutlass integration. - Use cutlass gemm with epilogue fusion for dequantization - Remove all existing dequant kernels and interface - Remove cublas i8gemm files Test: Run `examples/offline_quantized_inference.py` ``` (vllm-test) varun@floppy-fan:~/code/neuralmagic-vllm (vllm-quantization-cutlass) $ python3 ./examples/offline_quantized_inference.py ... Prompt: 'Hello, my name is', Generated text: ' John and I am a recovering workaholic.\nI used to work all the time' Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government of the United States of America. The president leads the executive' Prompt: 'The capital of France is', Generated text: ' Paris.\nThe capital of France is Paris.' Prompt: 'The future of AI is', Generated text: ' here, and it’s more accessible than ever.\nThe future of AI is here,' ``` Profiling results : Prefill 512 tokens, Branch : This branch, dtype : "torch.float", model : Quantized model - [results](https://drive.google.com/file/d/1GydrBmphPTrBMujIPL9K_Y-ZauQ_8FlR/view?usp=sharing) Note that this branch is better than the [previous best](https://drive.google.com/file/d/1Ga_rpnRCYUtenBUj_BDPcZVvbIRcvRB8/view?usp=drive_link) [w8a8 upstream PR with custom fused kernels] --------- Co-authored-by: Varun Sundar Rabindranath --- CMakeLists.txt | 3 - csrc/pybind.cpp | 22 - .../quantization/smoothquant/fused_kernels.cu | 71 -- .../smoothquant/int8gemm/allocator.h | 232 ------ .../smoothquant/int8gemm/cublasAlgoMap.cc | 188 ----- .../smoothquant/int8gemm/cublasAlgoMap.h | 108 --- .../int8gemm/cublasINT8MMWrapper.cc | 676 ------------------ .../int8gemm/cublasINT8MMWrapper.h | 71 -- .../smoothquant/int8gemm/cuda_utils.cc | 45 -- .../smoothquant/int8gemm/cuda_utils.h | 158 ---- .../smoothquant/int8gemm/int8_gemm.h | 127 ---- requirements-cuda.txt | 1 + tests/kernels/test_fusion.py | 48 -- vllm/model_executor/layers/linear.py | 15 +- .../layers/quantization/smoothquant/config.py | 127 ++-- .../quantization/smoothquant/cutlass_gemm.py | 75 ++ .../quantization/smoothquant/formats.py | 58 +- 17 files changed, 128 insertions(+), 1897 deletions(-) delete mode 100644 csrc/quantization/smoothquant/int8gemm/allocator.h delete mode 100644 csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc delete mode 100644 csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h delete mode 100644 csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc delete mode 100644 csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h delete mode 100644 csrc/quantization/smoothquant/int8gemm/cuda_utils.cc delete mode 100644 csrc/quantization/smoothquant/int8gemm/cuda_utils.h delete mode 100644 csrc/quantization/smoothquant/int8gemm/int8_gemm.h create mode 100644 vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cb48c119e6c..ebea46103699 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,9 +167,6 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc" - "csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc" - "csrc/quantization/smoothquant/int8gemm/cuda_utils.cc" "csrc/quantization/smoothquant/fused_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 8c4fbdaed105..983455878023 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -1,7 +1,6 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" -#include "quantization/smoothquant/int8gemm/int8_gemm.h" #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -50,21 +49,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "fused_add_rms_norm", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); - ops.def( - "dequant", - py::overload_cast< - torch::Tensor&, - torch::Tensor&, - float>(&dequant), - "Dequant."); - ops.def( - "dequant", - py::overload_cast< - torch::Tensor&, - torch::Tensor&, - torch::Tensor&, - float>(&dequant), - "Per-token dequant."); ops.def( "quant", py::overload_cast< @@ -102,12 +86,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - pybind11::class_(ops, "I8CUGEMM") - .def(pybind11::init<>()) - .def("linear_a8_w8_o32", &I8CUGEMM::linear_a8_w8_o32) - .def("linear_a8_w8_o8", &I8CUGEMM::linear_a8_w8_o8) - .def("linear_a8_w8_o8_", &I8CUGEMM::linear_a8_w8_o8_) - .def("linear_a8_w8_o32_", &I8CUGEMM::linear_a8_w8_o32_); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/smoothquant/fused_kernels.cu b/csrc/quantization/smoothquant/fused_kernels.cu index 1e9d1acf8f47..1d23a5a0653c 100644 --- a/csrc/quantization/smoothquant/fused_kernels.cu +++ b/csrc/quantization/smoothquant/fused_kernels.cu @@ -7,27 +7,6 @@ #include "quant_utils.cuh" namespace vllm { -template -__global__ void dequant_kernel( - const int32_t* __restrict__ input, - scalar_t* __restrict__ out, - const float scale, - const int m, - const int hidden_size, - const int input_stride, - const int out_stride, - const float* __restrict__ act_scale = nullptr) { - const int tid = threadIdx.x; - const int token_idx = blockIdx.x; - float scale_ = scale; - if constexpr (use_per_token_dequant) { - scale_ = scale * act_scale[token_idx]; - } - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * out_stride + i] = - (scalar_t)(((float)input[token_idx * input_stride + i]) * scale_); - } -} template __global__ void quant_kernel( @@ -71,56 +50,6 @@ __global__ void quant_kernel( } } // namespace vllm -void dequant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - float scale) { - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); - int input_stride = input.stride(-2); - int out_stride = out.stride(-2); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] { - vllm::dequant_kernel<<>>( - input.data_ptr(), - out.data_ptr(), - scale, - num_tokens, - hidden_size, - input_stride, - out_stride); - }); -} - -void dequant( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale, - float weight_dequant_scale) { - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); - int input_stride = input.stride(-2); - int out_stride = out.stride(-2); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(out.scalar_type(), "dequant_kernel", [&] { - vllm::dequant_kernel<<>>( - input.data_ptr(), - out.data_ptr(), - weight_dequant_scale, - num_tokens, - hidden_size, - input_stride, - out_stride, - scale.data_ptr()); - }); -} - void quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] diff --git a/csrc/quantization/smoothquant/int8gemm/allocator.h b/csrc/quantization/smoothquant/int8gemm/allocator.h deleted file mode 100644 index 79be2e99e29c..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/allocator.h +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Memory Allocator - **/ - -#pragma once - -#include "cuda_utils.h" -#include -#include -#include - -#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 -#define CUDA_MEMORY_POOL_DISABLED -#endif - -enum class AllocatorType { CUDA, TF, TH }; - -enum class ReallocType { - INCREASE, - REUSE, - DECREASE, -}; - -class IAllocator { -public: - virtual ~IAllocator(){}; - - virtual void *malloc(size_t size, const bool is_set_zero = true, - bool is_host = false) = 0; - virtual void free(void **ptr, bool is_host = false) const = 0; - virtual void setStream(cudaStream_t stream) = 0; - virtual cudaStream_t returnStream() = 0; - virtual void memSet(void *ptr, const int val, const size_t size) = 0; - - template - void *reMalloc(T *ptr, size_t size, const bool is_set_zero = true, - bool is_host = false) { - // FT_LOG_DEBUG(__PRETTY_FUNCTION__); - size = ((size + 31) / 32) * 32; // make the buffer align with 32 bytes - void *void_ptr = (void *)ptr; - void *ptr_address = getAddress(void_ptr); - if (isExist(ptr_address)) { - ReallocType realloc_type = isReMalloc(ptr_address, size); - if (realloc_type == ReallocType::INCREASE) { - // FT_LOG_DEBUG("ReMalloc the buffer %p since it is too small.", - // void_ptr); - free((void **)(&void_ptr), is_host); - return malloc(size, is_set_zero, is_host); - } -#if !defined(CUDA_MEMORY_POOL_DISABLED) - else if (realloc_type == ReallocType::DECREASE) { - // FT_LOG_DEBUG("ReMalloc the buffer %p to release unused memory to - // memory pools.", void_ptr); - free((void **)(&void_ptr), is_host); - return malloc(size, is_set_zero, is_host); - } -#endif - else { - // FT_LOG_DEBUG("Reuse original buffer %p with size %d and do nothing - // for reMalloc.", void_ptr, size); - if (is_set_zero) { - memSet(void_ptr, 0, size); - } - return void_ptr; - } - } else { - // FT_LOG_DEBUG("Cannot find buffer %p, mallocing new one.", void_ptr); - return malloc(size, is_set_zero, is_host); - } - } - -protected: - virtual bool isExist(void *address) const = 0; - virtual ReallocType isReMalloc(void *address, size_t size) const = 0; - - void *getAddress(void *ptr) const { return ptr; } -}; - -template class Allocator; - -template <> class Allocator : public IAllocator { -private: - const int device_id_; - cudaStream_t stream_ = 0; // initialize as default stream - std::unordered_map *pointer_mapping_; - - bool isExist(void *address) const { - return pointer_mapping_->count(address) > 0; - } - ReallocType isReMalloc(void *address, size_t size) const { - FT_CHECK(isExist(address)); - if (pointer_mapping_->at(address) < size) { - return ReallocType::INCREASE; - } else if (pointer_mapping_->at(address) == size) { - return ReallocType::REUSE; - } else { - return ReallocType::DECREASE; - } - } - -public: - Allocator(int device_id) : device_id_(device_id) { - // FT_LOG_DEBUG(__PRETTY_FUNCTION__); - pointer_mapping_ = new std::unordered_map(); -#if defined(CUDA_MEMORY_POOL_DISABLED) - // FT_LOG_WARNING( - // "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync - // cudaMalloc/Free." "Note this may lead to hang with NCCL kernels - // launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); -#else - int device_count = 1; - check_cuda_error(cudaGetDeviceCount(&device_count)); - cudaMemPool_t mempool; - check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); - cudaMemAccessDesc desc = {}; - int peer_access_available = 0; - for (int i = 0; i < device_count; i++) { - if (i == device_id) { - continue; - } - check_cuda_error( - cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); - if (!peer_access_available) { - // FT_LOG_WARNING("Device " + std::to_string(device_id) + " peer access - // Device " + std::to_string(i) - // + " is not available."); - continue; - } - desc.location.type = cudaMemLocationTypeDevice; - desc.location.id = i; - desc.flags = cudaMemAccessFlagsProtReadWrite; - check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); - } - // set memory pool threshold to avoid shrinking the pool - uint64_t setVal = UINT64_MAX; - check_cuda_error(cudaMemPoolSetAttribute( - mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); -#endif - } - - virtual ~Allocator() { - // FT_LOG_DEBUG(__PRETTY_FUNCTION__); - while (!pointer_mapping_->empty()) { - free((void **)(&pointer_mapping_->begin()->first)); - } - delete pointer_mapping_; - } - - void setStream(cudaStream_t stream) { stream_ = stream; } - - cudaStream_t returnStream() { return stream_; }; - - void *malloc(size_t size, const bool is_set_zero = true, - bool is_host = false) { - // FT_LOG_DEBUG(__PRETTY_FUNCTION__); - if (size == 0) { - return nullptr; - } - void *ptr = nullptr; - int o_device = 0; - - check_cuda_error(getSetDevice(device_id_, &o_device)); - if (is_host) { - check_cuda_error(cudaMallocHost(&ptr, (size_t)(ceil(size / 32.)) * 32)); - } else { -#if defined(CUDA_MEMORY_POOL_DISABLED) - check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32)); -#else - check_cuda_error( - cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); -#endif - } - if (is_set_zero) { - check_cuda_error( - cudaMemsetAsync(ptr, 0, (size_t)(ceil(size / 32.)) * 32, stream_)); - } - check_cuda_error(getSetDevice(o_device)); - // FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); - - pointer_mapping_->insert({getAddress(ptr), size}); - - return ptr; - } - - void free(void **ptr, bool is_host = false) const { - // FT_LOG_DEBUG(__PRETTY_FUNCTION__); - void *address = getAddress(*ptr); - if (*ptr != nullptr) { - int o_device = 0; - if (pointer_mapping_->count(address)) { - // FT_LOG_DEBUG("Free buffer %p", address); - check_cuda_error(getSetDevice(device_id_, &o_device)); - if (is_host) { - check_cuda_error(cudaFreeHost(*ptr)); - } else { -#if defined(CUDA_MEMORY_POOL_DISABLED) - check_cuda_error(cudaFree(*ptr)); -#else - check_cuda_error(cudaFreeAsync(*ptr, stream_)); - cudaStreamSynchronize(stream_); -#endif - } - check_cuda_error(getSetDevice(o_device)); - pointer_mapping_->erase(address); - } else { - // FT_LOG_WARNING("pointer_mapping_ does not have information of ptr at - // %p.", address); - } - } - *ptr = nullptr; - return; - } - - void memSet(void *ptr, const int val, const size_t size) { - check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_)); - } -}; \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc b/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc deleted file mode 100644 index 61e41438c6a8..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.cc +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cublasAlgoMap.h" - -cublasAlgoMap::cublasAlgoMap(const std::string filename, - const std::string sp_config_filename) - : config_filename_(filename), sp_config_filename_(sp_config_filename) { - loadGemmConfig(); - loadSpGemmConfig(); -} - -cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap &algo_map) - : config_filename_(algo_map.config_filename_), - sp_config_filename_(algo_map.sp_config_filename_), - algo_map_(algo_map.algo_map_), sp_algo_map_(algo_map.sp_algo_map_) {} - -cublasAlgoMap::~cublasAlgoMap() { algo_map_.clear(); } - -void cublasAlgoMap::loadGemmConfig() { - FILE *fd; - fd = fopen(config_filename_.c_str(), "r"); - if (fd == NULL) { - std::cout << "[WARNING] " << config_filename_ - << " is not found; using default GEMM algo" << std::endl; - return; - } - - int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val; - int batch_size, seq_len, head_num, size_per_head, dataType; - int swizzle, reductionScheme, workspaceSize, stages; - int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode; - float exec_time; - char tmp[1024]; - if (!fgets(tmp, 1024, fd)) { - printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); - exit(-1); - } - while (fscanf(fd, - "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d " -#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) - "%d %d " -#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) - "%d %d %d " -#endif - "%f\n", - &batch_size, &seq_len, &head_num, &size_per_head, &dataType, - &batchCount2, &n2, &m2, &k2, &algoId, &customOption, &tile, - &splitK_val, &swizzle, &reductionScheme, &workspaceSize, - &stages, -#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) - &inner_shapeId, &cluster_shapeId, -#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) - &mma_shapeId, &cga_shapeId, &sche_mode, -#endif - &exec_time) != EOF) { - if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && - dataType != BFLOAT16_DATATYPE && dataType != INT8_DATATYPE && - dataType != FP8_DATATYPE) { - printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType); - continue; - } - cublasAlgoConfig_t markStr{batchCount2, m2, n2, k2, - static_cast(dataType)}; - // workspaceSize should be zero - if (algo_map_.find(markStr) == algo_map_.end()) { - algo_map_[markStr].algoId = algoId; - algo_map_[markStr].customOption = customOption; - algo_map_[markStr].tile = tile; - algo_map_[markStr].splitK_val = splitK_val; - algo_map_[markStr].swizzle = swizzle; - algo_map_[markStr].reductionScheme = reductionScheme; - algo_map_[markStr].workspaceSize = workspaceSize; - algo_map_[markStr].stages = stages; -#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) - algo_map_[markStr].inner_shapeId = (uint16_t)inner_shapeId; - algo_map_[markStr].cluster_shapeId = (uint16_t)cluster_shapeId; -#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) - algo_map_[markStr].mma_shapeId = (uint16_t)mma_shapeId; - algo_map_[markStr].cga_shapeId = (uint16_t)cga_shapeId; - algo_map_[markStr].sche_mode = (uint16_t)sche_mode; -#endif - algo_map_[markStr].exec_time = exec_time; - } - } - fclose(fd); -} - -bool cublasAlgoMap::isExist(const int batch_count, const int m, const int n, - const int k, const CublasDataType data_type) { - cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; - return algo_map_.find(mark) != algo_map_.end(); -} - -cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(const int batch_count, - const int m, const int n, - const int k, - const CublasDataType data_type) { - cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; - if (algo_map_.find(mark) != algo_map_.end()) { - return algo_map_[mark]; - } else { - cublasLtMatmulAlgo_info tmp_algo; - tmp_algo.algoId = static_cast(data_type == FLOAT_DATATYPE - ? CUBLAS_GEMM_DEFAULT - : CUBLAS_GEMM_DEFAULT_TENSOR_OP); - tmp_algo.customOption = -1; - tmp_algo.tile = -1; - tmp_algo.splitK_val = -1; - tmp_algo.swizzle = -1; - tmp_algo.reductionScheme = -1; - tmp_algo.workspaceSize = -1; - tmp_algo.stages = -1; - tmp_algo.exec_time = -1.0f; - return tmp_algo; - } -} - -void cublasAlgoMap::loadSpGemmConfig() { - if (sp_config_filename_.empty()) { - return; - } - FILE *fd = fopen(sp_config_filename_.c_str(), "r"); - if (fd == NULL) { - printf("[WARNING] %s is not found; using SPGEMM algo id 0\n", - sp_config_filename_.c_str()); - return; - } - sp_algo_map_.clear(); - int batch_size, seq_len, head_num, size_per_head, data_type; - int batchCount, m, n, k, algoId; - float exec_time; - char tmp[1024]; - if (!fgets(tmp, 1024, fd)) { - printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); - exit(-1); - } - while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, - &seq_len, &head_num, &size_per_head, &data_type, &batchCount, - &m, &n, &k, &algoId, &exec_time) != EOF) { - char mark[256]; - sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k); - std::string markStr(mark); - sp_algo_map_[markStr] = algoId; - } - fclose(fd); -} - -int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, - const int k) { - char mark[256]; - sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); - if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { - return sp_algo_map_[mark]; - } else { - // for remove padding, select algo 1 for simplicity - return 0; - } -} - -bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, - const int k) { - // not available to use cusparselt. - if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0) { - return false; - } - char mark[256]; - sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); - if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { - return sp_algo_map_[mark] != -1; - } else { - // no gemm test case, choose sparse according to sparse flag - return true; - } -} diff --git a/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h b/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h deleted file mode 100644 index beb9d3a23d90..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/cublasAlgoMap.h +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cuda_utils.h" -#include -#include -#include -#include -#include -#include -#include - -#pragma once - -#define GEMM_NUM 6 -#define GEMM_CONFIG "gemm_config.in" -#define IGEMM_CONFIG "igemm_config.in" -#define SPGEMM_CONFIG "spgemm_config.in" -#define SPIGEMM_CONFIG "spigemm_config.in" - -typedef struct { - int algoId, customOption, tile, splitK_val; - int swizzle, reductionScheme, workspaceSize; - // only used in cublasLt >= 11.0 - int stages; -#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) - uint16_t inner_shapeId, cluster_shapeId; -#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) - uint16_t mma_shapeId, cga_shapeId, sche_mode; -#endif - float exec_time; -} cublasLtMatmulAlgo_info; - -/* Structure to store information about different run trials */ -typedef struct { - cublasLtMatmulAlgo_t algo; - cublasStatus_t status; - float time; - size_t workspaceSize; // actual memory workspace needed - cublasMath_t mathMode; - cublasLtReductionScheme_t reductionScheme; - int customOption; - float wavesCount; -} customMatmulPerf_t; - -struct cublasAlgoConfig_t { - int batch_count; - int m; - int n; - int k; - CublasDataType data_type; - bool operator==(cublasAlgoConfig_t const &config) const { - return (batch_count == config.batch_count) && (m == config.m) && - (n == config.n) && (k == config.k) && - (data_type == config.data_type); - } -}; - -class cublasAlgoConfig_hasher { -public: - std::size_t operator()(cublasAlgoConfig_t const &config) const { - return config.batch_count * 98317ull ^ config.m * 49157ull ^ - config.n * 24593ull ^ config.k * 196613ull ^ - static_cast(config.data_type) * 6151ull; - } -}; - -class cublasAlgoMap { -private: - std::unordered_map - algo_map_; - std::string config_filename_; - std::string sp_config_filename_; - std::map sp_algo_map_; - -public: - cublasAlgoMap(){}; - explicit cublasAlgoMap(const std::string filename, - const std::string sp_config_filename = ""); - cublasAlgoMap(const cublasAlgoMap &map); - ~cublasAlgoMap(); - void loadGemmConfig(); - void loadSpGemmConfig(); - int getSpAlgo(const int batch_count, const int m, const int n, const int k); - bool isUseSparse(const int batch_count, const int m, const int n, - const int k); - - bool isExist(const int batch_count, const int m, const int n, const int k, - const CublasDataType data_type); - - cublasLtMatmulAlgo_info getAlgo(const int batch_count, const int m, - const int n, const int k, - const CublasDataType data_type); -}; diff --git a/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc b/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc deleted file mode 100644 index 03c656b10cbd..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.cc +++ /dev/null @@ -1,676 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cublasINT8MMWrapper.h" - -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#endif - -cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, - cudaStream_t stream, - cublasAlgoMap *cublas_algo_map, - std::mutex *mu, - bool use_ORDER_COL32_2R_4R4) - : cublas_handle_(nullptr), cublaslt_handle_(cublaslt_handle), - stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), - allocator_(nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} - -cublasINT8MMWrapper::cublasINT8MMWrapper(cublasHandle_t cublas_handle, - cublasLtHandle_t cublaslt_handle, - cudaStream_t stream, - cublasAlgoMap *cublas_algo_map, - std::mutex *mu, - bool use_ORDER_COL32_2R_4R4) - : cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), - stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), - allocator_(nullptr), use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} - - -cublasINT8MMWrapper::~cublasINT8MMWrapper() { mu_ = nullptr; } - -cublasINT8MMWrapper::cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper) - : cublas_handle_(nullptr), cublaslt_handle_(wrapper.cublaslt_handle_), - stream_(wrapper.stream_), cublas_algo_map_(wrapper.cublas_algo_map_), mu_(wrapper.mu_), - allocator_(wrapper.allocator_), use_ORDER_COL32_2R_4R4_(wrapper.use_ORDER_COL32_2R_4R4_) { -} - -// for int8 cublasLtMM with algo -// ATransform should be m*n, CUBLASLT_ORDER_COL32 -// kernel should be n*k, CUBLASLT_ORDER_COL4_4R2_8C or -// CUBLASLT_ORDER_COL32_2R_4R4 res is m*n, CUBLASLT_ORDER_COL32 -void cublasINT8MMWrapper::Gemm(int *res, int batchCount, int m, int n, int k, - int64_t stridea, int64_t strideb, - int64_t stridec, const int8_t *ATransform, - const int8_t *kernel) { - mu_->lock(); - cublasOperation_t opTranspose = CUBLAS_OP_T; -#if (CUDART_VERSION >= 11000) - cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; -#else - cudaDataType_t computeType = CUDA_R_32I; -#endif - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t AtransformDesc = NULL; - cublasLtMatrixLayout_t BtransformDesc = NULL; - cublasLtMatrixLayout_t CtransformDesc = NULL; - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; - - cublasLtOrder_t order_matrixB; -#if (CUDART_VERSION >= 11000) - if (use_ORDER_COL32_2R_4R4_) { - order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; - } else { - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; - } -#else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; -#endif - - int ldaTransform = 32 * m; - int ldbTransform; - if (use_ORDER_COL32_2R_4R4_) { - ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; - } else { - ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; - } - int ldcTransform = 32 * m; - - // create matmulDesc -#if (CUDART_VERSION >= 11000) - cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); -#else - cublasLtMatmulDescCreate(&matmulDesc, computeType); -#endif - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTranspose, sizeof(cublasOperation_t)); - cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL32, sizeof(order_COL32)); - cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_matrixB, sizeof(order_matrixB)); - cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, m, n, ldcTransform); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL32, sizeof(order_COL32)); - if (batchCount > 1) { - cublasLtMatrixLayoutSetAttribute(AtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, - sizeof(stridea)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, - sizeof(strideb)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, - sizeof(stridec)); - } - - int alphaI = 1; - int betaI = 0; - - // get algo - cublasLtMatmulAlgo_t algo; - int findAlgo = 0; - if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { - // printf("find algo %s\n", markStr.c_str()); - findAlgo = 1; - - cublasLtMatmulAlgo_info tmp_info = - cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); - - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, - CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, - &algo); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), - sizeof(tmp_info.customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tmp_info.tile), - sizeof(tmp_info.tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(tmp_info.splitK_val), - sizeof(tmp_info.splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), - sizeof(tmp_info.swizzle)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(tmp_info.reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(tmp_info.stages), - sizeof(tmp_info.stages)); -#endif - } else { - findAlgo = 1; - int algoId; - if (use_ORDER_COL32_2R_4R4_) { - algoId = 7; - } else { - algoId = 6; - } - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; - int reductionScheme = 0; - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, - CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &(customOption), sizeof(customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tile), sizeof(tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(splitK_val), sizeof(splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - int stages; - if (use_ORDER_COL32_2R_4R4_) { - stages = 15; - } else { - stages = 13; - } - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(stages), sizeof(stages)); -#endif - } - - cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, ATransform, - AtransformDesc, kernel, BtransformDesc, &betaI, res, - CtransformDesc, res, CtransformDesc, - (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream_); - - cublasLtMatmulDescDestroy(matmulDesc); - cublasLtMatrixLayoutDestroy(AtransformDesc); - cublasLtMatrixLayoutDestroy(BtransformDesc); - cublasLtMatrixLayoutDestroy(CtransformDesc); - sync_check_cuda_error(); - mu_->unlock(); -} - -// Atransform: mxk CUDA_R_8I -// kernel: nxk CUDA_R_8I -// res: mxn CUDA_R_32I -// alpha: CUDA_R_32I should be 1 -// beta: CUDA_R_32I should be 0 -// computeType: CUBLAS_COMPUTE_32I -void cublasINT8MMWrapper::Gemm_(int *res, int batchCount, int m, int n, int k, - int64_t stridea, int64_t strideb, - int64_t stridec, const int8_t *ATransform, - const int8_t *kernel) { - mu_->lock(); - cublasOperation_t opTranspose = CUBLAS_OP_T; -#if (CUDART_VERSION >= 11000) - cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; -#else - cudaDataType_t computeType = CUDA_R_32I; -#endif - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t AtransformDesc = NULL; - cublasLtMatrixLayout_t BtransformDesc = NULL; - cublasLtMatrixLayout_t CtransformDesc = NULL; - - // create matmulDesc -#if (CUDART_VERSION >= 11000) - cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); -#else - cublasLtMatmulDescCreate(&matmulDesc, computeType); -#endif - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &opTranspose, sizeof(cublasOperation_t)); - - cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); - - cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); - - cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, n, m, n); - - if (batchCount > 1) { - cublasLtMatrixLayoutSetAttribute(AtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, - sizeof(stridea)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, - sizeof(strideb)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, - sizeof(stridec)); - } - - int alphaI = 1; - int betaI = 0; - - // get algo - cublasLtMatmulAlgo_t algo; - int findAlgo = 0; - if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { - // printf("find algo %s\n", markStr.c_str()); - findAlgo = 1; - - cublasLtMatmulAlgo_info tmp_info = - cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); - - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, - CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, - &algo); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), - sizeof(tmp_info.customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tmp_info.tile), - sizeof(tmp_info.tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(tmp_info.splitK_val), - sizeof(tmp_info.splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), - sizeof(tmp_info.swizzle)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(tmp_info.reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(tmp_info.stages), - sizeof(tmp_info.stages)); -#endif - } else { - findAlgo = 1; - int algoId; - algoId = 21; - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; - int reductionScheme = 0; - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, - CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &(customOption), sizeof(customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tile), sizeof(tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(splitK_val), sizeof(splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - int stages; - stages = 17; - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(stages), sizeof(stages)); -#endif - } - - cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, kernel, AtransformDesc, - ATransform, BtransformDesc, &betaI, res, CtransformDesc, res, - CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, - stream_); - - cublasLtMatmulDescDestroy(matmulDesc); - cublasLtMatrixLayoutDestroy(AtransformDesc); - cublasLtMatrixLayoutDestroy(BtransformDesc); - cublasLtMatrixLayoutDestroy(CtransformDesc); - sync_check_cuda_error(); - mu_->unlock(); -} - -// for int8 IO cublasLtMM with algo -// ATransform should be m*k CUBLASLT_ORDER_COL32 -// kernel should be n*k CUBLASLT_ORDER_COL4_4R2_8C -// res is m*n CUBLASLT_ORDER_COL32 -void cublasINT8MMWrapper::Gemm(int8_t *res, int batchCount, int m, int n, int k, - int64_t stridea, int64_t strideb, - int64_t stridec, const float alpha, - const int8_t *ATransform, const int8_t *kernel) { - mu_->lock(); - cublasOperation_t opTranspose = CUBLAS_OP_T; - // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE - // cublasLtPointerMode_t pointerMode = - // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cudaDataType_t scaleType = CUDA_R_32F; -#if (CUDART_VERSION >= 11000) - cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; -#else - cudaDataType_t computeType = CUDA_R_32I; -#endif - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t AtransformDesc = NULL; - cublasLtMatrixLayout_t BtransformDesc = NULL; - cublasLtMatrixLayout_t CtransformDesc = NULL; - cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; - - cublasLtOrder_t order_matrixB; -#if (CUDART_VERSION >= 11000) - if (use_ORDER_COL32_2R_4R4_) { - order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; - } else { - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; - } -#else - order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; -#endif - - int ldaTransform = 32 * m; - - int ldbTransform; - if (use_ORDER_COL32_2R_4R4_) { - ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; - } else { - ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; - } - - int ldcTransform = 32 * m; - - // create matmulDesc -#if (CUDART_VERSION >= 11000) - cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); -#else - cublasLtMatmulDescCreate(&matmulDesc, computeType); -#endif - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &opTranspose, sizeof(cublasOperation_t)); - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, - &scaleType, sizeof(scaleType)); - // cublasLtMatmulDescSetAttribute(matmulDesc, - // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, - // sizeof(cublasLtPointerMode_t)); - cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); - cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL32, sizeof(order_COL32)); - cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_matrixB, sizeof(order_matrixB)); - cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, m, n, ldcTransform); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL32, sizeof(order_COL32)); - if (batchCount > 1) { - cublasLtMatrixLayoutSetAttribute(AtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, - sizeof(stridea)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, - sizeof(strideb)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, - sizeof(stridec)); - } - - // get algo - cublasLtMatmulAlgo_t algo; - int findAlgo = 0; - if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { - findAlgo = 1; - - cublasLtMatmulAlgo_info tmp_info = - cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); - - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, - CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, - &algo); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), - sizeof(tmp_info.customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tmp_info.tile), - sizeof(tmp_info.tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(tmp_info.splitK_val), - sizeof(tmp_info.splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), - sizeof(tmp_info.swizzle)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(tmp_info.reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(tmp_info.stages), - sizeof(tmp_info.stages)); -#endif - } else { - findAlgo = 1; - int algoId; - if (use_ORDER_COL32_2R_4R4_) { - algoId = 7; - } else { - algoId = 6; - } - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; - int reductionScheme = 0; - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, - CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &(customOption), sizeof(customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tile), sizeof(tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(splitK_val), sizeof(splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - int stages; - if (use_ORDER_COL32_2R_4R4_) { - stages = 15; - } else { - stages = 13; - } - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(stages), sizeof(stages)); -#endif - } - - float beta = 0.0f; - cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, - ATransform, BtransformDesc, &beta, res, CtransformDesc, res, - CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, - stream_); - - cublasLtMatmulDescDestroy(matmulDesc); - cublasLtMatrixLayoutDestroy(AtransformDesc); - cublasLtMatrixLayoutDestroy(BtransformDesc); - cublasLtMatrixLayoutDestroy(CtransformDesc); - sync_check_cuda_error(); - mu_->unlock(); -} - -// Atransform: mxk CUDA_R_8I -// kernel: nxk CUDA_R_8I -// res: mxn CUDA_R_8I -// alpha: CUDA_R_32F -// beta: CUDA_R_32F -// computeType: CUBLAS_COMPUTE_32I -void cublasINT8MMWrapper::Gemm_(int8_t *res, int batchCount, int m, int n, - int k, int64_t stridea, int64_t strideb, - int64_t stridec, const float alpha, - const int8_t *ATransform, - const int8_t *kernel) { - mu_->lock(); - cublasOperation_t opTranspose = CUBLAS_OP_T; - // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE - // cublasLtPointerMode_t pointerMode = - // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cudaDataType_t scaleType = CUDA_R_32F; -#if (CUDART_VERSION >= 11000) - cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; -#else - cudaDataType_t computeType = CUDA_R_32I; -#endif - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t AtransformDesc = NULL; - cublasLtMatrixLayout_t BtransformDesc = NULL; - cublasLtMatrixLayout_t CtransformDesc = NULL; - - // create matmulDesc -#if (CUDART_VERSION >= 11000) - cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); -#else - cublasLtMatmulDescCreate(&matmulDesc, computeType); -#endif - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &opTranspose, sizeof(cublasOperation_t)); - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, - &scaleType, sizeof(scaleType)); - // cublasLtMatmulDescSetAttribute(matmulDesc, - // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, - // sizeof(cublasLtPointerMode_t)); - cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); - - cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); - - cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, n, m, n); - - if (batchCount > 1) { - cublasLtMatrixLayoutSetAttribute(AtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, - sizeof(stridea)); - cublasLtMatrixLayoutSetAttribute(BtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, - sizeof(strideb)); - cublasLtMatrixLayoutSetAttribute(CtransformDesc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batchCount, sizeof(batchCount)); - cublasLtMatrixLayoutSetAttribute( - CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, - sizeof(stridec)); - } - - // get algo - cublasLtMatmulAlgo_t algo; - int findAlgo = 0; - if (cublas_algo_map_->isExist(batchCount, n, m, k, INT8_DATATYPE)) { - findAlgo = 1; - cublasLtMatmulAlgo_info tmp_info = - cublas_algo_map_->getAlgo(batchCount, n, m, k, INT8_DATATYPE); - - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, - CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, - &algo); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), - sizeof(tmp_info.customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tmp_info.tile), - sizeof(tmp_info.tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(tmp_info.splitK_val), - sizeof(tmp_info.splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), - sizeof(tmp_info.swizzle)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(tmp_info.reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(tmp_info.stages), - sizeof(tmp_info.stages)); -#endif - } else { - findAlgo = 1; - int algoId; - algoId = 21; - int swizzle = 0; - int customOption = 0; - int tile = 20; - int splitK_val = 0; - int reductionScheme = 0; - cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, - CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, - &(customOption), sizeof(customOption)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, - &(tile), sizeof(tile)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, - &(splitK_val), sizeof(splitK_val)); - cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); - cublasLtMatmulAlgoConfigSetAttribute(&algo, - CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, - &(reductionScheme), sizeof(int)); -#if (CUDART_VERSION >= 11000) - int stages; - stages = 17; - cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, - &(stages), sizeof(stages)); -#endif - } - - float beta = 0.0f; - cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, - ATransform, BtransformDesc, &beta, res, CtransformDesc, res, - CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, - stream_); - - cublasLtMatmulDescDestroy(matmulDesc); - cublasLtMatrixLayoutDestroy(AtransformDesc); - cublasLtMatrixLayoutDestroy(BtransformDesc); - cublasLtMatrixLayoutDestroy(CtransformDesc); - sync_check_cuda_error(); - mu_->unlock(); -} - -bool cublasINT8MMWrapper::getUseOrderCol322R4R4() { - return use_ORDER_COL32_2R_4R4_; -} diff --git a/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h b/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h deleted file mode 100644 index 8bc209f58b91..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/cublasINT8MMWrapper.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "allocator.h" -#include "cublasAlgoMap.h" -#include -#include -#include -#include -#include -#include - -#pragma once - -class cublasINT8MMWrapper{ -protected: - cublasHandle_t cublas_handle_; - cublasLtHandle_t cublaslt_handle_; - cudaStream_t stream_; - cublasAlgoMap *cublas_algo_map_; - std::mutex *mu_; - IAllocator *allocator_ = nullptr; - -private: - bool use_ORDER_COL32_2R_4R4_; - -public: - cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, - cublasAlgoMap *map, std::mutex *mu, - bool use_ORDER_COL32_2R_4R4); - - cublasINT8MMWrapper(cublasHandle_t cublas_handle, - cublasLtHandle_t cublaslt_handle, cudaStream_t stream, - cublasAlgoMap *map, std::mutex *mu, - bool use_ORDER_COL32_2R_4R4); - - ~cublasINT8MMWrapper(); - - cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper); - - void Gemm(int *res, int batchCount, int m, int n, int k, int64_t stridea, - int64_t strideb, int64_t stridec, const int8_t *ATransform, - const int8_t *kernel); - - void Gemm_(int *res, int batchCount, int m, int n, int k, int64_t stridea, - int64_t strideb, int64_t stridec, const int8_t *ATransform, - const int8_t *kernel); - - void Gemm(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, - int64_t strideb, int64_t stridec, const float alpha, - const int8_t *ATransform, const int8_t *kernel); - - void Gemm_(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, - int64_t strideb, int64_t stridec, const float alpha, - const int8_t *ATransform, const int8_t *kernel); - - bool getUseOrderCol322R4R4(); -}; \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/cuda_utils.cc b/csrc/quantization/smoothquant/int8gemm/cuda_utils.cc deleted file mode 100644 index 588375570937..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/cuda_utils.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cuda_utils.h" - -cudaError_t getSetDevice(int i_device, int *o_device) { - int current_dev_id = 0; - cudaError_t err = cudaSuccess; - - if (o_device != NULL) { - err = cudaGetDevice(¤t_dev_id); - if (err != cudaSuccess) { - return err; - } - if (current_dev_id == i_device) { - *o_device = i_device; - } else { - err = cudaSetDevice(i_device); - if (err != cudaSuccess) { - return err; - } - *o_device = current_dev_id; - } - } else { - err = cudaSetDevice(i_device); - if (err != cudaSuccess) { - return err; - } - } - - return cudaSuccess; -} diff --git a/csrc/quantization/smoothquant/int8gemm/cuda_utils.h b/csrc/quantization/smoothquant/int8gemm/cuda_utils.h deleted file mode 100644 index f1d9bba4ab06..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/cuda_utils.h +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - - -enum CublasDataType { - FLOAT_DATATYPE = 0, - HALF_DATATYPE = 1, - BFLOAT16_DATATYPE = 2, - INT8_DATATYPE = 3, - FP8_DATATYPE = 4 -}; - -static const char *_cudaGetErrorEnum(cudaError_t error) { - return cudaGetErrorString(error); -} - -static const char *_cudaGetErrorEnum(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return ""; -} - -template -void check(T result, char const *const func, const char *const file, - int const line) { - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + - (_cudaGetErrorEnum(result)) + " " + file + ":" + - std::to_string(line) + " \n"); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) -#define check_cuda_error_2(val, file, line) check((val), #val, file, line) - -inline void syncAndCheck(const char *const file, int const line) { - // When FT_DEBUG_LEVEL=DEBUG, must check error - static char *level_name = std::getenv("FT_DEBUG_LEVEL"); - if (level_name != nullptr) { - static std::string level = std::string(level_name); - if (level == "DEBUG") { - cudaDeviceSynchronize(); - cudaError_t result = cudaGetLastError(); - if (result) { - throw std::runtime_error( - std::string("[FT][ERROR] CUDA runtime error: ") + - (_cudaGetErrorEnum(result)) + " " + file + ":" + - std::to_string(line) + " \n"); - } - // FT_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line)); - } - } - -#ifndef NDEBUG - cudaDeviceSynchronize(); - cudaError_t result = cudaGetLastError(); - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + - (_cudaGetErrorEnum(result)) + " " + file + ":" + - std::to_string(line) + " \n"); - } -#endif -} - -#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__) - - -[[noreturn]] inline void throwRuntimeError(const char *const file, - int const line, - std::string const &info = "") { - throw std::runtime_error(std::string("[FT][ERROR] ") + info + - " Assertion fail: " + file + ":" + - std::to_string(line) + " \n"); -} - -inline void myAssert(bool result, const char *const file, int const line, - std::string const &info = "") { - if (!result) { - throwRuntimeError(file, line, info); - } -} - -#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) -#define FT_CHECK_WITH_INFO(val, info) \ - do { \ - bool is_valid_val = (val); \ - if (!is_valid_val) { \ - fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ - } \ - } while (0) - -#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) - -cudaError_t getSetDevice(int i_device, int *o_device = NULL); - -inline int getDevice() { - int current_dev_id = 0; - check_cuda_error(cudaGetDevice(¤t_dev_id)); - return current_dev_id; -} - -inline int getDeviceCount() { - int count = 0; - check_cuda_error(cudaGetDeviceCount(&count)); - return count; -} \ No newline at end of file diff --git a/csrc/quantization/smoothquant/int8gemm/int8_gemm.h b/csrc/quantization/smoothquant/int8gemm/int8_gemm.h deleted file mode 100644 index 2e80d4efe22a..000000000000 --- a/csrc/quantization/smoothquant/int8gemm/int8_gemm.h +++ /dev/null @@ -1,127 +0,0 @@ -/* - gemm methods are adapted from ft -*/ -#include -#include "cublasAlgoMap.h" -#include "cublasINT8MMWrapper.h" - -class I8CUGEMM { -private: - cublasINT8MMWrapper *int8_gemm_wrapper = nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - -public: - I8CUGEMM(); - ~I8CUGEMM(); - - void linear_a8_w8_o32( - torch::Tensor& input, - torch::Tensor& weight, - torch::Tensor& output); - void linear_a8_w8_o32_( - torch::Tensor& input, - torch::Tensor& weight, - torch::Tensor& output); - void linear_a8_w8_o8( - torch::Tensor& input, - torch::Tensor& weight, - torch::Tensor& output, - float alpha); - void linear_a8_w8_o8_( - torch::Tensor& input, - torch::Tensor& weight, - torch::Tensor& output, - float alpha); -}; -I8CUGEMM::I8CUGEMM() { - // cublasAlgoMap *cublas_algo_map = new cublasAlgoMap("igemm_config.in"); - cublasAlgoMap *cublas_algo_map = new cublasAlgoMap(); - std::mutex *cublas_wrapper_mutex = new std::mutex(); - bool use_ORDER_COL32_2R_4R4 = true; - - cublasLtHandle_t cublaslt_handle; - cublasLtCreate(&cublaslt_handle); - - int8_gemm_wrapper = new cublasINT8MMWrapper( - cublaslt_handle, - this->stream, - cublas_algo_map, - cublas_wrapper_mutex, - use_ORDER_COL32_2R_4R4); -} - -I8CUGEMM::~I8CUGEMM() {} - -void I8CUGEMM::linear_a8_w8_o32( - torch::Tensor& input, // INT8 - torch::Tensor& weight, // INT8 - torch::Tensor& out // INT32 -) { - int m = input.size(0); - int n = weight.size(0); - int k = input.size(1); - - // Set data types - int8_t* input_ptr = input.data_ptr(); - int8_t* weight_ptr = weight.data_ptr(); - int32_t* output_ptr = out.data_ptr(); - - int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, - weight_ptr); -} - -void I8CUGEMM::linear_a8_w8_o32_( - torch::Tensor& input, // INT8 - torch::Tensor& weight, // INT8 - torch::Tensor& out // INT32 -) { - int m = input.size(0); - int n = weight.size(0); - int k = input.size(1); - - // Set data types - int8_t* input_ptr = input.data_ptr(); - int8_t* weight_ptr = weight.data_ptr(); - int32_t* output_ptr = out.data_ptr(); - - int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, - weight_ptr); -} - -void I8CUGEMM::linear_a8_w8_o8( - torch::Tensor& input, // INT8 - torch::Tensor& weight, // INT8 - torch::Tensor& out, // INT8 - float alpha // FP32 -) { - int m = input.size(0); - int n = weight.size(0); - int k = input.size(1); - - // Set data types - int8_t* input_ptr = input.data_ptr(); - int8_t* weight_ptr = weight.data_ptr(); - int8_t* output_ptr = out.data_ptr(); - - int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, - weight_ptr); -} - -void I8CUGEMM::linear_a8_w8_o8_( - torch::Tensor& input, // INT8 - torch::Tensor& weight, // INT8 - torch::Tensor& out, // INT8 - float alpha // FP32 -) { - int m = input.size(0); - int n = weight.size(0); - int k = input.size(1); - - // Set data types - int8_t* input_ptr = input.data_ptr(); - int8_t* weight_ptr = weight.data_ptr(); - int8_t* output_ptr = out.data_ptr(); - - int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, - weight_ptr); -} diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6ee75e8139c0..f3adcb519ed9 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,3 +8,4 @@ vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 triton >= 2.1.0 +nvidia-cutlass == 3.5.0 diff --git a/tests/kernels/test_fusion.py b/tests/kernels/test_fusion.py index 07d9ce60a403..7cebc76b248b 100644 --- a/tests/kernels/test_fusion.py +++ b/tests/kernels/test_fusion.py @@ -9,54 +9,6 @@ SEEDS = [0] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("scale", SCALE) -@torch.inference_mode() -def test_dequant(num_tokens: int, hidden_size: int, dtype: torch.dtype, - seed: int, scale: float) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - x = torch.randint( - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - (num_tokens, hidden_size), - dtype=torch.int32, - device="cuda", - ) - - out1 = (x * scale).to(dtype) - out2 = torch.empty_like(x, dtype=dtype) - ops.dequant(out2, x, scale) - assert torch.allclose(out1, out2, atol=0.001) - - -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@torch.inference_mode() -def test_per_token_dequant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - x = torch.randint( - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - (num_tokens, hidden_size), - dtype=torch.int32, - device="cuda", - ) - scale = torch.rand(num_tokens, 1, dtype=torch.float32, device="cuda") - out1 = (x * scale).to(dtype) - out2 = torch.empty_like(x, dtype=dtype) - scale = torch.squeeze(scale) - ops.dequant(out2, x, scale) - assert torch.allclose(out1, out2, atol=0.001) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 20891acbbfb5..905b01e69d7e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -166,9 +166,6 @@ class ColumnParallelLinear(torch.nn.Module): skip adding bias but instead return it. params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. - logical_widths: Optional list of widths for logical weight matrices. - E.g. for QKVParallelLinear, this parameter defines - the width """ def __init__( @@ -308,6 +305,7 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) param_shard_splitter = getattr(param, "shard_splitter", None) + if output_dim is not None and param_shard_splitter is not None: raise NotImplementedError( "We do not currently support output_dim != None and " @@ -373,8 +371,11 @@ def weight_loader(self, shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") param_data, loaded_weight = param_shard_splitter( - param_data, loaded_weight, loaded_shard_id) + param_data, loaded_weight, loaded_shard_id, logical_widths + ) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -466,7 +467,7 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) param_shard_splitter = getattr(param, "shard_splitter", None) - + if output_dim is not None and param_shard_splitter is not None: raise NotImplementedError( "We do not currently support output_dim != None and " @@ -548,8 +549,10 @@ def weight_loader(self, shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") param_data, loaded_weight = param_shard_splitter( - param_data, loaded_weight, loaded_shard_id) + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: diff --git a/vllm/model_executor/layers/quantization/smoothquant/config.py b/vllm/model_executor/layers/quantization/smoothquant/config.py index 885ffce3e36d..788481a51ae1 100644 --- a/vllm/model_executor/layers/quantization/smoothquant/config.py +++ b/vllm/model_executor/layers/quantization/smoothquant/config.py @@ -1,10 +1,8 @@ from typing import Any, Dict, List, Tuple, Type, Optional, Union -import threading import torch from torch.nn.parameter import Parameter -from vllm._C import ops from vllm.model_executor.layers.linear import ( LinearMethodBase, set_weight_attrs) @@ -15,6 +13,9 @@ SmoothQuantDynamicPerToken, SmoothQuantStaticPerTensor, ) +from vllm.model_executor.layers.quantization.smoothquant.cutlass_gemm import ( + cutlass_gemm_dq +) LAYER_KEYS = ["qkv", "out", "fc1", "fc2"] FORMAT_REGISTRY = { @@ -85,31 +86,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": def get_linear_method(self) -> "SmoothQuantLinearMethod": return SmoothQuantLinearMethod(self) - -# TODO: why is this needed? -class Int8GEMM(object): - _instance_lock = threading.Lock() - - def __init__(self): - if not hasattr(self, "i8cugemm"): - self.i8cugemm = ops.I8CUGEMM() - - def __new__(cls, *args, **kwargs): - if not hasattr(Int8GEMM, "_instance"): - with Int8GEMM._instance_lock: - if not hasattr(Int8GEMM, "_instance"): - Int8GEMM._instance = object.__new__(cls) - return Int8GEMM._instance - - def get_i8cugemm(self): - return self.i8cugemm - - class SmoothQuantLinearMethod(LinearMethodBase): def __init__(self, sq_config: SmoothQuantConfig) -> None: self.sq_config = sq_config self.sq_type = None - self.i8cugemm = Int8GEMM().get_i8cugemm() def maybe_update_loaded_weight_name(self, name: str) -> str: @@ -123,24 +103,26 @@ def maybe_update_loaded_weight_name(self, name.replace(suffix, "dequant_scale") return name - def scales_shard_splitter(self, - param: torch.Tensor, - loaded_weight: torch.Tensor, - shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: - """Index into param for for loading. - - This function is called by QKVColumnLinear and MergedColumnParallelLinear - during weight loading to put the scales from disk in the right spot. - """ - if type(shard_id) == str: - qkv_idxs = { "q": 0, "k": 1, "v": 2 } - if shard_id not in qkv_idxs: - raise ValueError(f"Invalid shard_id {shard_id}") - shard_id = qkv_idxs[shard_id] - elif type(shard_id) != int: - raise ValueError(f"Invalid shard id {shard_id}") + def shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id - return param[shard_id], loaded_weight + assert isinstance(shard_id, str) + qkv_idxs = { "q": 0, "k": 1, "v": 2 } + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter(self, + param: torch.Tensor, + loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self.shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset : offset + size], loaded_weight def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: """ @@ -213,16 +195,22 @@ def create_weights(self, "output_dim": 0, }) - # Static scale for each logical weight (e.g. 3 for QKV). - dequant_scale = Parameter( - torch.empty( - len(output_sizes_per_partition), - device='cuda', dtype=params_dtype, - ), requires_grad=False - ) - set_weight_attrs(dequant_scale, { - "shard_splitter": self.scales_shard_splitter, - }) + if len(output_sizes_per_partition) == 1: + # Single static scale for the entire tensor. + dequant_scale = Parameter( + torch.empty((1),device='cuda', dtype=params_dtype), + requires_grad=False + ) + else: + # Static scale for each logical weight (e.g. 3 for QKV). + dequant_scale = Parameter( + torch.empty((sum(output_sizes_per_partition)), + device='cuda', dtype=params_dtype), + requires_grad=False + ) + set_weight_attrs(dequant_scale, + {"shard_splitter": self.scales_shard_splitter, + "logical_widths" : output_sizes_per_partition}) return { "weight": weight, @@ -242,37 +230,10 @@ def _quantize(self, x_q: Quantized activation at INT8 activation_scales: Optional dynamic scales for each token. """ - x_q = torch.empty_like(x, dtype=torch.int8) + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") x_q, activation_scales = sq_format.quantize_op(x, x_q) return x_q, activation_scales - def _dequantize(self, - x_q: torch.Tensor, - dynamic_scales: Optional[torch.Tensor], - static_scales: torch.Tensor, - logical_widths: List[int], - dtype: torch.dtype, - sq_format: SmoothQuantFormat) -> torch.Tensor: - """Dequantize activations. - - Args: - x_q: quantized activations. - dynamic_scales: Optional dynamic scales. - static_scales: Static dequantization scales. - logical_widths: Width of each logical activation (for QKV case). - dtype: Datatype to dequantize to. - Returns: - x_dq: dequantized activation at output_dtype precision - """ - # Split X_q and X_dq buffer into logical activations (for QKV case). - x_q_split = x_q.split(logical_widths, dim=-1) - x_dq = torch.empty_like(x_q, dtype=dtype) - x_dq_split = x_dq.split(logical_widths, dim=-1) - # Dequantize in place and return. - sq_format.dequantize_op(x_q_split, x_dq_split, dynamic_scales, static_scales) - return x_dq - - def apply_weights(self, weights: Dict[str, torch.Tensor], x: torch.Tensor, @@ -290,17 +251,11 @@ def apply_weights(self, raise NotImplementedError weight_q = weights["weight"] static_scales = weights["dequant_scale"] - logical_widths = weights["logical_widths"] sq_format = weights["sq_format"] # Q x_q, activation_scales = self._quantize(x, sq_format) - # GEMM - x_q = x_q.view(-1, x_q.shape[-1]) - a_q = torch.empty((x_q.shape[0], weight_q.shape[0]), dtype=torch.int32, device="cuda") - self.i8cugemm.linear_a8_w8_o32_(x_q, weight_q, a_q) - a_q = a_q.view(*x_q.shape[:-1], -1) + # GEMM and DQ + return cutlass_gemm_dq(x_q, weight_q, x.dtype, static_scales, activation_scales) - # DQ - return self._dequantize(a_q, activation_scales, static_scales, logical_widths, x.dtype, sq_format) diff --git a/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py b/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py new file mode 100644 index 000000000000..05ae38c3343e --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py @@ -0,0 +1,75 @@ + +import cutlass +from cutlass import Tensor as FakeTensor +import cutlass.epilogue + +import torch +from typing import Optional, Tuple, Dict + +from vllm.logger import init_logger + +logger = init_logger("cutlass_gemm") + +def setup_dequant_epilogue(plan : cutlass.op.Gemm, + dq: torch.Tensor, + static_scales: Optional[torch.Tensor], + activation_scales: Optional[torch.Tensor]) \ + -> Tuple[cutlass.op.Gemm, Dict]: + + if all([static_scales is None, activation_scales is None]): + return plan, None + assert static_scales is not None + + def epilog_with_scales_and_act_scales(accum, scales, act_scales): + D = accum * scales * act_scales + return D + + def epilog_with_scales(accum, scales): + D = accum * scales + return D + + epilog_tensors = { + 'scales' : static_scales, + 'D' : dq + } + epilogue_trace_tensors = { + "accum": FakeTensor(element=torch.int32, shape=dq.shape, + layout_tag=cutlass.LayoutType.RowMajor), + 'scales' : static_scales, + 'D' : dq, + } + epilog_fn = epilog_with_scales + + if activation_scales is not None: + epilog_tensors['act_scales'] = activation_scales + epilogue_trace_tensors['act_scales'] = activation_scales + epilog_fn = epilog_with_scales_and_act_scales + + plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, epilogue_trace_tensors) + return plan, epilog_tensors + +def cutlass_gemm_dq(x_q : torch.Tensor, + w_q : torch.Tensor, + dtype: torch.dtype, + static_scales: torch.Tensor, + activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + + dq = torch.empty((x_q.shape[0], w_q.shape[0]), + dtype=dtype, device="cuda") + + plan = cutlass.op.Gemm(element_A=x_q.dtype, element_B=w_q.dtype, + element_C=dq.dtype, element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32, + # TODO (varun) : lets not have kernel cc here please. + kernel_cc=80) + + plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, activation_scales) + + plan.run(x_q, w_q.t(), dq, dq, alpha=1, beta=0, + visitor_args=visitor_args, print_module=False) + + dq = dq.view(*x_q.shape[:-1], -1) + return dq diff --git a/vllm/model_executor/layers/quantization/smoothquant/formats.py b/vllm/model_executor/layers/quantization/smoothquant/formats.py index b8ddd642c888..4155ef64ffe3 100644 --- a/vllm/model_executor/layers/quantization/smoothquant/formats.py +++ b/vllm/model_executor/layers/quantization/smoothquant/formats.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Type +from typing import Optional, Tuple import torch @@ -7,23 +7,6 @@ class SmoothQuantFormat(ABC): - @abstractmethod - def dequantize_op(self, - x_qs: List[torch.Tensor], - x_dqs: List[torch.Tensor], - dynamic_scales: Optional[torch.Tensor], - static_scales: torch.Tensor) -> None: - """Dequantize the activations. x_dq is updated in place. - - Args: - x_qs: List of N quantized activations. - x_dqs: List of N buffers to fill with dequantized values. - dynamic_scales: Optional dynamic scales for dequantization. - static_scales: Static scales for dequantization. N values. - """ - raise NotImplementedError - - @abstractmethod def quantize_op(self, x: torch.Tensor, @@ -41,55 +24,18 @@ def quantize_op(self, class SmoothQuantDynamicPerToken(SmoothQuantFormat): - def dequantize_op(self, - x_qs: List[torch.Tensor], - x_dqs: List[torch.Tensor], - dynamic_scales: Optional[torch.Tensor], - static_scales: torch.Tensor) -> None: - """Notes: - dynamic_scales: N scales for N tokens in the activation. - static_scales: K scales for K logical activations (equals just w_scale). - """ - if dynamic_scales is None: - raise ValueError - - # Dequantize each logical activation. - # TODO: test this for case when logical_widths > 1 (may need to reshape) - for x_dq, x_q, dynamic_scale, static_scale in zip( - x_dqs, x_qs, dynamic_scales, static_scales): - - # Dequantize (updates x_dq in place). - ops.dequant(x_dq, x_q, dynamic_scale, static_scale) - - def quantize_op(self, x: torch.Tensor, x_q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Notes: Returns quantized activaiton and dynamic activation scales. """ - activation_scales = torch.empty(x.numel() // x.shape[-1], dtype=x.dtype, device=x.device) + activation_scales = torch.empty((x.numel() // x.shape[-1], 1), dtype=x.dtype, device=x.device) ops.quant(x_q, x, activation_scales) return x_q, activation_scales class SmoothQuantStaticPerTensor(SmoothQuantFormat): - def dequantize_op(self, - x_qs: List[torch.Tensor], - x_dqs: List[torch.Tensor], - dynamic_scales: Optional[torch.Tensor], - static_scales: torch.Tensor) -> None: - """Notes: - dynamic_scales: None - static_scales: K scales for K logical activations (equals w_scale * a_scale). - """ - if dynamic_scales is not None: - raise ValueError - - # Dequantize each logical activation. - for xdq, xq, static_scale in zip(x_dqs, x_qs, static_scales): - ops.dequant(xdq, xq, static_scale) - def quantize_op(self, x: torch.Tensor, x_q: torch.Tensor) -> Tuple[torch.Tensor, None]: From 8f9264542ddec24ca8e6bbafc10ae1e6683390c6 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 24 Apr 2024 19:23:52 +0000 Subject: [PATCH 8/8] format.sh --- benchmarks/benchmark_throughput.py | 9 +- examples/offline_profile.py | 47 ++++---- examples/offline_quantized_inference.py | 23 ++-- examples/simple_test.py | 15 ++- tests/kernels/test_fusion.py | 1 + vllm/config.py | 4 +- vllm/model_executor/layers/linear.py | 84 ++++++------- .../model_executor/layers/quantization/awq.py | 14 +-- .../layers/quantization/marlin.py | 2 +- .../quantization/smoothquant/__init__.py | 7 +- .../layers/quantization/smoothquant/config.py | 111 +++++++++--------- .../quantization/smoothquant/cutlass_gemm.py | 82 +++++++------ .../quantization/smoothquant/formats.py | 21 ++-- .../layers/quantization/squeezellm.py | 14 +-- vllm/model_executor/model_loader.py | 3 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/baichuan.py | 11 +- vllm/model_executor/models/deepseek.py | 6 +- vllm/model_executor/models/gpt2.py | 14 ++- vllm/model_executor/models/llama.py | 10 +- vllm/model_executor/models/phi.py | 12 +- vllm/model_executor/models/qwen2.py | 25 ++-- vllm/model_executor/models/starcoder2.py | 13 +- vllm/worker/model_runner.py | 1 - 24 files changed, 266 insertions(+), 264 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index a708dbde4f50..d5fddba233e3 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -259,10 +259,11 @@ def main(args: argparse.Namespace): "output length from the dataset.") parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None], - default=None) + parser.add_argument( + '--quantization', + '-q', + choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None], + default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, diff --git a/examples/offline_profile.py b/examples/offline_profile.py index ab728e01ce54..b62ffb209854 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -28,12 +28,14 @@ class ProfileContext: tensor_parallel_size: int allow_cuda_graphs: bool -def get_dtype(dtype:str): + +def get_dtype(dtype: str): if dtype == "torch.float": return torch.float else: return dtype + def run_profile(context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]): print("Run profile with:") @@ -45,17 +47,17 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], # Sparsity is in the future # Create LLM - llm = LLM( - model=context.model, - tokenizer=context.tokenizer if context.tokenizer is not None else context.model, - revision=context.model_revision, - enforce_eager=not context.allow_cuda_graphs, - tensor_parallel_size=context.tensor_parallel_size, - gpu_memory_utilization=0.9, - max_model_len=context.max_seq_len, - quantization=context.quantization, - dtype=get_dtype(context.dtype), - max_num_batched_tokens=context.max_num_batched_tokens) + llm = LLM(model=context.model, + tokenizer=context.tokenizer + if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_seq_len, + quantization=context.quantization, + dtype=get_dtype(context.dtype), + max_num_batched_tokens=context.max_num_batched_tokens) batch_size = context.batch_size prompt_len = context.prompt_len @@ -168,11 +170,10 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], type=str, required=True, help='The name or path of a HuggingFace Transformers model.') - parser.add_argument( - "--tokenizer", - type=str, - default=None, - help="path to the tokenizer") + parser.add_argument("--tokenizer", + type=str, + default=None, + help="path to the tokenizer") parser.add_argument("--model-revision", type=str, default=None) parser.add_argument( @@ -196,12 +197,12 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None], default=None, help="The method used to quantize the model weights, " - "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"") - parser.add_argument( - "--dtype", - type=str, - default='auto', - help="model dtype") + "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"" + ) + parser.add_argument("--dtype", + type=str, + default='auto', + help="model dtype") parser.add_argument( "--max-seq-len", type=int, diff --git a/examples/offline_quantized_inference.py b/examples/offline_quantized_inference.py index 8b3dbea72ae6..5935341f1f9b 100644 --- a/examples/offline_quantized_inference.py +++ b/examples/offline_quantized_inference.py @@ -1,8 +1,8 @@ from vllm import LLM, SamplingParams import torch -hf_path="nm-testing/Nous-Hermes-Llama2-13b-smoothquant" -model_path=hf_path +hf_path = "nm-testing/Nous-Hermes-Llama2-13b-smoothquant" +model_path = hf_path # Sample prompts. prompts = [ @@ -13,18 +13,17 @@ ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.0, top_k = 1,max_tokens=20) +sampling_params = SamplingParams(temperature=0.0, top_k=1, max_tokens=20) # Create an LLM. -llm = LLM( - model="nm-testing/Nous-Hermes-Llama2-13b-smoothquant", - gpu_memory_utilization=0.9, - max_model_len=2048, - quantization="smoothquant", - dtype=torch.float, - enforce_eager=True, - tensor_parallel_size=1, - max_num_batched_tokens=7000) +llm = LLM(model="nm-testing/Nous-Hermes-Llama2-13b-smoothquant", + gpu_memory_utilization=0.9, + max_model_len=2048, + quantization="smoothquant", + dtype=torch.float, + enforce_eager=True, + tensor_parallel_size=1, + max_num_batched_tokens=7000) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/examples/simple_test.py b/examples/simple_test.py index 81cb82e928e3..949bd8e75b8e 100644 --- a/examples/simple_test.py +++ b/examples/simple_test.py @@ -39,13 +39,16 @@ model_id = MODELS[args.model] print(f"Using model_id = {model_id}") -messages=[{ - "role": "user", - "content": "What is deep learning?" -}] +messages = [{"role": "user", "content": "What is deep learning?"}] -model = LLM(model_id, enforce_eager=True, max_model_len=1024, tensor_parallel_size=args.tensor_parallel_size, dtype="float16", trust_remote_code=True) -prompt = model.llm_engine.tokenizer.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +model = LLM(model_id, + enforce_eager=True, + max_model_len=1024, + tensor_parallel_size=args.tensor_parallel_size, + dtype="float16", + trust_remote_code=True) +prompt = model.llm_engine.tokenizer.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) out = model.generate(prompt, SamplingParams(max_tokens=50)) print(f"\n-----prompt\n{prompt}") print(f"\n-----generation\n{out[0].outputs[0].text}") diff --git a/tests/kernels/test_fusion.py b/tests/kernels/test_fusion.py index 7cebc76b248b..a11515167a76 100644 --- a/tests/kernels/test_fusion.py +++ b/tests/kernels/test_fusion.py @@ -9,6 +9,7 @@ SEEDS = [0] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/vllm/config.py b/vllm/config.py index cd48fe4f1b9d..ea53b334139f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,7 +173,9 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "marlin", "squeezellm", "smoothquant"] + supported_quantization = [ + "awq", "gptq", "marlin", "squeezellm", "smoothquant" + ] rocm_not_supported_quantization = ["awq", "marlin", "smoothquant"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 905b01e69d7e..908aa4cc997e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -29,11 +29,8 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, - layer_name: str, - input_size_per_partition: int, - output_sizes_per_partition: List[int], - input_size: int, + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" @@ -46,7 +43,7 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights to the input tensor.""" raise NotImplementedError - + def maybe_update_loaded_weight_name(self, name: str) -> str: """Update the name of a loaded weight to enable generic handling of cases where serialized state_dict does not match vllm model definition. @@ -65,11 +62,9 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, - layer_name: str, - input_size_per_partition: int, - output_sizes_per_partition: List[int], - input_size: int, output_size: int, + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: weight = Parameter(torch.empty(sum(output_sizes_per_partition), input_size_per_partition, @@ -78,7 +73,6 @@ def create_weights(self, set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) return {"weight": weight} - def apply_weights(self, weights: Dict[str, torch.Tensor], x: torch.Tensor, @@ -190,10 +184,11 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() self.output_size_per_partition = divide(self.output_size, tp_size) self.output_sizes_per_partition = [self.output_size_per_partition] - # If QKV or MergedColumn, use output size of each partition. + # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_sizes_per_partition = [ - divide(output_size, tp_size) for output_size in self.output_sizes + divide(output_size, tp_size) + for output_size in self.output_sizes ] self.skip_bias_add = skip_bias_add @@ -288,15 +283,14 @@ def __init__( self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__( - layer_name=layer_name, - input_size=input_size, - output_size=sum(output_sizes), - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - linear_method=linear_method) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -373,8 +367,7 @@ def weight_loader(self, elif param_shard_splitter is not None: logical_widths = getattr(param, "logical_widths") param_data, loaded_weight = param_shard_splitter( - param_data, loaded_weight, loaded_shard_id, logical_widths - ) + param_data, loaded_weight, loaded_shard_id, logical_widths) else: ignore_warning = getattr(param, "ignore_warning", False) @@ -383,7 +376,7 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") - + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -445,20 +438,19 @@ def __init__( output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ - self.num_heads * self.head_size * tp_size, # q_proj - self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__( - layer_name=layer_name, - input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - linear_method=linear_method) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -513,7 +505,7 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] - + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": @@ -538,15 +530,15 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - param_data = param_data.narrow( - output_dim, shard_offset, shard_size) + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + shard_size) # If a param_shard_splitter is defined by the LinearMethod, use it. elif param_shard_splitter is not None: logical_widths = getattr(param, "logical_widths") @@ -561,10 +553,8 @@ def weight_loader(self, "QKVParallelLinear, assume the weight is the same " "for all partitions.") - assert ( - param_data.shape == loaded_weight.shape or - (len(param_data.shape) == 0 and len(loaded_weight.shape) == 0) - ) + assert (param_data.shape == loaded_weight.shape or + (len(param_data.shape) == 0 and len(loaded_weight.shape) == 0)) param_data.copy_(loaded_weight) @@ -630,7 +620,7 @@ def __init__( input_size_per_partition=self.input_size_per_partition, output_sizes_per_partition=[self.output_size], input_size=self.input_size, - output_size=self.output_size, + output_size=self.output_size, params_dtype=self.params_dtype, ) for name, weight in self.linear_weights.items(): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 7cf94ae9f44e..cc089893bd5f 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,15 +79,11 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights( - self, - layer_name: str, - input_size_per_partition: int, - output_sizes_per_partition: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - del layer_name, input_size, output_size # Unused. + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size, output_size # Unused. output_size_per_partition = sum(output_sizes_per_partition) if input_size_per_partition % self.quant_config.group_size != 0: diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 59d217567919..6883668b0a7e 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -98,7 +98,7 @@ def create_weights( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del layer_name, input_size, output_size # Unused. + del layer_name, input_size, output_size # Unused. output_size_per_partition = sum(output_sizes_per_partition) if params_dtype != torch.float16: diff --git a/vllm/model_executor/layers/quantization/smoothquant/__init__.py b/vllm/model_executor/layers/quantization/smoothquant/__init__.py index 2f62cee49d95..4074f7379c07 100644 --- a/vllm/model_executor/layers/quantization/smoothquant/__init__.py +++ b/vllm/model_executor/layers/quantization/smoothquant/__init__.py @@ -1,11 +1,8 @@ from vllm.model_executor.layers.quantization.smoothquant.formats import ( - SmoothQuantFormat -) + SmoothQuantFormat) from vllm.model_executor.layers.quantization.smoothquant.config import ( - SmoothQuantConfig, - SmoothQuantLinearMethod -) + SmoothQuantConfig, SmoothQuantLinearMethod) __all__ = [ "SmoothQuantFormat", diff --git a/vllm/model_executor/layers/quantization/smoothquant/config.py b/vllm/model_executor/layers/quantization/smoothquant/config.py index 788481a51ae1..13a7f5ba3742 100644 --- a/vllm/model_executor/layers/quantization/smoothquant/config.py +++ b/vllm/model_executor/layers/quantization/smoothquant/config.py @@ -3,9 +3,8 @@ import torch from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import ( - LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.smoothquant.formats import ( @@ -14,8 +13,7 @@ SmoothQuantStaticPerTensor, ) from vllm.model_executor.layers.quantization.smoothquant.cutlass_gemm import ( - cutlass_gemm_dq -) + cutlass_gemm_dq) LAYER_KEYS = ["qkv", "out", "fc1", "fc2"] FORMAT_REGISTRY = { @@ -23,36 +21,34 @@ "per-tensor": SmoothQuantStaticPerTensor, } + def get_sq_format_cls(format_key: str) -> Type[SmoothQuantFormat]: if format_key not in FORMAT_REGISTRY: raise ValueError(f"Invalid smoothquant format: {format_key}") return FORMAT_REGISTRY[format_key] + class SmoothQuantConfig(QuantizationConfig): """Config class for SmoothQuant. Reference: https://github.com/mit-han-lab/smoothquant """ - def __init__(self, - layer_format_map: Dict[str, str]) -> None: + + def __init__(self, layer_format_map: Dict[str, str]) -> None: self.layer_format_map = layer_format_map for key, format in self.layer_format_map.items(): if key not in LAYER_KEYS: raise ValueError( - f"Found key of {key} in {self.layer_format_map}, " - f"but key must be one of {LAYER_KEYS}" - ) + f"Found key of {key} in {self.layer_format_map}, " + f"but key must be one of {LAYER_KEYS}") if format not in FORMAT_REGISTRY: raise ValueError( f"Found format of {format} in {self.layer_format_map}, " - f"but format must be one of {FORMAT_REGISTRY}" - ) + f"but format must be one of {FORMAT_REGISTRY}") for key in LAYER_KEYS: if key not in self.layer_format_map: - raise ValueError( - f"Could not find {key} in {layer_format_map}" - ) + raise ValueError(f"Could not find {key} in {layer_format_map}") def __repr__(self) -> str: return (f"SmoothQuantConfig(layer_format_map={self.layer_format_map})") @@ -82,17 +78,18 @@ def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": if format in FORMAT_REGISTRY: layer_format_map[layer_key] = format return cls(layer_format_map) - + def get_linear_method(self) -> "SmoothQuantLinearMethod": return SmoothQuantLinearMethod(self) + class SmoothQuantLinearMethod(LinearMethodBase): + def __init__(self, sq_config: SmoothQuantConfig) -> None: self.sq_config = sq_config self.sq_type = None - def maybe_update_loaded_weight_name(self, - name: str) -> str: + def maybe_update_loaded_weight_name(self, name: str) -> str: """Convert serialized name k_dequant_scale to dequant_scale. This function is called by model_cls.load_weights() during the weight @@ -108,21 +105,20 @@ def shard_id_as_int(self, shard_id: Union[str, int]) -> int: return shard_id assert isinstance(shard_id, str) - qkv_idxs = { "q": 0, "k": 1, "v": 2 } + qkv_idxs = {"q": 0, "k": 1, "v": 2} assert shard_id in qkv_idxs return qkv_idxs[shard_id] - def scales_shard_splitter(self, - param: torch.Tensor, - loaded_weight: torch.Tensor, - shard_id: Union[str, int], - logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: shard_id = self.shard_id_as_int(shard_id) - offset = sum(logical_widths[:shard_id]) + offset = sum(logical_widths[:shard_id]) size = logical_widths[shard_id] # update loaded weight with copies for broadcast. loaded_weight = loaded_weight.repeat(size) - return param[offset : offset + size], loaded_weight + return param[offset:offset + size], loaded_weight def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: """ @@ -137,7 +133,7 @@ def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: layer_name: Name of the layer we are creating the LinearMethod for. Returns sq_linear_method: SmoothQuantLinearMethod with the right SQFormat. - """ + """ # Note: AutoSmoothQuant Serialization is not very good yet. # # It looks like the following (which does not map to layer names in the model): @@ -160,24 +156,21 @@ def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: # return get_sq_format_cls(sq_format)() HACKED_REMAP_FOR_LLAMA = { - "qkv": "qkv", - "o_proj": "out", - "gate_up": - "fc1", "down": "fc2", + "qkv": "qkv", + "o_proj": "out", + "gate_up": "fc1", + "down": "fc2", } for match_key, lookup_key in HACKED_REMAP_FOR_LLAMA.items(): if match_key in layer_name: sq_format = self.sq_config.layer_format_map[lookup_key] return get_sq_format_cls(sq_format)() - + raise ValueError - - def create_weights(self, - layer_name: str, - input_size_per_partition: int, - output_sizes_per_partition: int, - input_size: int, + + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: del input_size, output_size @@ -187,8 +180,10 @@ def create_weights(self, torch.empty( sum(output_sizes_per_partition), input_size_per_partition, - device="cuda", dtype=torch.int8, - ), requires_grad=False, + device="cuda", + dtype=torch.int8, + ), + requires_grad=False, ) set_weight_attrs(weight, { "input_dim": 1, @@ -197,20 +192,22 @@ def create_weights(self, if len(output_sizes_per_partition) == 1: # Single static scale for the entire tensor. - dequant_scale = Parameter( - torch.empty((1),device='cuda', dtype=params_dtype), - requires_grad=False - ) + dequant_scale = Parameter(torch.empty((1), + device='cuda', + dtype=params_dtype), + requires_grad=False) else: # Static scale for each logical weight (e.g. 3 for QKV). - dequant_scale = Parameter( - torch.empty((sum(output_sizes_per_partition)), - device='cuda', dtype=params_dtype), - requires_grad=False - ) - set_weight_attrs(dequant_scale, - {"shard_splitter": self.scales_shard_splitter, - "logical_widths" : output_sizes_per_partition}) + dequant_scale = Parameter(torch.empty( + (sum(output_sizes_per_partition)), + device='cuda', + dtype=params_dtype), + requires_grad=False) + set_weight_attrs( + dequant_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_sizes_per_partition + }) return { "weight": weight, @@ -219,9 +216,9 @@ def create_weights(self, "sq_format": self.get_layer_format(layer_name) } - def _quantize(self, - x: torch.Tensor, - sq_format: SmoothQuantFormat) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def _quantize( + self, x: torch.Tensor, sq_format: SmoothQuantFormat + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Quantize activations. Args: @@ -257,5 +254,5 @@ def apply_weights(self, x_q, activation_scales = self._quantize(x, sq_format) # GEMM and DQ - return cutlass_gemm_dq(x_q, weight_q, x.dtype, static_scales, activation_scales) - + return cutlass_gemm_dq(x_q, weight_q, x.dtype, static_scales, + activation_scales) diff --git a/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py b/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py index 05ae38c3343e..ab0bb0808fb2 100644 --- a/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py +++ b/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py @@ -1,4 +1,3 @@ - import cutlass from cutlass import Tensor as FakeTensor import cutlass.epilogue @@ -11,7 +10,7 @@ logger = init_logger("cutlass_gemm") def setup_dequant_epilogue(plan : cutlass.op.Gemm, - dq: torch.Tensor, + dq: torch.Tensor, static_scales: Optional[torch.Tensor], activation_scales: Optional[torch.Tensor]) \ -> Tuple[cutlass.op.Gemm, Dict]: @@ -28,15 +27,16 @@ def epilog_with_scales(accum, scales): D = accum * scales return D - epilog_tensors = { - 'scales' : static_scales, - 'D' : dq - } + epilog_tensors = {'scales': static_scales, 'D': dq} epilogue_trace_tensors = { - "accum": FakeTensor(element=torch.int32, shape=dq.shape, - layout_tag=cutlass.LayoutType.RowMajor), - 'scales' : static_scales, - 'D' : dq, + "accum": + FakeTensor(element=torch.int32, + shape=dq.shape, + layout_tag=cutlass.LayoutType.RowMajor), + 'scales': + static_scales, + 'D': + dq, } epilog_fn = epilog_with_scales @@ -45,31 +45,43 @@ def epilog_with_scales(accum, scales): epilogue_trace_tensors['act_scales'] = activation_scales epilog_fn = epilog_with_scales_and_act_scales - plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, epilogue_trace_tensors) - return plan, epilog_tensors - -def cutlass_gemm_dq(x_q : torch.Tensor, - w_q : torch.Tensor, - dtype: torch.dtype, - static_scales: torch.Tensor, - activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: - - dq = torch.empty((x_q.shape[0], w_q.shape[0]), - dtype=dtype, device="cuda") - - plan = cutlass.op.Gemm(element_A=x_q.dtype, element_B=w_q.dtype, - element_C=dq.dtype, element_D=dq.dtype, - layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.ColumnMajor, - layout_C=cutlass.LayoutType.RowMajor, - element_accumulator=torch.int32, - # TODO (varun) : lets not have kernel cc here please. - kernel_cc=80) - - plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, activation_scales) - - plan.run(x_q, w_q.t(), dq, dq, alpha=1, beta=0, - visitor_args=visitor_args, print_module=False) + plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, + epilogue_trace_tensors) + return plan, epilog_tensors + + +def cutlass_gemm_dq( + x_q: torch.Tensor, + w_q: torch.Tensor, + dtype: torch.dtype, + static_scales: torch.Tensor, + activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + + dq = torch.empty((x_q.shape[0], w_q.shape[0]), dtype=dtype, device="cuda") + + plan = cutlass.op.Gemm( + element_A=x_q.dtype, + element_B=w_q.dtype, + element_C=dq.dtype, + element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32, + # TODO (varun) : lets not have kernel cc here please. + kernel_cc=80) + + plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, + activation_scales) + + plan.run(x_q, + w_q.t(), + dq, + dq, + alpha=1, + beta=0, + visitor_args=visitor_args, + print_module=False) dq = dq.view(*x_q.shape[:-1], -1) return dq diff --git a/vllm/model_executor/layers/quantization/smoothquant/formats.py b/vllm/model_executor/layers/quantization/smoothquant/formats.py index 4155ef64ffe3..4f82b0d67a99 100644 --- a/vllm/model_executor/layers/quantization/smoothquant/formats.py +++ b/vllm/model_executor/layers/quantization/smoothquant/formats.py @@ -7,10 +7,11 @@ class SmoothQuantFormat(ABC): + @abstractmethod - def quantize_op(self, - x: torch.Tensor, - x_q: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def quantize_op( + self, x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Quantize the input and (optionally compute dequant scales). Args: @@ -24,20 +25,22 @@ def quantize_op(self, class SmoothQuantDynamicPerToken(SmoothQuantFormat): - def quantize_op(self, - x: torch.Tensor, + + def quantize_op(self, x: torch.Tensor, x_q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Notes: Returns quantized activaiton and dynamic activation scales. """ - activation_scales = torch.empty((x.numel() // x.shape[-1], 1), dtype=x.dtype, device=x.device) + activation_scales = torch.empty((x.numel() // x.shape[-1], 1), + dtype=x.dtype, + device=x.device) ops.quant(x_q, x, activation_scales) return x_q, activation_scales - + class SmoothQuantStaticPerTensor(SmoothQuantFormat): - def quantize_op(self, - x: torch.Tensor, + + def quantize_op(self, x: torch.Tensor, x_q: torch.Tensor) -> Tuple[torch.Tensor, None]: """Notes: Returns quantized activaiton and no dynamic scales. diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 893e6781089d..6b800fc12509 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,15 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights( - self, - layer_name: str, - input_size_per_partition: int, - output_sizes_per_partition: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - del layer_name, input_size # Unused. + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size # Unused. output_size_per_partition = sum(output_sizes_per_partition) if input_size_per_partition % self.quant_config.pack_factor != 0: diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index de7910c4860b..152d1ea95869 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -88,7 +88,8 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, model = model_class(model_config.hf_config, vision_language_config, linear_method) - if not hasattr(model_class, "supported_lora_modules") and lora_config: + if not hasattr(model_class, + "supported_lora_modules") and lora_config: raise ValueError( f"Model {model_class.__name__} does not support LoRA, " "but LoRA is enabled. Support for this model may " diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f488711d1158..426eb8846437 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -123,6 +123,7 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): def get_supported_smoothquant_archs() -> List[str]: return list(_SUPPORTED_SMOOTHQUANT_MODELS.keys()) + __all__ = [ "ModelRegistry", ] diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index c437f28d1fcf..35aa1f4fee29 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -266,11 +266,10 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer( - parent_name=f"model.layers.{idx}", - config=config, - position_embedding=position_embedding, - linear_method=linear_method) + BaiChuanDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + position_embedding=position_embedding, + linear_method=linear_method) for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -371,7 +370,7 @@ def load_weights(self, # Update name of the loaded_weight if needed by the LinearMethod. if self.linear_method: name = self.linear_method.maybe_update_loaded_weight_name(name) - + if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 973f70803b87..c3c8d6e23edb 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -288,7 +288,8 @@ def __init__( and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekMoE(parent_name=f"{parent_name}.mlp", - config=config, linear_method=linear_method) + config=config, + linear_method=linear_method) else: self.mlp = DeepseekMLP( parent_name=f"{parent_name}.mlp", @@ -348,7 +349,8 @@ def __init__( ) self.layers = nn.ModuleList([ DeepseekDecoderLayer(parent_name=f"model.layers.{idx}", - config=config, layer_idx=idx, + config=config, + layer_idx=idx, linear_method=linear_method) for idx in range(config.num_hidden_layers) ]) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 48963eab4868..c18b3e2f719b 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -140,11 +140,13 @@ def __init__( self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(parent_name=f"{parent_name}.attn", - config=config, linear_method=linear_method) + config=config, + linear_method=linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(parent_name=f"{parent_name}.mlp", intermediate_size=inner_dim, - config=config, linear_method=linear_method) + config=config, + linear_method=linear_method) def forward( self, @@ -186,9 +188,9 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block( - parent_name=f"transformer.h.{idx}", - config=config, linear_method=linear_method) + GPT2Block(parent_name=f"transformer.h.{idx}", + config=config, + linear_method=linear_method) for idx in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -263,7 +265,7 @@ def load_weights(self, # Update name of the loaded_weight if needed by the LinearMethod. if self.linear_method: name = self.linear_method.maybe_update_loaded_weight_name(name) - + if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e396cd44a51f..9569af31348a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -249,12 +249,10 @@ def forward( class LlamaModel(nn.Module): - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None - ) -> None: + def __init__(self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index e3379b4880be..05ce50551934 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -168,12 +168,12 @@ def __init__(self, super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention( - parent_name=f"{parent_name}.self_attn", - config=config, linear_method=linear_method) - self.mlp = PhiMLP( - parent_name=f"{parent_name}.mlp", - config=config, linear_method=linear_method) + self.self_attn = PhiAttention(parent_name=f"{parent_name}.self_attn", + config=config, + linear_method=linear_method) + self.mlp = PhiMLP(parent_name=f"{parent_name}.mlp", + config=config, + linear_method=linear_method) def forward( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index d8ca9b4c6037..c0af1221ae67 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -86,18 +86,16 @@ def forward(self, x): class Qwen2Attention(nn.Module): - def __init__( - self, - parent_name: str, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - use_sliding_window: bool = False, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None - ) -> None: + def __init__(self, + parent_name: str, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + use_sliding_window: bool = False, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -249,7 +247,8 @@ def __init__( ) self.layers = nn.ModuleList([ Qwen2DecoderLayer(parent_name=f"model.layers.{idx}", - config=config, layer_idx=idx, + config=config, + layer_idx=idx, linear_method=linear_method) for idx in range(config.num_hidden_layers) ]) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index e3730777014b..bd2b8ca9ad69 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -162,11 +162,13 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(parent_name=f"{parent_name}.self_attn", - config=config, - linear_method=linear_method) + self.self_attn = Starcoder2Attention( + parent_name=f"{parent_name}.self_attn", + config=config, + linear_method=linear_method) self.mlp = Starcoder2MLP(parent_name=f"{parent_name}.mlp", - config=config, linear_method=linear_method) + config=config, + linear_method=linear_method) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -214,7 +216,8 @@ def __init__(self, config.hidden_size) self.layers = nn.ModuleList([ Starcoder2DecoderLayer(parent_name=f"model.layers.{idx}", - config=config, linear_method=linear_method) + config=config, + linear_method=linear_method) for idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f984a8bda4d7..d967bb7b3367 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -985,7 +985,6 @@ def forward( # Return the output tensor. return self.output_buffers["hidden_states"] - def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs)