Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Oct 15, 2024
1 parent e7b33bc commit 4dbe339
Show file tree
Hide file tree
Showing 12 changed files with 704 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ torchao.quantization
Int4WeightOnlyQuantizer
quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int4_weight_cutlass
int8_dynamic_activation_int8_weight
int4_weight_only
int8_weight_only
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
import cutlass_library
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"

extra_link_args = []
extra_compile_args = {
"cxx": [
Expand All @@ -74,6 +80,7 @@ def get_extensions():
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
"-I" + cutlass_include_dir,
]
}

Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int4_weight_cutlass,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
Expand All @@ -25,6 +26,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int4_weight_cutlass(),
int8_dynamic_activation_int8_weight(),
]
if do_int4:
Expand Down
1 change: 1 addition & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Quantizer,
TwoStepQuantizer,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int4_weight_cutlass,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
Expand Down
51 changes: 51 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# FIXME: move this test to the appropriate test file!!!

import copy

from torchao.quantization import quantize_
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight_cutlass

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

import pytest


class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128, 256)
self.linear2 = torch.nn.Linear(256, 128, bias=False)

def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x


class TestS8S4LinearCUTLASS(TestCase):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_s8s4_linear_cutlass_(self):
# FIXME: remove this!
torch.manual_seed(0)

dtype = torch.float16 # torch.bfloat16

input = torch.rand((64, 128)).to(dtype).cuda()
model = ToyModel().to(dtype).cuda()

output_ref = model(input)

modelq = copy.deepcopy(model)
quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass())
output = modelq(input)

assert torch.allclose(output, output_ref, rtol=1e-1, atol=0)


if __name__ == "__main__":
run_tests()
5 changes: 4 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def main(
quantize_,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int4_weight_cutlass,
int4_weight_only,
fpx_weight_only,
uintx_weight_only,
Expand All @@ -221,6 +222,8 @@ def main(
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
if "w4a8-cutlass" in quantization:
quantize_(model, int8_dynamic_activation_int4_weight_cutlass())
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
Expand Down Expand Up @@ -459,7 +462,7 @@ def callback(x):
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
'Which quantization techniques to apply: int8dq, w4a8-cutlass, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
)
)
Expand Down
Loading

0 comments on commit 4dbe339

Please sign in to comment.