Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] [Triton] [AMD] Adding Triton implementations awq_dequantize and awq_gemm to support AWQ #7386

Merged
merged 59 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ff27ffa
Add awq_dequantize_triton
rasmith Jul 26, 2024
f9b6e74
Add awq_dequantize_triton
rasmith Jul 26, 2024
7b49a76
Merge branch 'ransmith_awq_dequantize_triton' of github.com:rasmith/v…
rasmith Jul 31, 2024
e2c3ba5
Merge branch 'vllm-project:main' into ransmith_awq_dequantize_triton
rasmith Jul 31, 2024
ec14fe9
Use any instead of all
rasmith Jul 31, 2024
fd80f7f
ruff checks
rasmith Jul 31, 2024
370c9f0
run isort
rasmith Jul 31, 2024
bdd0ab7
run yapf
rasmith Jul 31, 2024
915e0ae
Format for PR
rasmith Jul 31, 2024
3b3a563
Merge branch 'ransmith_awq_dequantize_triton' of github.com:rasmith/v…
rasmith Aug 9, 2024
150db8c
Merge branch 'vllm-project:main' into ransmith_awq_dequantize_triton
rasmith Aug 9, 2024
00dee49
Merge branch 'main' into ransmith_awq_dequantize_triton
rasmith Aug 9, 2024
a8ef8c2
Merge branch 'ransmith_awq_dequantize_triton' of github.com:rasmith/v…
rasmith Aug 9, 2024
2ebd212
Have working awq_gemm in Triton
rasmith Aug 9, 2024
e3073bc
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 10, 2024
5326dde
Optimizations to awq_gemm
rasmith Aug 10, 2024
fb43aa4
Small cleanup
rasmith Aug 10, 2024
43abe7a
ruff and yapf linting/formatting
rasmith Aug 12, 2024
91c6741
isort/ruff fixing
rasmith Aug 12, 2024
962ea59
add env VLLM_USE_TRITON_AWQ
rasmith Aug 14, 2024
c9df260
Add tests
rasmith Aug 16, 2024
c7b63e8
awq for rocm in config
rasmith Aug 16, 2024
5cf14db
add dimension assertions
rasmith Aug 16, 2024
23cf001
fix typo
rasmith Aug 16, 2024
f94c1b0
yappity yapf
rasmith Aug 16, 2024
5887e77
merge main
rasmith Aug 16, 2024
8594e25
Merge branch 'vllm-project:main' into ransmith_awq_gemm_triton
rasmith Aug 16, 2024
64e5251
Merge main
rasmith Aug 16, 2024
86f2ec6
warning message for AWQ on ROCm and not setting VLLM_USE_TRITON_AWQ
rasmith Aug 19, 2024
d32212a
VLLM_USE_TRITON_AWQ enabled automatically
rasmith Aug 19, 2024
6514622
parameterized unit tests
rasmith Aug 20, 2024
8a1f6f2
cleanup
rasmith Aug 20, 2024
39d44a2
ruff
rasmith Aug 20, 2024
34e06b5
yapf
rasmith Aug 20, 2024
4f3148f
yapf
rasmith Aug 20, 2024
010c80e
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 20, 2024
4895074
test cleanup
rasmith Aug 21, 2024
0e1862c
test cleanup
rasmith Aug 21, 2024
24a6b3b
yapf
rasmith Aug 21, 2024
3d2854c
merge main
rasmith Aug 22, 2024
c3b8102
Adjust threshold
rasmith Aug 22, 2024
a84c7d7
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 23, 2024
c7fbacf
simplify unit test and use assert_close
rasmith Aug 24, 2024
11860d6
clean up test
rasmith Aug 24, 2024
bea93a2
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 24, 2024
0c45b68
use marlin tolerance
rasmith Aug 24, 2024
bbfb4d9
update test
rasmith Aug 24, 2024
13bb612
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 24, 2024
226e7fb
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 24, 2024
c4e3fd1
Merge branch 'vllm-project:main' into ransmith_awq_gemm_triton
rasmith Aug 25, 2024
62612ee
Support more group sizes
rasmith Aug 26, 2024
5d91e78
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 26, 2024
ba434dc
Merge branch 'ransmith_awq_gemm_triton' of github.com:rasmith/vllm in…
rasmith Aug 26, 2024
2db93e0
assert added
rasmith Aug 26, 2024
f07c241
ruff
rasmith Aug 26, 2024
e95dfc4
ruff
rasmith Aug 26, 2024
efbd8a5
isort
rasmith Aug 26, 2024
69573dd
test update
rasmith Aug 26, 2024
d456232
update comment
rasmith Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions tests/kernels/test_awq_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Tests for the AWQ Triton kernel.

Run `pytest tests/kernels/test_awq_triton.py`.
"""
import pytest
import torch

from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)

device = "cuda"


def reverse_awq_order(t: torch.Tensor):
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
reverse_order_tensor = torch.arange(
t.shape[-1],
dtype=torch.int32,
device=t.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)

t = t[:, reverse_order_tensor] & 0xF
return t


# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int) -> torch.Tensor:

if group_size == -1:
group_size = qweight.shape[0]

bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)

iweights = torch.bitwise_right_shift(qweight[:, :, None],
shifts[None, None, :]).to(torch.int8)

iweights = iweights.view(iweights.shape[0], -1)

zeros = torch.bitwise_right_shift(qzeros[:, :, None],
shifts[None, None, :]).to(torch.int8)
zeros = zeros.view(qzeros.shape[0], -1)
zeros = reverse_awq_order(zeros)

iweights = reverse_awq_order(iweights)

iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)

scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales


# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
def test_dequantize(qweight_rows, qweight_cols, group_size):

if group_size == -1:
group_size = qweight_rows

qweight_dtype = torch.int32
scales_rows = qweight_rows // group_size
scales_cols = qweight_cols * 8
scales_dtype = torch.float16
zeros_rows = scales_rows
zeros_cols = qweight_cols
zeros_dtype = torch.int32

torch.manual_seed(0)

qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device)
scales = torch.rand(scales_rows,
scales_cols,
dtype=scales_dtype,
device=device)
zeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device)

iweights_triton = awq_dequantize_triton(qweight, scales, zeros)

assert (not torch.any(torch.isinf(iweights_triton))
and not torch.any(torch.isnan(iweights_triton)))

iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)

torch.testing.assert_close(iweights_triton, iweights_torch)


# input - [N, K]
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
@pytest.mark.parametrize("K", [128])
@pytest.mark.parametrize("M", [16, 24, 32])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("splitK", [1, 8])
def test_gemm(N, K, M, splitK, group_size):

if group_size == -1:
group_size = K

split_k_iters = splitK

input_rows = N
input_cols = K
input_dtype = torch.float32
qweight_rows = input_cols
qweight_cols = M // 8
scales_rows = qweight_rows // group_size
scales_cols = M
scales_dtype = torch.float32
qzeros_rows = scales_rows
qzeros_cols = qweight_cols

torch.manual_seed(0)

input = torch.rand((input_rows, input_cols),
dtype=input_dtype,
device=device)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
device=device)
qzeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(qzeros_rows, qzeros_cols),
device=device)
scales = torch.rand((scales_rows, scales_cols),
dtype=scales_dtype,
device=device)

output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
split_k_iters)

assert (not torch.any(torch.isinf(output_triton))
and not torch.any(torch.isnan(output_triton)))

dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)

output_torch = torch.matmul(input, dequantized_weights)

assert (not torch.any(torch.isinf(output_torch))
and not torch.any(torch.isnan(output_torch)))

torch.testing.assert_close(output_triton.cpu(),
output_torch.cpu(),
atol=1e-1,
rtol=1e-1)
9 changes: 9 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand Down Expand Up @@ -177,12 +178,20 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
return awq_dequantize_triton(qweight, scales, zeros)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)


def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_gemm_triton)
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


Expand Down
8 changes: 7 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _parse_quant_hf_config(self):

def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
Expand Down Expand Up @@ -319,6 +319,12 @@ def _verify_quantization(self) -> None:
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ def get_default_config_root():
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
}

# end-env-vars-definition
Expand Down
Loading
Loading