diff --git a/CMakeLists.txt b/CMakeLists.txt index 184515118128..ebea46103699 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/smoothquant/fused_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d6bf18c82e46..d5fddba233e3 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -259,10 +259,11 @@ def main(args: argparse.Namespace): "output length from the dataset.") parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=['awq', 'gptq', 'squeezellm', None], - default=None) + parser.add_argument( + '--quantization', + '-q', + choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None], + default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 64f86381d9db..1efd714a26e1 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -5,3 +5,4 @@ #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" #include "dtype_fp8.cuh" +#include "dtype_int8.cuh" diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb..51407f35e2d0 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) { return c; } +// for compiling, the above function seems to be useless +inline __device__ Float4_ add(Float4_ a, Float4_ b) { + Float4_ c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + // Vector multiplication. template<> inline __device__ float mul(float a, float b) { diff --git a/csrc/attention/dtype_int8.cuh b/csrc/attention/dtype_int8.cuh new file mode 100644 index 000000000000..91e6ec40b038 --- /dev/null +++ b/csrc/attention/dtype_int8.cuh @@ -0,0 +1,49 @@ +#pragma once + +#include +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +namespace vllm { +// define int8 vector types for quantization of kv cache + +template<> +struct Vec { + using Type = int8_t; +}; + +template<> +struct Vec { + using Type = int16_t; +}; + +template<> +struct Vec { + using Type = int32_t; +}; + +template<> +struct Vec { + using Type = int64_t; +}; + +template<> +struct FloatVec { + using Type = float; +}; + +template<> +struct FloatVec { + using Type = float2; +}; + +template<> +struct FloatVec { + using Type = Float4_; +}; + +template<> +struct FloatVec { + using Type = Float8_; +}; +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4b..20ea511a8529 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -9,12 +9,20 @@ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index 41ecc1e89371..76ba6188c7a8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -131,6 +131,27 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); +void dequant( + torch::Tensor& out, + torch::Tensor& input, + float scale); + +void dequant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale, + float weight_dequant_scale); + +void quant( + torch::Tensor& out, + torch::Tensor& input, + float scale); + +void quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc16211..983455878023 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -49,6 +49,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "fused_add_rms_norm", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); + ops.def( + "quant", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + float>(&quant), + "Quant."); + ops.def( + "quant", + py::overload_cast< + torch::Tensor&, + torch::Tensor&, + torch::Tensor&>( + &quant), + "Per-token quant."); // Rotary embedding ops.def( diff --git a/csrc/quantization/smoothquant/fused_kernels.cu b/csrc/quantization/smoothquant/fused_kernels.cu new file mode 100644 index 000000000000..1d23a5a0653c --- /dev/null +++ b/csrc/quantization/smoothquant/fused_kernels.cu @@ -0,0 +1,91 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" +#include "../../reduction_utils.cuh" +#include "quant_utils.cuh" + +namespace vllm { + +template +__global__ void quant_kernel( + const scalar_t* __restrict__ input, + int8_t* __restrict__ out, + scale_type scale, + const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + if constexpr (use_per_token_quant) { + float amax_val = 0.0f; + const float zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = (float)input[token_idx * hidden_size + i]; + val = val > zero ? val : -val; + if (val > amax_val) + amax_val = val; + } + + __shared__ float s_amax; + const float block_amax_val = blockReduceMax(amax_val); + if (tid == 0) { + s_amax = block_amax_val; + scale[token_idx] = block_amax_val / 127.0f; + } + __syncthreads(); + + float tmp_scale = 127.0f / s_amax; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) * tmp_scale); + } + } else { + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } + } +} +} // namespace vllm + +void quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale, + hidden_size); + }); +} + +void quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale) { // [num_tokens] + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), + out.data_ptr(), + scale.data_ptr(), + hidden_size); + }); +} \ No newline at end of file diff --git a/csrc/quantization/smoothquant/quant_utils.cuh b/csrc/quantization/smoothquant/quant_utils.cuh new file mode 100644 index 000000000000..dcbf3da3fcb7 --- /dev/null +++ b/csrc/quantization/smoothquant/quant_utils.cuh @@ -0,0 +1,243 @@ +// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +#pragma once + +#include +#include +#include +#include +#include "../../attention/attention_dtypes.h" +#include "../../attention/dtype_float32.cuh" +using namespace vllm; + +// this function is for function matching, delete it after writing customized dispatch functions +inline __device__ int8_t quant(double a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ int8_t quant(float a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ short quant(float2 a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + return int16; +} + +inline __device__ int32_t quant(float4 a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + int8[2] = round(max(-128.f, min(127.f, (a.z - zp) / scale))); + int8[3] = round(max(-128.f, min(127.f, (a.w - zp) / scale))); + return int32; +} + +// float16 to int8 +inline __device__ int8_t quant(uint16_t a, const float scale, const float zp) +{ + int8_t int8; + float b = half_to_float(a); + int8 = round(max(-128.f, min(127.f, (b - zp) / scale))); + return int8; +} + +// float16x2 to int8x2 +inline __device__ int16_t quant(uint32_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + float2 b = half2_to_float2(a); + + int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale))); + return int16; +} + +// float16x4 to int8x4 +inline __device__ int32_t quant(uint2 a, const float scale, const float zp) +{ + union { + int16_t int16[2]; + int32_t int32; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + return int32; +} + +// float16x8 to int8x8 +inline __device__ int64_t quant(uint4 a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + int16[2] = quant(a.z, scale, zp); + int16[3] = quant(a.w, scale, zp); + return int64; +} + +// int8 to float32, then `vec_conversion` to target format +inline __device__ float dequant(int8_t a, const float scale, const float zp) +{ + float b = a * scale + zp; + return b; +} + +// int8x2 to float32x2 +inline __device__ float2 dequant(int16_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + int16 = a; + + float2 b; + b.x = int8[0] * scale + zp; + b.y = int8[1] * scale + zp; + return b; +} + +// int8x4 to float32x4 +inline __device__ Float4_ dequant(int32_t a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int32 = a; + + Float4_ b; + b.x.x = (int8[0] * scale) + zp; + b.x.y = (int8[1] * scale) + zp; + b.y.x = (int8[2] * scale) + zp; + b.y.y = (int8[3] * scale) + zp; + return b; +} + +inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + int64 = a; + + Float8_ b; + b.x = dequant(int16[0], scale, zp); + b.y = dequant(int16[1], scale, zp); + b.z = dequant(int16[2], scale, zp); + b.w = dequant(int16[3], scale, zp); + return b; +} + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template<> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + + return b; +} + +template<> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { + return __float22bfloat162_rn(a); +} + +template<> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { + bf16_4_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + return b; +} + +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { + bf16_8_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + b.z = vec_conversion<__nv_bfloat162, float2>(a.z); + b.w = vec_conversion<__nv_bfloat162, float2>(a.w); + return b; +} + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} \ No newline at end of file diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bb5171f854d5..4065efc8efa2 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -62,4 +62,29 @@ __inline__ __device__ T blockReduceSum(T val) { return val; } +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + return val; +} +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + val = warpReduceMax(val); // get maxx in each warp + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + __syncthreads(); + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + return val; +} } // namespace vllm diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..6b548d5e8921 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_profile.py b/examples/offline_profile.py new file mode 100644 index 000000000000..b62ffb209854 --- /dev/null +++ b/examples/offline_profile.py @@ -0,0 +1,251 @@ +import argparse +import torch +import sys +import json +import inspect + +from dataclasses import dataclass, asdict +from typing import Optional +from vllm import LLM, SamplingParams +from vllm.profiler import nm_profile + +BATCH_SIZE_DEFAULT = 1 +PROMPT_LEN_DEFAULT = 256 +MAX_SEQ_LEN_DEFAULT = 1024 + + +@dataclass +class ProfileContext: + model: str + tokenizer: str + model_revision: str + quantization: str + max_seq_len: int + max_num_batched_tokens: int + prompt_len: int + batch_size: int + dtype: str + tensor_parallel_size: int + allow_cuda_graphs: bool + + +def get_dtype(dtype: str): + if dtype == "torch.float": + return torch.float + else: + return dtype + + +def run_profile(context: ProfileContext, csv_output: Optional[str], + json_output: Optional[str]): + print("Run profile with:") + for key, value in asdict(context).items(): + print(f" {key} = {value}") + + # Create sampling params + sampling_params = SamplingParams(temperature=0.0, top_k=1, max_tokens=8) + + # Sparsity is in the future + # Create LLM + llm = LLM(model=context.model, + tokenizer=context.tokenizer + if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_seq_len, + quantization=context.quantization, + dtype=get_dtype(context.dtype), + max_num_batched_tokens=context.max_num_batched_tokens) + + batch_size = context.batch_size + prompt_len = context.prompt_len + + scheduler_config = llm.llm_engine.scheduler_config + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + if batch_size * prompt_len > max_num_batched_tokens: + print(f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max_num_batched_tokens") + sys.exit(-1) + if batch_size >= max_num_seqs: + print( + f"ERROR: chosen batch_size ({batch_size}) is larger than " + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " + f"single profile step, please choose a smaller batch size") + sys.exit(-1) + + for i in range(batch_size): + llm.llm_engine.add_request( + request_id=f"seq{i}", + prompt=None, + prompt_token_ids=torch.randint( + 128, # 128 to skip over special tokens + llm.llm_engine.model_config.get_vocab_size() // 2, + size=(prompt_len, )).tolist(), + sampling_params=sampling_params) + + with nm_profile() as prefill_prof: + llm.llm_engine.step() # First step is prefill + + with nm_profile() as decode_prof: + llm.llm_engine.step() + + prefill_results = prefill_prof.results + decode_results = decode_prof.results + + print("=" * 80) + print(f"= Prefill Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_model_table() + print() + print("=" * 80) + print(f"= Decode Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results.print_model_table() + print() + print("=" * 80) + print(f"= Prefill Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + prefill_results.print_summary_table() + print() + print("=" * 80) + print(f"= Decode Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results.print_summary_table() + + if csv_output: + csv_filename_base = csv_output.rstrip(".csv") + prefill_results.export_model_stats_table_csv( + csv_filename_base + "_prefill_model_table.csv") + prefill_results.export_summary_stats_table_csv( + csv_filename_base + "_prefill_summary_table.csv") + decode_results.export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results.export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") + + if json_output: + cuda_devices = [ + torch.cuda.get_device_properties(dev_idx) + for dev_idx in range(torch.cuda.device_count()) + ] + + json_dict = { + "context": { + "python_version": f"{sys.version}", + "torch_version": f"{torch.__version__}", + "torch_cuda_version": f"{torch.version.cuda}", + "cuda_devices": f"{cuda_devices}", + **asdict(context) + }, + "prefill": prefill_results.convert_stats_to_dict(), + "decode": decode_results.convert_stats_to_dict() + } + + with open(json_output.rstrip(".json") + ".json", "w+") as f: + json.dump(json_dict, f, indent=2) + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model", + type=str, + required=True, + help='The name or path of a HuggingFace Transformers model.') + parser.add_argument("--tokenizer", + type=str, + default=None, + help="path to the tokenizer") + + parser.add_argument("--model-revision", type=str, default=None) + parser.add_argument( + "--csv", + type=str, + default=None, + help="Export the results as multiple csv file. This should be the root " + "filename, will create _prefill_model_table.csv, " + "_prefill_summary_table.csv, " + "_decode_model_table.csv, and " + "_decode_summary_table.csv") + parser.add_argument( + "--json", + type=str, + default=None, + help="Export the results as a json file. This should be the filename") + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None], + default=None, + help="The method used to quantize the model weights, " + "options are \"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"" + ) + parser.add_argument("--dtype", + type=str, + default='auto', + help="model dtype") + parser.add_argument( + "--max-seq-len", + type=int, + default=MAX_SEQ_LEN_DEFAULT, + help=f"Maximum length of a sequence (including prompt and output), " + f"default={MAX_SEQ_LEN_DEFAULT}") + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=None, + help="Maximum number of tokens to be processed in a single iteration. " + " Should be greater than batch-size * prompt-len so the prefill can " + " run in a single iteration.") + parser.add_argument( + "--prompt-len", + type=int, + default=PROMPT_LEN_DEFAULT, + help=f"Length of the random prompt to use when profiling, all batched " + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument("--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}") + parser.add_argument("--tensor-parallel-size", + "-tp", + type=int, + default=1, + help="Number of GPUs to use i.e. tensor parallelism, " + "default=1") + parser.add_argument( + "--allow-cuda-graphs", + action='store_true', + help="Enables cuda graphs to be used, well remove a lot of the module " + "level info in the profiler results since almost everything runs in " + "the graph where we do not have access to an informative stack trace") + + args = parser.parse_args() + + context = ProfileContext( + **{ + k: v + for k, v in vars(args).items() + if k in inspect.signature(ProfileContext).parameters + }) + run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/examples/offline_quantized_inference.py b/examples/offline_quantized_inference.py new file mode 100644 index 000000000000..5935341f1f9b --- /dev/null +++ b/examples/offline_quantized_inference.py @@ -0,0 +1,35 @@ +from vllm import LLM, SamplingParams +import torch + +hf_path = "nm-testing/Nous-Hermes-Llama2-13b-smoothquant" +model_path = hf_path + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.0, top_k=1, max_tokens=20) + +# Create an LLM. +llm = LLM(model="nm-testing/Nous-Hermes-Llama2-13b-smoothquant", + gpu_memory_utilization=0.9, + max_model_len=2048, + quantization="smoothquant", + dtype=torch.float, + enforce_eager=True, + tensor_parallel_size=1, + max_num_batched_tokens=7000) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/simple_test.py b/examples/simple_test.py new file mode 100644 index 000000000000..949bd8e75b8e --- /dev/null +++ b/examples/simple_test.py @@ -0,0 +1,54 @@ +import argparse +from vllm import LLM, SamplingParams + +MODELS = { + "tinyllama-fp16": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "tinyllama-marlin": "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", + "tinyllama-gptq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + "tinyllama-awq": "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ", + "gemma-fp16": "google/gemma-1.1-2b-it", + "gemma-awq": "TechxGenus/gemma-1.1-2b-it-AWQ", + "gemma-gptq": "TechxGenus/gemma-1.1-2b-it-GPTQ", + "phi-2-fp16": "abacaj/phi-2-super", + "phi-2-marlin": "neuralmagic/phi-2-super-marlin", + "deepseek-fp16": "deepseek-ai/deepseek-coder-1.3b-instruct", + "deepseek-gptq": "TheBloke/deepseek-coder-1.3b-instruct-GPTQ", + "deepseek-awq": "TheBloke/deepseek-coder-1.3b-instruct-AWQ", + "deepseek-moe-fp16": "deepseek-ai/deepseek-moe-16b-chat", + "baichuan-fp16": "baichuan-inc/Baichuan2-7B-Chat", + "baichuan-gptq": "csdc-atl/Baichuan2-7B-Chat-GPTQ-Int4", + "qwen-fp16": "Qwen/Qwen1.5-1.8B", + "qwen-gptq": "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4", + "qwen-awq": "Qwen/Qwen1.5-1.8B-Chat-AWQ", + "gpt2-fp16": "openai-community/gpt2", + "gpt2-gptq": "etyacke/GPT2-GPTQ-int4", + "starcoder2-fp16": "bigcode/starcoder2-3b", + "starcoder2-gptq": "TechxGenus/starcoder2-3b-GPTQ", + "starcoder2-awq": "TechxGenus/starcoder2-3b-AWQ", +} + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str) +parser.add_argument("--tensor-parallel-size", type=int, default=1) +args = parser.parse_args() + +if args.model not in MODELS: + print(f"Got model id of {args.model}; Must be in {list(MODELS.keys())}") + raise ValueError +else: + model_id = MODELS[args.model] + print(f"Using model_id = {model_id}") + +messages = [{"role": "user", "content": "What is deep learning?"}] + +model = LLM(model_id, + enforce_eager=True, + max_model_len=1024, + tensor_parallel_size=args.tensor_parallel_size, + dtype="float16", + trust_remote_code=True) +prompt = model.llm_engine.tokenizer.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) +out = model.generate(prompt, SamplingParams(max_tokens=50)) +print(f"\n-----prompt\n{prompt}") +print(f"\n-----generation\n{out[0].outputs[0].text}") diff --git a/experiments.sh b/experiments.sh new file mode 100755 index 000000000000..a45229d31499 --- /dev/null +++ b/experiments.sh @@ -0,0 +1,142 @@ +#! /bin/bash + +set -e +set -u +#set -x + +# global args +model_path=NousResearch/Nous-Hermes-Llama2-13b +quant_model_path=nm-testing/Nous-Hermes-Llama2-13b-smoothquant + +# model generation args +enforce_eager=True +max_seq_len=2048 +max_num_batched_tokens=7000 +tensor_parallel_size=1 + +# experiment args +prefill_prompt_len=512 +decode_batch_sizes=(1 2 8 16 32 64 128) + +run_prefill() { + model=$1 + desc=$2 + dtype=$3 + output_directory=$4 + + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/prefill_${prefill_prompt_len}_llama13_${desc}-${now} + + echo "Running prefill ${prefill_prompt_len} model ${model} desc ${desc} dtype ${dtype} store at ${out_base}" + python3 examples/offline_profile.py --model $model \ + --batch-size 1 \ + --prompt-len $prefill_prompt_len \ + --quantization smoothquant \ + --max-seq-len $max_seq_len \ + --dtype $dtype \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 +} + +run_decode() { + model=$1 + desc=$2 + dtype=$3 + output_directory=$4 + + for bs in "${decode_batch_sizes[@]}" + do + now=`date +"%Y-%m-%d-%I-%M-%S"` + out_base=${output_directory}/decode_bs_${bs}_llama13_${desc}-${now} + + echo "Running decode bs ${bs} model ${model} desc ${desc} dtype ${dtype} store at ${out_base}" + python3 examples/offline_profile.py --model $model \ + --batch-size $bs \ + --prompt-len 1 \ + --quantization smoothquant \ + --max-seq-len $max_seq_len \ + --dtype $dtype + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size $tensor_parallel_size \ + --json $out_base \ + --csv $out_base > ${out_base}_stdout.txt 2>&1 + done +} + +## Arg parser and invocation + +usage() { + echo`` + echo "Run profiler" + echo + echo "usage: ${0} " + echo + echo " -t - pass in quant or base" + echo " -d - description" + echo " -m - model data type" + echo " -n - pass in num_benchmark_iterations" + echo " -o - out directory" + echo +} + +exp_type="" # should be either quant or base +num_benchmark_iterations=1 +output_directory="./" +desc="" +dtype="" + +while getopts ':t:n:o:d:m:h:' OPT; do + case "${OPT}" in + t) + exp_type="${OPTARG}" + ;; + n) + num_benchmark_iterations=${OPTARG} + ;; + o) + output_directory="${OPTARG}" + ;; + d) + desc="${OPTARG}" + ;; + m) + dtype="${OPTARG}" + ;; + h) + usage + exit 1 + ;; + esac +done + +if [ "$exp_type" != "quant" -a "$exp_type" != "base" ]; +then + echo "Invalid arg $exp_type" + usage + exit 1 +fi + +if [[ "${output_directory}" == "" || "${desc}" == "" || "${dtype}" == "" ]]; +then + echo "Either output_directory or desc is not set" + usage + exit 1 +fi + + +for i in $(seq 1 $num_benchmark_iterations); +do + echo "Running benchmark iteration ${i} ..." + if [[ "${exp_type}" == "quant" ]]; + then + run_prefill $quant_model_path "${exp_type}-${desc}" $dtype $output_directory + #run_decode $quant_model_path "${exp_type}-${desc}" $dtype $output_directory + fi + if [[ "${exp_type}" == "base" ]]; + then + run_prefill $model_path "${exp_type}-${desc}" $dtype $output_directory + #run_decode $model_path "${exp_type}-${desc}" $dtype $output_directory + fi +done diff --git a/neuralmagic/tools/profiler/print_table.py b/neuralmagic/tools/profiler/print_table.py new file mode 100644 index 000000000000..728a5e386089 --- /dev/null +++ b/neuralmagic/tools/profiler/print_table.py @@ -0,0 +1,77 @@ +import argparse +import json + +from vllm.profiler.nm_profile import SummaryStatsEntry, ModelStatsEntry +from vllm.profiler.utils import indent_string, TablePrinter +from typing import Dict + + +def flatten_entries(entry_cls, profile_dict: Dict): + entries_and_depth = [] + + def get_entries(node, curr_depth=0): + entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) + + for child in node["children"]: + get_entries( + child, + curr_depth=curr_depth + 1, + ) + + for root in profile_dict: + get_entries(root) + + return entries_and_depth + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--json-trace", + type=str, + required=True, + help="json trace file output by " + "examples/offline_profile.py") + parser.add_argument("--phase", + type=str, + choices=["prefill", "decode"], + required=True, + help="The phase to print the table for.") + parser.add_argument("--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the " + "layerwise model table") + + args = parser.parse_args() + + with open(args.json_trace, "r") as f: + profile_data = json.load(f) + + if args.table == "summary": + entries_and_depths = flatten_entries( + SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) + column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + elif args.table == "model": + entries_and_depths = flatten_entries( + ModelStatsEntry, profile_data[args.phase]["model_stats"]) + column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + + # ident entry names based on the depth + entries = [] + for entry, depth in entries_and_depths: + entry.name = indent_string( + entry.name, + indent=depth, + indent_style=lambda indent: "|" + "-" * indent + " ") + entries.append(entry) + + TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py new file mode 100644 index 000000000000..9d0f7f328543 --- /dev/null +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -0,0 +1,209 @@ +import argparse +import json +import pandas as pd +import matplotlib.pyplot as plt + + +def trim_string_back(string: str, width: int): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +def abbreviate_known_names(name: str): + abbreviations = { + "MergedColumnParallelLinear": "MCPLinear", + "QKVParallelLinear": "QKVPLinear", + "RowParallelLinear": "RPLinear", + "weight=": "w=", + "bfloat16": "bf16", + "float16": "f16", + } + for key, value in abbreviations.items(): + name = name.replace(key, value) + return name + + +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_profile.py") + parser.add_argument( + "--output", + type=str, + required=False, + help="Output figure file, should be a image file such as pdf, " + "jpeg, png, etc., defaults to .pdf") + parser.add_argument("--level", + type=str, + default="module", + choices=["module", "kernel"]) + parser.add_argument("--top_k", + type=int, + default=9, + help="Only graph the top `top_k` entries by time.") + parser.add_argument("--ignore_sampler", + action='store_true', + help="Ignore everything under the \"Sampler\" module") + + args = parser.parse_args() + + ignore_sampler = args.ignore_sampler + make_names_unique = False + top_k = args.top_k + + if args.level == "module": + depth = -2 + make_names_unique = True + elif args.level == "kernel": + depth = -1 + else: + raise Exception(f"Unexpected level value ({args.level})") + + if ignore_sampler: + print("WARNING: ignoring Sampler time so the pct_cuda_time will not " + "add up to 100%") + + json_trace = args.json_trace + output = args.output if args.output else json_trace.strip(".json") + ".pdf" + + with open(json_trace, "r") as f: + profile_data = json.load(f) + + prefill_entries_and_traces = [] + decode_entries_and_traces = [] + + def largest_dist_from_leaf(node, depth=0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + def get_entries_at_depth(depth, + entries_and_traces, + node, + curr_depth=0, + trace=()): + if ignore_sampler and node["entry"]["name"] == "Sampler": + return + + if (depth >= 0 and depth == curr_depth) or ( + depth < 0 + and largest_dist_from_leaf(node) == (abs(depth) - 1)): + entries_and_traces.append((node["entry"], trace)) + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + for root in profile_data["prefill"]["summary_stats"]: + get_entries_at_depth(depth, prefill_entries_and_traces, root) + for root in profile_data["decode"]["summary_stats"]: + get_entries_at_depth(depth, decode_entries_and_traces, root) + + def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [ + (entry, trace) for entry, trace in entries_and_traces + if entry["name"] == name + ] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + if make_names_unique: + attempt_to_make_names_unique(prefill_entries_and_traces) + attempt_to_make_names_unique(decode_entries_and_traces) + + def keep_only_top_entries(df, metric, top_k=9): + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + + prefill_df = pd.DataFrame( + [entry for entry, _ in prefill_entries_and_traces]) + prefill_df["phase"] = "prefill" + decode_df = pd.DataFrame([entry for entry, _ in decode_entries_and_traces]) + decode_df["phase"] = "decode" + + if top_k: + keep_only_top_entries(prefill_df, "cuda_time_us", top_k) + keep_only_top_entries(decode_df, "cuda_time_us", top_k) + + df = pd.concat([prefill_df, decode_df]) + df["cuda_time_ms"] = df["cuda_time_us"] / 1000 + + fig, axes = plt.subplots(2, figsize=(5, 8), sharex=True) + + def plot_metric(metric: str, ax, add_totals=False): + pivoted_df = df.pivot_table(index="phase", + columns="name", + values=metric, + aggfunc="sum") + pivoted_df.plot.bar(stacked=True, legend=False, ax=ax) + ax.set_ylabel(metric) + + if add_totals: + ax.bar_label(ax.containers[-1]) + + plot_metric("cuda_time_ms", ax=axes[0], add_totals=True) + plot_metric("pct_cuda_time", ax=axes[1]) + + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(0.93, 0.5)) + shorten_plot_legend_strings(legend, 50) + + context = profile_data["context"] + plt.suptitle( + f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + context['sparsity'] if context['sparsity'] else ''}" + ) + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6ee75e8139c0..f3adcb519ed9 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,3 +8,4 @@ vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 triton >= 2.1.0 +nvidia-cutlass == 3.5.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1..f78949c24718 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,8 @@ aiohttp # Multimodal pillow + +# Profiling +matplotlib +pandas +pyarrow diff --git a/tests/kernels/test_fusion.py b/tests/kernels/test_fusion.py new file mode 100644 index 000000000000..a11515167a76 --- /dev/null +++ b/tests/kernels/test_fusion.py @@ -0,0 +1,47 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp(-128, 127).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.quant(out2, x, scale) + assert torch.allclose(out1, out2, atol=1) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_per_token_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + scale1 = torch.max(x, dim=1)[0].to(torch.float32) / 127.0 + out1 = (x / scale1.view(-1, 1)).round().clamp(-128, 127).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale2 = torch.empty(num_tokens, dtype=torch.float32, device="cuda") + ops.quant(out2, x, scale2) + assert torch.allclose(out1, out2, atol=1) diff --git a/vllm/config.py b/vllm/config.py index 6762a75f25f2..ea53b334139f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,8 +173,10 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] - rocm_not_supported_quantization = ["awq", "marlin"] + supported_quantization = [ + "awq", "gptq", "marlin", "squeezellm", "smoothquant" + ] + rocm_not_supported_quantization = ["awq", "marlin", "smoothquant"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -868,6 +870,7 @@ def get_image_input_enum_type( _STR_DTYPE_TO_TORCH_DTYPE = { + "int8": torch.int8, "half": torch.float16, "float16": torch.float16, "float": torch.float32, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a6197942645e..3e6742eea2af 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -270,17 +270,18 @@ def add_cli_args( action='store_true', help='disable logging statistics') # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'gptq', 'squeezellm', None], - default=EngineArgs.quantization, - help='Method used to quantize the weights. If ' - 'None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') + parser.add_argument( + '--quantization', + '-q', + type=str, + choices=['awq', 'gptq', 'squeezellm', 'smoothquant', None], + default=None, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2..908aa4cc997e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -29,8 +29,8 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" @@ -44,6 +44,12 @@ def apply_weights(self, """Apply the weights to the input tensor.""" raise NotImplementedError + def maybe_update_loaded_weight_name(self, name: str) -> str: + """Update the name of a loaded weight to enable generic handling of + cases where serialized state_dict does not match vllm model definition. + """ + return name + class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. @@ -56,11 +62,11 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - weight = Parameter(torch.empty(output_size_per_partition, + weight = Parameter(torch.empty(sum(output_sizes_per_partition), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -83,6 +89,7 @@ class ReplicatedLinear(torch.nn.Module): """Replicated linear layer. Args: + layer_name: name of the layer in the state dict. input_size: input dimension of the linear layer. output_size: output dimension of the linear layer. bias: If true, add bias. @@ -93,6 +100,7 @@ class ReplicatedLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -103,6 +111,7 @@ def __init__( super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add @@ -113,8 +122,8 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.layer_name, self.input_size, [self.output_size], + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -139,6 +148,7 @@ class ColumnParallelLinear(torch.nn.Module): its second dimension as A = [A_1, ..., A_p]. Args: + layer_name: name of the layer in the state dict. input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. @@ -154,6 +164,7 @@ class ColumnParallelLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -165,12 +176,21 @@ def __init__( super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_sizes_per_partition = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_sizes_per_partition = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + self.skip_bias_add = skip_bias_add if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -179,8 +199,13 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) + layer_name=self.layer_name, + input_size_per_partition=self.input_size, + output_sizes_per_partition=self.output_sizes_per_partition, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + ) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -246,6 +271,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): def __init__( self, + layer_name: str, input_size: int, output_sizes: List[int], bias: bool = True, @@ -257,8 +283,14 @@ def __init__( self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -266,6 +298,19 @@ def weight_loader(self, loaded_shard_id: Optional[int] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -318,6 +363,12 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -325,6 +376,7 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -340,6 +392,7 @@ class QKVParallelLinear(ColumnParallelLinear): be replicated while the query heads are partitioned. Args: + layer_name: name of the layer in the state dict. hidden_size: input hidden state size of the transformer. head_size: size of each attention head. total_num_heads: total number of attention query heads. @@ -355,6 +408,7 @@ class QKVParallelLinear(ColumnParallelLinear): def __init__( self, + layer_name: str, hidden_size: int, head_size: int, total_num_heads: int, @@ -383,8 +437,20 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__(layer_name=layer_name, + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + linear_method=linear_method) def weight_loader(self, param: Parameter, @@ -392,6 +458,18 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) if loaded_shard_id is None: # Loaded weight is already packed. @@ -427,6 +505,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -459,6 +539,12 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths") + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -466,7 +552,9 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") - assert param_data.shape == loaded_weight.shape + + assert (param_data.shape == loaded_weight.shape or + (len(param_data.shape) == 0 and len(loaded_weight.shape) == 0)) param_data.copy_(loaded_weight) @@ -483,6 +571,7 @@ class RowParallelLinear(torch.nn.Module): | A_p | - - Arguments: + layer_name: name of the layer in the state dict. input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. Note that bias is not parallelized. @@ -498,6 +587,7 @@ class RowParallelLinear(torch.nn.Module): def __init__( self, + layer_name: str, input_size: int, output_size: int, bias: bool = True, @@ -509,6 +599,7 @@ def __init__( ): super().__init__() # Keep input parameters + self.layer_name = layer_name self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel @@ -525,8 +616,13 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) + layer_name=self.layer_name, + input_size_per_partition=self.input_size_per_partition, + output_sizes_per_partition=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + ) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -555,6 +651,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # TODO: canon + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ad988d48755b..301bcf48de6b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -6,11 +6,13 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.smoothquant import SmoothQuantConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "smoothquant": SmoothQuantConfig, "marlin": MarlinConfig, } diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..cc089893bd5f 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,13 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size, output_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 6115e7c3be95..868e09252bb2 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -51,14 +51,12 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_linear_method(self) -> LinearMethodBase: + def get_linear_method(self, name) -> LinearMethodBase: """Get the linear method to use for the quantized linear layer.""" raise NotImplementedError - @abstractmethod def get_scaled_act_names(self) -> List[str]: """Returns the activation function names that should be post-scaled. - For now, this is only used by AWQ. """ - raise NotImplementedError + return [] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..8c3492ae67d8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,13 +89,16 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer_name: str, input_size_per_partition: int, - output_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + del output_size, layer_name # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf..6883668b0a7e 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,13 +91,15 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer_name: str, input_size_per_partition: int, - output_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + del layer_name, input_size, output_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) if params_dtype != torch.float16: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/smoothquant/__init__.py b/vllm/model_executor/layers/quantization/smoothquant/__init__.py new file mode 100644 index 000000000000..4074f7379c07 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/__init__.py @@ -0,0 +1,11 @@ +from vllm.model_executor.layers.quantization.smoothquant.formats import ( + SmoothQuantFormat) + +from vllm.model_executor.layers.quantization.smoothquant.config import ( + SmoothQuantConfig, SmoothQuantLinearMethod) + +__all__ = [ + "SmoothQuantFormat", + "SmoothQuantConfig", + "SmoothQuantLinearMethod", +] diff --git a/vllm/model_executor/layers/quantization/smoothquant/config.py b/vllm/model_executor/layers/quantization/smoothquant/config.py new file mode 100644 index 000000000000..13a7f5ba3742 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/config.py @@ -0,0 +1,258 @@ +from typing import Any, Dict, List, Tuple, Type, Optional, Union + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.smoothquant.formats import ( + SmoothQuantFormat, + SmoothQuantDynamicPerToken, + SmoothQuantStaticPerTensor, +) +from vllm.model_executor.layers.quantization.smoothquant.cutlass_gemm import ( + cutlass_gemm_dq) + +LAYER_KEYS = ["qkv", "out", "fc1", "fc2"] +FORMAT_REGISTRY = { + "per-token": SmoothQuantDynamicPerToken, + "per-tensor": SmoothQuantStaticPerTensor, +} + + +def get_sq_format_cls(format_key: str) -> Type[SmoothQuantFormat]: + if format_key not in FORMAT_REGISTRY: + raise ValueError(f"Invalid smoothquant format: {format_key}") + return FORMAT_REGISTRY[format_key] + + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant. + + Reference: https://github.com/mit-han-lab/smoothquant + """ + + def __init__(self, layer_format_map: Dict[str, str]) -> None: + self.layer_format_map = layer_format_map + + for key, format in self.layer_format_map.items(): + if key not in LAYER_KEYS: + raise ValueError( + f"Found key of {key} in {self.layer_format_map}, " + f"but key must be one of {LAYER_KEYS}") + if format not in FORMAT_REGISTRY: + raise ValueError( + f"Found format of {format} in {self.layer_format_map}, " + f"but format must be one of {FORMAT_REGISTRY}") + for key in LAYER_KEYS: + if key not in self.layer_format_map: + raise ValueError(f"Could not find {key} in {layer_format_map}") + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(layer_format_map={self.layer_format_map})") + + def get_name(self) -> str: + return "smoothquant" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + # TODO: check if we support fp16 / bf16 as well. + return [torch.float] + + def get_min_capability(self) -> int: + # TODO: check if this is right. + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + layer_format_map: Dict[str, str] = {} + for layer_key, format in config.items(): + if format in FORMAT_REGISTRY: + layer_format_map[layer_key] = format + return cls(layer_format_map) + + def get_linear_method(self) -> "SmoothQuantLinearMethod": + return SmoothQuantLinearMethod(self) + + +class SmoothQuantLinearMethod(LinearMethodBase): + + def __init__(self, sq_config: SmoothQuantConfig) -> None: + self.sq_config = sq_config + self.sq_type = None + + def maybe_update_loaded_weight_name(self, name: str) -> str: + """Convert serialized name k_dequant_scale to dequant_scale. + + This function is called by model_cls.load_weights() during the weight + loading process to match on disk state dict to vllm state dict. + """ + if "dequant_scale" in name: + suffix = name.split('.')[-1] + name.replace(suffix, "dequant_scale") + return name + + def shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self.shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def get_layer_format(self, layer_name: str) -> SmoothQuantFormat: + """ + Gets the SmoothQuantFormat for a specific layer. + + SmoothQuantLinearMethod uses SmoothQuantType to support non-uniform quantization + (where each layer has a different format). To determine the SmoothQuantFormat + for a layer, we match the layer_name to the layer_keys=["qkv","out","fc1","fc2"] + and use layer_format_map to to determine the SQFormat. + + Args: + layer_name: Name of the layer we are creating the LinearMethod for. + Returns + sq_linear_method: SmoothQuantLinearMethod with the right SQFormat. + """ + # Note: AutoSmoothQuant Serialization is not very good yet. + # + # It looks like the following (which does not map to layer names in the model): + # { + # "qkv": "per-tensor", + # "out": "per-token", + # "fc1": "per-tensor", + # "fc2": "per-token" + # } + # + # So, this is a hack for llama now. But with the SparseMLConfig, we can make robust, + # where we actually use the layer_name in the model to look up what the format is + # based on the config. + # + # What it would actually look like: + # layer_config is None + # for supported_key in SUPPORTED_LAYER_KEYS: + # if supported_key in layer_name: + # sq_format = self.layer_mapping[lookup_key] + # return get_sq_format_cls(sq_format)() + + HACKED_REMAP_FOR_LLAMA = { + "qkv": "qkv", + "o_proj": "out", + "gate_up": "fc1", + "down": "fc2", + } + + for match_key, lookup_key in HACKED_REMAP_FOR_LLAMA.items(): + if match_key in layer_name: + sq_format = self.sq_config.layer_format_map[lookup_key] + return get_sq_format_cls(sq_format)() + + raise ValueError + + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + del input_size, output_size + + # Statically Quantized Weights. + weight = Parameter( + torch.empty( + sum(output_sizes_per_partition), + input_size_per_partition, + device="cuda", + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + + if len(output_sizes_per_partition) == 1: + # Single static scale for the entire tensor. + dequant_scale = Parameter(torch.empty((1), + device='cuda', + dtype=params_dtype), + requires_grad=False) + else: + # Static scale for each logical weight (e.g. 3 for QKV). + dequant_scale = Parameter(torch.empty( + (sum(output_sizes_per_partition)), + device='cuda', + dtype=params_dtype), + requires_grad=False) + set_weight_attrs( + dequant_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_sizes_per_partition + }) + + return { + "weight": weight, + "dequant_scale": dequant_scale, + "logical_widths": output_sizes_per_partition, + "sq_format": self.get_layer_format(layer_name) + } + + def _quantize( + self, x: torch.Tensor, sq_format: SmoothQuantFormat + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Quantize activations. + + Args: + x: Activation at floating point precision. + Returns: + x_q: Quantized activation at INT8 + activation_scales: Optional dynamic scales for each token. + """ + x_q = torch.empty_like(x, dtype=torch.int8, device="cuda") + x_q, activation_scales = sq_format.quantize_op(x, x_q) + return x_q, activation_scales + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward method. Computes Q --> GEMM --> DQ. + + Args: + weigths: Dictionary of weights, scales, and metadata. + x: Input in floating point precision. + bias: Optional bias. + Returns: + a_dq: Dequantized activation at floating point precision. + """ + if bias is not None: + raise NotImplementedError + weight_q = weights["weight"] + static_scales = weights["dequant_scale"] + sq_format = weights["sq_format"] + + # Q + x_q, activation_scales = self._quantize(x, sq_format) + + # GEMM and DQ + return cutlass_gemm_dq(x_q, weight_q, x.dtype, static_scales, + activation_scales) diff --git a/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py b/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py new file mode 100644 index 000000000000..ab0bb0808fb2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/cutlass_gemm.py @@ -0,0 +1,87 @@ +import cutlass +from cutlass import Tensor as FakeTensor +import cutlass.epilogue + +import torch +from typing import Optional, Tuple, Dict + +from vllm.logger import init_logger + +logger = init_logger("cutlass_gemm") + +def setup_dequant_epilogue(plan : cutlass.op.Gemm, + dq: torch.Tensor, + static_scales: Optional[torch.Tensor], + activation_scales: Optional[torch.Tensor]) \ + -> Tuple[cutlass.op.Gemm, Dict]: + + if all([static_scales is None, activation_scales is None]): + return plan, None + assert static_scales is not None + + def epilog_with_scales_and_act_scales(accum, scales, act_scales): + D = accum * scales * act_scales + return D + + def epilog_with_scales(accum, scales): + D = accum * scales + return D + + epilog_tensors = {'scales': static_scales, 'D': dq} + epilogue_trace_tensors = { + "accum": + FakeTensor(element=torch.int32, + shape=dq.shape, + layout_tag=cutlass.LayoutType.RowMajor), + 'scales': + static_scales, + 'D': + dq, + } + epilog_fn = epilog_with_scales + + if activation_scales is not None: + epilog_tensors['act_scales'] = activation_scales + epilogue_trace_tensors['act_scales'] = activation_scales + epilog_fn = epilog_with_scales_and_act_scales + + plan.epilogue_visitor = cutlass.epilogue.trace(epilog_fn, + epilogue_trace_tensors) + return plan, epilog_tensors + + +def cutlass_gemm_dq( + x_q: torch.Tensor, + w_q: torch.Tensor, + dtype: torch.dtype, + static_scales: torch.Tensor, + activation_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + + dq = torch.empty((x_q.shape[0], w_q.shape[0]), dtype=dtype, device="cuda") + + plan = cutlass.op.Gemm( + element_A=x_q.dtype, + element_B=w_q.dtype, + element_C=dq.dtype, + element_D=dq.dtype, + layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.ColumnMajor, + layout_C=cutlass.LayoutType.RowMajor, + element_accumulator=torch.int32, + # TODO (varun) : lets not have kernel cc here please. + kernel_cc=80) + + plan, visitor_args = setup_dequant_epilogue(plan, dq, static_scales, + activation_scales) + + plan.run(x_q, + w_q.t(), + dq, + dq, + alpha=1, + beta=0, + visitor_args=visitor_args, + print_module=False) + + dq = dq.view(*x_q.shape[:-1], -1) + return dq diff --git a/vllm/model_executor/layers/quantization/smoothquant/formats.py b/vllm/model_executor/layers/quantization/smoothquant/formats.py new file mode 100644 index 000000000000..4f82b0d67a99 --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant/formats.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +from vllm._C import ops + + +class SmoothQuantFormat(ABC): + + @abstractmethod + def quantize_op( + self, x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Quantize the input and (optionally compute dequant scales). + + Args: + x: Input data in floating point format. + x_q: Buffer for quantized inputs. + Returns: + x_q: Quantized input. + activation_scales: Optional dynamic scales for the activations. + """ + raise NotImplementedError + + +class SmoothQuantDynamicPerToken(SmoothQuantFormat): + + def quantize_op(self, x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Notes: + Returns quantized activaiton and dynamic activation scales. + """ + activation_scales = torch.empty((x.numel() // x.shape[-1], 1), + dtype=x.dtype, + device=x.device) + ops.quant(x_q, x, activation_scales) + return x_q, activation_scales + + +class SmoothQuantStaticPerTensor(SmoothQuantFormat): + + def quantize_op(self, x: torch.Tensor, + x_q: torch.Tensor) -> Tuple[torch.Tensor, None]: + """Notes: + Returns quantized activaiton and no dynamic scales. + """ + ops.quant(x_q, x, 1.0) + return x_q, None diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1..6b800fc12509 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,13 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + def create_weights(self, layer_name: str, input_size_per_partition: int, + output_sizes_per_partition: List[int], input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: + del layer_name, input_size # Unused. + output_size_per_partition = sum(output_sizes_per_partition) + if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0..152d1ea95869 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -55,6 +55,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, # Get the (maybe quantized) linear method. linear_method = None + quant_config = None if model_config.quantization is not None: quant_config = get_quant_config(model_config) capability = torch.cuda.get_device_capability() @@ -80,18 +81,21 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, if hasattr(model_class, "supported_lora_modules"): model = model_class(model_config.hf_config, linear_method, lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") else: if model_class not in _VISION_MODEL_CLASSES: model = model_class(model_config.hf_config, linear_method) else: model = model_class(model_config.hf_config, vision_language_config, linear_method) + + if not hasattr(model_class, + "supported_lora_modules") and lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17fc97056804..426eb8846437 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -74,6 +74,13 @@ "Sliding window attention is not yet supported in ROCm's flash attention", } +# Models supported by smoothquant +_SUPPORTED_SMOOTHQUANT_MODELS = { + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), +} + class ModelRegistry: @@ -112,6 +119,10 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + def get_supported_smoothquant_archs() -> List[str]: + return list(_SUPPORTED_SMOOTHQUANT_MODELS.keys()) + __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index fa5a27b5a697..35aa1f4fee29 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -75,6 +75,7 @@ class BaiChuanMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -82,13 +83,17 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -106,6 +111,7 @@ class BaiChuanAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, position_embedding: str, @@ -128,16 +134,18 @@ def __init__( # pylint: disable=invalid-name self.W_pack = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_heads, + layer_name=f"{parent_name}.W_pack", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=False, linear_method=linear_method, ) @@ -183,6 +191,7 @@ def forward( class BaiChuanDecoderLayer(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, position_embedding: str, linear_method: Optional[LinearMethodBase] = None): @@ -192,6 +201,7 @@ def __init__(self, max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = BaiChuanAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, position_embedding=position_embedding, @@ -200,6 +210,7 @@ def __init__(self, linear_method=linear_method, ) self.mlp = BaiChuanMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -255,8 +266,11 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, linear_method) - for _ in range(config.num_hidden_layers) + BaiChuanDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + position_embedding=position_embedding, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -353,6 +367,10 @@ def load_weights(self, params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4eba..c3c8d6e23edb 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -55,6 +55,7 @@ class DeepseekMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -63,14 +64,18 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method, - reduce_results=reduce_results) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + linear_method=linear_method, + reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -87,6 +92,7 @@ class DeepseekMoE(nn.Module): def __init__( self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None, ): @@ -102,7 +108,8 @@ def __init__( f"the number of experts {self.n_routed_experts}.") self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, + DeepseekMLP(parent_name=f"{parent_name}.experts", + hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, @@ -111,8 +118,9 @@ def __init__( ]) self.pack_params() - self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, + self.gate = ReplicatedLinear(layer_name=f"{parent_name}.gate", + input_size=config.hidden_size, + output_size=self.n_routed_experts, bias=False, linear_method=None) @@ -120,6 +128,7 @@ def __init__( intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) self.shared_experts = DeepseekMLP( + parent_name=f"{parent_name}.shared_experts", hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, @@ -173,6 +182,7 @@ class DeepseekAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -205,17 +215,19 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=False, linear_method=linear_method, ) @@ -251,6 +263,7 @@ class DeepseekDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: PretrainedConfig, layer_idx: int, linear_method: Optional[LinearMethodBase] = None, @@ -262,6 +275,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = DeepseekAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -273,9 +287,12 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + self.mlp = DeepseekMoE(parent_name=f"{parent_name}.mlp", + config=config, + linear_method=linear_method) else: self.mlp = DeepseekMLP( + parent_name=f"{parent_name}.mlp", hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -331,10 +348,11 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, - layer_idx, + DeepseekDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + layer_idx=idx, linear_method=linear_method) - for layer_idx in range(config.num_hidden_layers) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -417,6 +435,10 @@ def load_weights(self, load_format, revision, fall_back_to_pt=False): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 08609532b8b3..9476df293f84 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -75,6 +75,7 @@ class GemmaMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: Optional[str] = None, @@ -83,13 +84,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): @@ -102,6 +107,7 @@ def forward(self, x): class GemmaAttention(nn.Module): def __init__(self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -132,16 +138,18 @@ def __init__(self, self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=False, linear_method=linear_method, ) @@ -177,12 +185,14 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: GemmaConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = GemmaAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -192,6 +202,7 @@ def __init__( linear_method=linear_method, ) self.mlp = GemmaMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -247,8 +258,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) + GemmaDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -364,6 +377,10 @@ def load_weights(self, loaded_params = set() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f816a9996be..c18b3e2f719b 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -45,6 +45,7 @@ class GPT2Attention(nn.Module): def __init__( self, + parent_name: str, config: GPT2Config, linear_method: Optional[LinearMethodBase] = None, ): @@ -59,15 +60,17 @@ def __init__( self.scale = self.head_dim**-0.5 self.c_attn = QKVParallelLinear( - self.hidden_size, - self.head_dim, - total_num_heads, + layer_name=f"{parent_name}.c_attn", + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=total_num_heads, bias=True, linear_method=linear_method, ) self.c_proj = RowParallelLinear( - self.hidden_size, - self.hidden_size, + layer_name=f"{parent_name}.c_proj", + input_size=self.hidden_size, + output_size=self.hidden_size, bias=True, linear_method=linear_method, ) @@ -90,6 +93,7 @@ class GPT2MLP(nn.Module): def __init__( self, + parent_name: str, intermediate_size: int, config: GPT2Config, linear_method: Optional[LinearMethodBase] = None, @@ -97,14 +101,16 @@ def __init__( super().__init__() hidden_size = config.hidden_size self.c_fc = ColumnParallelLinear( - hidden_size, - intermediate_size, + layer_name=f"{parent_name}.c_fc", + input_size=hidden_size, + output_size=intermediate_size, bias=True, linear_method=linear_method, ) self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, + layer_name=f"{parent_name}.c_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=True, linear_method=linear_method, ) @@ -123,6 +129,7 @@ class GPT2Block(nn.Module): def __init__( self, + parent_name: str, config: GPT2Config, linear_method: Optional[LinearMethodBase] = None, ): @@ -132,9 +139,14 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, linear_method) + self.attn = GPT2Attention(parent_name=f"{parent_name}.attn", + config=config, + linear_method=linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, linear_method) + self.mlp = GPT2MLP(parent_name=f"{parent_name}.mlp", + intermediate_size=inner_dim, + config=config, + linear_method=linear_method) def forward( self, @@ -176,8 +188,10 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, linear_method) - for _ in range(config.num_hidden_layers) + GPT2Block(parent_name=f"transformer.h.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -248,6 +262,10 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ef19c41e67ae..9569af31348a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,6 +54,7 @@ class LlamaMLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -61,13 +62,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -84,6 +89,7 @@ class LlamaAttention(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -127,16 +133,18 @@ def __init__( self.kv_scale = 1.0 self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, bias=bias, linear_method=linear_method, ) @@ -174,6 +182,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, ) -> None: @@ -185,6 +194,7 @@ def __init__( 8192) sliding_window = getattr(config, "sliding_window", None) self.self_attn = LlamaAttention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", @@ -197,6 +207,7 @@ def __init__( sliding_window=sliding_window, ) self.mlp = LlamaMLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -238,12 +249,10 @@ def forward( class LlamaModel(nn.Module): - def __init__( - self, - config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: + def __init__(self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -257,8 +266,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -386,8 +397,13 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 40e068acaba7..05ce50551934 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -63,6 +63,7 @@ class PhiAttention(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() @@ -78,15 +79,17 @@ def __init__(self, # pylint: disable=C0103 self.qkv_proj = QKVParallelLinear( - self.hidden_size, - self.head_size, - self.total_num_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=self.hidden_size, + head_size=self.head_size, + total_num_heads=self.total_num_heads, bias=True, linear_method=linear_method, ) self.dense = RowParallelLinear( - self.hidden_size, - self.hidden_size, + layer_name=f"{parent_name}.dense", + input_size=self.hidden_size, + output_size=self.hidden_size, linear_method=linear_method, ) @@ -126,6 +129,7 @@ def forward( class PhiMLP(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() @@ -134,13 +138,15 @@ def __init__(self, n_inner = n_inner if n_inner is not None else 4 * config.hidden_size self.fc1 = ColumnParallelLinear( - config.hidden_size, - n_inner, + layer_name=f"{parent_name}.fc1", + input_size=config.hidden_size, + output_size=n_inner, linear_method=linear_method, ) self.fc2 = RowParallelLinear( - n_inner, - config.hidden_size, + layer_name=f"{parent_name}.fc2", + input_size=n_inner, + output_size=config.hidden_size, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) @@ -156,13 +162,18 @@ def forward(self, hidden_states): class PhiLayer(nn.Module): def __init__(self, + parent_name: str, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, linear_method) - self.mlp = PhiMLP(config, linear_method) + self.self_attn = PhiAttention(parent_name=f"{parent_name}.self_attn", + config=config, + linear_method=linear_method) + self.mlp = PhiMLP(parent_name=f"{parent_name}.mlp", + config=config, + linear_method=linear_method) def forward( self, @@ -195,8 +206,10 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, linear_method) - for _ in range(config.num_hidden_layers) + PhiLayer(parent_name=f"model.layers.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -281,6 +294,10 @@ def load_weights(self, for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8c92cd773f6b..c0af1221ae67 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -53,6 +53,7 @@ class Qwen2MLP(nn.Module): def __init__( self, + parent_name: str, hidden_size: int, intermediate_size: int, hidden_act: str, @@ -60,13 +61,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + layer_name=f"{parent_name}.gate_up_proj", + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear( + layer_name=f"{parent_name}.down_proj", + input_size=intermediate_size, + output_size=hidden_size, bias=False, linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -82,6 +87,7 @@ def forward(self, x): class Qwen2Attention(nn.Module): def __init__(self, + parent_name: str, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -114,16 +120,18 @@ def __init__(self, self.sliding_window = sliding_window if use_sliding_window else None self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=True, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, bias=False, linear_method=linear_method, ) @@ -159,6 +167,7 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, + parent_name: str, config: Qwen2Config, layer_idx: int, linear_method: Optional[LinearMethodBase] = None, @@ -170,6 +179,7 @@ def __init__( use_sliding_window = (config.use_sliding_window and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( + parent_name=f"{parent_name}.self_attn", hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, @@ -179,6 +189,7 @@ def __init__( linear_method=linear_method, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( + parent_name=f"{parent_name}.mlp", hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -235,8 +246,11 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, linear_method) - for layer_idx in range(config.num_hidden_layers) + Qwen2DecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + layer_idx=idx, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,6 +362,10 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 50d23e0a3b6e..bd2b8ca9ad69 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -46,6 +46,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, + parent_name: str, config: Starcoder2Config, linear_method: Optional[LinearMethodBase] = None): super().__init__() @@ -76,16 +77,18 @@ def __init__(self, self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( - self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + layer_name=f"{parent_name}.qkv_proj", + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=self.use_bias, linear_method=linear_method, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, + layer_name=f"{parent_name}.o_proj", + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, bias=self.use_bias, linear_method=linear_method, ) @@ -122,18 +125,21 @@ def forward( class Starcoder2MLP(nn.Module): def __init__(self, + parent_name: str, config: Starcoder2Config, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.c_fc = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, + layer_name=f"{parent_name}.c_fc", + input_size=config.hidden_size, + output_size=config.intermediate_size, bias=config.use_bias, linear_method=linear_method, ) self.c_proj = RowParallelLinear( - config.intermediate_size, - config.hidden_size, + layer_name=f"{parent_name}.c_proj", + input_size=config.intermediate_size, + output_size=config.hidden_size, bias=config.use_bias, linear_method=linear_method, ) @@ -151,13 +157,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Starcoder2DecoderLayer(nn.Module): def __init__(self, + parent_name: str, config: Starcoder2Config, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, - linear_method=linear_method) - self.mlp = Starcoder2MLP(config, linear_method=linear_method) + self.self_attn = Starcoder2Attention( + parent_name=f"{parent_name}.self_attn", + config=config, + linear_method=linear_method) + self.mlp = Starcoder2MLP(parent_name=f"{parent_name}.mlp", + config=config, + linear_method=linear_method) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -204,8 +215,10 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, linear_method=linear_method) - for _ in range(config.num_hidden_layers) + Starcoder2DecoderLayer(parent_name=f"model.layers.{idx}", + config=config, + linear_method=linear_method) + for idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -232,6 +245,7 @@ def __init__(self, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config + self.linear_method = linear_method self.model = Starcoder2Model(config, linear_method=linear_method) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size @@ -290,6 +304,10 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Update name of the loaded_weight if needed by the LinearMethod. + if self.linear_method: + name = self.linear_method.maybe_update_loaded_weight_name(name) + if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py new file mode 100644 index 000000000000..93ec4a800e60 --- /dev/null +++ b/vllm/profiler/__init__.py @@ -0,0 +1,5 @@ +from .nm_profile import nm_profile + +__all__ = [ + "nm_profile", +] diff --git a/vllm/profiler/nm_profile.py b/vllm/profiler/nm_profile.py new file mode 100644 index 000000000000..1912a300c02b --- /dev/null +++ b/vllm/profiler/nm_profile.py @@ -0,0 +1,346 @@ +import pandas as pd +import copy + +from collections import defaultdict +from dataclasses import dataclass, field, asdict +from vllm.profiler.utils import (indent_string, TablePrinter, event_has_module, + event_is_torch_op, event_module_repr, + event_torch_op_stack_trace) +from typing import Dict, List, Union, Optional, Tuple, Callable, TypeAlias +from torch.profiler import profile, ProfilerActivity +from torch.autograd.profiler import FunctionEvent +from torch._C._autograd import _ProfilerResult, _KinetoEvent, DeviceType +from torch._C._profiler import _EventType, _ProfilerEvent, _ExperimentalConfig + + +@dataclass +class _ModuleTreeNode: + event: _ProfilerEvent + parent: Optional['_ModuleTreeNode'] = None + children: List['_ModuleTreeNode'] = field(default_factory=list) + trace: str = "" + + @property + def is_leaf(self): + return (self.event.children is None or len(self.event.children) == 0) + + @property + def is_torch_op(self): + return event_is_torch_op(self.event) + + @property + def is_cuda(self): + return (self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA) + + +@dataclass +class SummaryStatsEntry: + name: str + cuda_time_us: float + pct_cuda_time: float + invocations: int + + +@dataclass +class ModelStatsEntry: + name: str + cpu_time_us: float + cuda_time_us: float + pct_cuda_time: float + trace: str + + +StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] + + +@dataclass +class _StatsTreeNode: + entry: StatsEntry + children: List[StatsEntry] + parent: Optional[StatsEntry] + + +@dataclass +class NMProfileResults(profile): + _kineto_results: _ProfilerResult + _kineto_event_correlation_map: Dict[int, + List[_KinetoEvent]] = field(init=False) + _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) + _module_tree: List[_ModuleTreeNode] = field(init=False) + _model_stats_tree: List[_StatsTreeNode] = field(init=False) + _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + + def __post_init__(self): + self._build_correlation_map() + self._build_module_tree() + self._build_stats_trees() + + def print_model_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + if column_widths: + _column_widths.update(**column_widths) + filtered_model_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._model_stats_tree) + if row.cuda_time_us > 0 or row.cpu_time_us > 0 + ] + TablePrinter(ModelStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_model_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def print_summary_table(self, column_widths: Dict[str, int] = None): + _column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + if column_widths: + _column_widths.update(**column_widths) + filtered_summary_table = [(depth, row) + for depth, row in self._flatten_stats_tree( + self._summary_stats_tree) + if row.cuda_time_us > 0] + TablePrinter(SummaryStatsEntry, _column_widths).print_table( + self._indent_row_names_based_on_depth( + filtered_summary_table, + indent_style=lambda indent: "|" + "-" * indent + " ")) + + def export_model_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._model_stats_tree) + ]) + df.to_csv(filename) + + def export_summary_stats_table_csv(self, filename: str): + df = pd.DataFrame([ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ]) + df.to_csv(filename) + + def convert_stats_to_dict(self) -> str: + return { + "summary_stats": + self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": + self._convert_stats_tree_to_dict(self._model_stats_tree) + } + + @staticmethod + def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, + StatsEntry]], + indent_style: Union[Callable[[int], + str], + str] = " "): + indented_rows = [] + for depth, row in depths_rows: + if row.cuda_time_us == 0: + continue + indented_row = copy.deepcopy(row) + indented_row.name = indent_string(indented_row.name, depth, + indent_style) + indented_rows.append(indented_row) + return indented_rows + + def _build_correlation_map(self): + self._kineto_event_correlation_map = defaultdict(list) + for event in self._kineto_results.events(): + self._kineto_event_correlation_map[event.correlation_id()].append( + event) + + def _build_module_tree(self): + self._module_tree = [] + event_tree = self._kineto_results.experimental_event_tree() + + def _df_traversal(event: _ProfilerEvent, + curr_node: Optional[_ModuleTreeNode] = None): + if event_has_module(event): + node = _ModuleTreeNode(event=event, parent=curr_node) + if curr_node: + curr_node.children.append(node) + else: + self._module_tree.append(node) + curr_node = node + + is_leaf = (event.children is None or len(event.children) == 0) + if is_leaf and curr_node: + node = _ModuleTreeNode( + event=event, + parent=curr_node, + trace=event_torch_op_stack_trace( + event, until=lambda x: event_has_module(x))) + curr_node.children.append(node) + curr_node = node + + for child in event.children: + _df_traversal(child, curr_node) + + for root in event_tree: + _df_traversal(root) + + def _get_kineto_gpu_event(self, node: _ModuleTreeNode): + if node.event.tag != _EventType.Kineto: + return None + correlated_kineto_events = self._kineto_event_correlation_map.get( + node.event.correlation_id, []) + iterator = (x for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA) + return next(iterator, None) + + def _cumulative_cuda_time(self, node: _ModuleTreeNode): + + def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): + if node.is_leaf and (gpu_kineto_event := + self._get_kineto_gpu_event(node)): + return gpu_kineto_event.duration_us() + else: + cumulative_cuda_time = 0 + for child in node.children: + cumulative_cuda_time += _cumulative_cuda_time_recursive( + child) + return cumulative_cuda_time + + return _cumulative_cuda_time_recursive(node) + + def _total_cuda_time(self): + return sum( + [self._cumulative_cuda_time(root) for root in self._module_tree]) + + def _build_stats_trees(self): + summary_dict: Dict[str, self.StatsTreeNode] = {} + total_cuda_time = self._total_cuda_time() + + def pct_cuda_time(cuda_time_us): + return (cuda_time_us / total_cuda_time) * 100 + + def build_summary_stats_tree_df( + node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None, + summary_trace: Tuple[str] = ()): + + if event_has_module(node.event): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_us() + else: + return None + + summary_trace = summary_trace + (name, ) + if summary_trace in summary_dict: + entry = summary_dict[summary_trace].entry + entry.cuda_time_us += cuda_time_us + entry.invocations += 1 + entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) + else: + new_node = _StatsTreeNode(entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1), + children=[], + parent=parent) + if parent: + parent.children.append(new_node) + summary_dict[summary_trace] = new_node + + for child in node.children: + build_summary_stats_tree_df(child, summary_dict[summary_trace], + summary_trace) + + return summary_dict[summary_trace] + + self._summary_stats_tree = [] + for root in self._module_tree: + self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + + def build_model_stats_tree_df(node: _ModuleTreeNode, + parent: Optional[_StatsTreeNode] = None): + if event_has_module(node.event, ): + name = event_module_repr(node.event) + cuda_time_us = self._cumulative_cuda_time(node) + cpu_time_us = node.event.duration_time_ns / 1000 + trace = "" + elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + name = gpu_kineto_event.name() + cuda_time_us = gpu_kineto_event.duration_us() + cpu_time_us = 0 + trace = node.trace + else: + return None + + new_node = _StatsTreeNode(entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace), + parent=parent, + children=[]) + if parent: + parent.children.append(new_node) + + for child in node.children: + build_model_stats_tree_df(child, new_node) + + return new_node + + self._model_stats_tree = [] + for root in self._module_tree: + self._model_stats_tree.append(build_model_stats_tree_df(root)) + + def _flatten_stats_tree( + self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: + entries: List[Tuple[int, StatsEntry]] = [] + + def df_traversal(node: _StatsTreeNode, depth=0): + entries.append((depth, node.entry)) + for child in node.children: + df_traversal(child, depth=depth + 1) + + for root in tree: + df_traversal(root) + + return entries + + def _convert_stats_tree_to_dict(self, + tree: List[_StatsTreeNode]) -> List[Dict]: + root_dicts: List[Dict] = [] + + def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): + curr_json_list.append({ + "entry": asdict(node.entry), + "children": [] + }) + for child in node.children: + df_traversal(child, curr_json_list[-1]["children"]) + + for root in tree: + df_traversal(root, root_dicts) + + return root_dicts + + +class nm_profile(profile): + + def __init__(self): + super().__init__( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True, + with_modules=True, + experimental_config=_ExperimentalConfig(verbose=True)) + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.results = NMProfileResults(self.profiler.kineto_results) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py new file mode 100644 index 000000000000..f8ead593d178 --- /dev/null +++ b/vllm/profiler/utils.py @@ -0,0 +1,146 @@ +import dataclasses + +from typing import Callable, Dict, Type, List, Union + +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata + +# +# String / Print Manipulation +# + + +def trim_string_front(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[offset:] + if len(string) > 3: + string = "..." + string[3:] + return string + + +def trim_string_back(string, width): + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +class TablePrinter: + + def __init__(self, row_cls: Type[dataclasses.dataclass], + column_widths: Dict[str, int]): + self.row_cls = row_cls + self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] + self.column_widths = column_widths + assert set(self.column_widths.keys()) == set(self.fieldnames) + + def print_table(self, rows: List[dataclasses.dataclass]): + self._print_header() + self._print_line() + for row in rows: + self._print_row(row) + + def _print_header(self): + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + print(trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n") + + def _print_row(self, row): + assert isinstance(row, self.row_cls) + + for i, f in enumerate(self.fieldnames): + last = (i == len(self.fieldnames) - 1) + col_width = self.column_widths[f] + val = getattr(row, f) + + val_str = "" + if isinstance(val, str): + val_str = trim_string_back(val, col_width).ljust(col_width) + elif type(val) in [float, int]: + val_str = f"{float(val):>.2f}".rjust(col_width) + else: + val_str = f"{val}".rjust(col_width) + print(val_str, end=" | " if not last else "\n") + + def _print_line(self): + total_col_width = 0 + for column_width in self.column_widths.values(): + total_col_width += column_width + print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) + + +def indent_string(string: str, + indent: int, + indent_style: Union[Callable[[int], str], str] = " ") -> str: + if indent: + if isinstance(indent_style, str): + return indent_style * indent + string + else: + return indent_style(indent) + string + else: + return string + + +# +# _ProfilerEvent utils +# + + +def event_has_module(event: _ProfilerEvent) -> bool: + event_type, typed_event = event.typed + if event_type == _EventType.PyCall: + return typed_event.module is not None + return False + + +def event_is_torch_op(event: _ProfilerEvent) -> bool: + return event.tag == _EventType.TorchOp + + +def event_arg_repr(arg) -> str: + if arg is None or type(arg) in [float, int, bool, str]: + return f"{arg}" + elif isinstance(arg, list): + return f"[{', '.join([event_arg_repr(x) for x in arg])}]" + elif isinstance(arg, tuple): + return f"({', '.join([event_arg_repr(x) for x in arg])})" + else: + assert isinstance(arg, + _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ', '.join([str(x) for x in arg.sizes]) + return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" + + +def event_torch_op_repr(event: _ProfilerEvent) -> str: + assert event.tag == _EventType.TorchOp + args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + return f"{event.name}({args_str})".replace("aten::", "") + + +def event_module_repr(event: _ProfilerEvent) -> str: + assert event_has_module(event) + module = event.typed[1].module + if module.parameters and len(module.parameters) > 0: + args_str = ', '.join( + [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + return f"{module.cls_name}({args_str})" + else: + return module.cls_name + + +def event_torch_op_stack_trace(curr_event: _ProfilerEvent, + until: Callable[[_ProfilerEvent], bool]) -> str: + trace = "" + curr_event = curr_event.parent + while curr_event and not until(curr_event): + if event_is_torch_op(curr_event): + if len(trace) > 0: + trace += " <- " + trace += event_torch_op_repr(curr_event) + curr_event = curr_event.parent + + return trace diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a..d967bb7b3367 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -900,9 +900,10 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() -class CUDAGraphRunner: +class CUDAGraphRunner(nn.Module): def __init__(self, model: nn.Module): + super().__init__() self.model = model self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {}