Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Sep 12, 2024
1 parent 8236a87 commit 1bacd02
Show file tree
Hide file tree
Showing 7 changed files with 572 additions and 2 deletions.
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
import cutlass_library
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"

extra_link_args = []
extra_compile_args = {
"cxx": [
Expand All @@ -74,6 +80,7 @@ def get_extensions():
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
"-I" + cutlass_include_dir,
]
}

Expand Down
46 changes: 46 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# FIXME: move this test to the appropriate test file!!!

import copy

from torchao.quantization import quantize_
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

import pytest


class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 32)

def forward(self, x):
x = self.linear(x)
return x


class TestS8S4LinearCUTLASS(TestCase):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_s8s4_linear_cutlass_(self):
# FIXME: remove this!
torch.manual_seed(0)

input = torch.rand((64, 128)).half().cuda()
model = ToyModel().half().cuda()

output_ref = model(input)

modelq = copy.deepcopy(model)
quantize_(modelq, int8_dynamic_activation_int4_weight(group_size=128))
output = modelq(input)

assert torch.allclose(output, output_ref, rtol=1e-1, atol=0)


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit 1bacd02

Please sign in to comment.