Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ addopts = -rP

log_cli = True
log_cli_level = INFO
log_file = logs/pytest.log
log_file = logs/pytest.log
markers =
benchmark: mark test as benchmark
slow: mark test as slow
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
def pytest_runtest_call(item):
try:
item.runtest()
except NotImplementedError as nie:
if "NO_CUBLASLT" in str(nie):
pytest.skip("CUBLASLT not available")
raise
except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled":
pytest.skip("Torch not compiled with CUDA enabled")
Expand Down
51 changes: 51 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from itertools import product
import random
from typing import Any

import torch

test_dims_rng = random.Random(42)


def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)]


def format_with_label(label: str, value: Any) -> str:
if isinstance(value, bool):
formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value)
else:
formatted = str(value)
return f"{label}={formatted}"


def id_formatter(label: str):
"""
Return a function that formats the value given to it with the given label.
"""
return lambda value: format_with_label(label, value)


DTYPE_NAMES = {
torch.bfloat16: "bf16",
torch.bool: "bool",
torch.float16: "fp16",
torch.float32: "fp32",
torch.float64: "fp64",
torch.int32: "int32",
torch.int64: "int64",
torch.int8: "int8",
}


def describe_dtype(dtype: torch.dtype) -> str:
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]


TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(
product(TRUE_FALSE, repeat=3)
) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
212 changes: 56 additions & 156 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,35 @@
from itertools import product
from typing import Tuple

import pytest
import torch

import bitsandbytes as bnb

n = 1
k = 25
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
values = list(
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
str_values = list(
product(
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
)
)
names = [
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals
)
for vals in str_values
]


@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
values,
ids=names,
from tests.helpers import (
BOOLEAN_TRIPLES,
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_test_dims,
id_formatter,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):

TRANSPOSE_VALS = [(False, True), (False, False)]


@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
for i in range(25):

# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
Expand Down Expand Up @@ -228,71 +213,17 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert (idx == 0).sum().item() < n * 0.02


n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()

dim2.append(0)

decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
str_funcs = ["matmullt", 'switchback_bnb']
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.bfloat16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(
product(
dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
has_bias
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
has_bias
)
)
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]


@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
values,
ids=names,
)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(
dim1,
dim2,
Expand All @@ -313,7 +244,7 @@ def test_matmullt(
req_grad = list(req_grad)
req_grad[2] = False

for i in range(k):
for i in range(3):

# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
Expand Down Expand Up @@ -429,45 +360,25 @@ def test_matmullt(
torch.testing.assert_close(gradBias1, gradBias2)


n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()

dim2.append(0)

funcs = [(torch.matmul, bnb.matmul_4bit)]
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
compress_statistics = [False, True]
has_fp16_weights = [True, False]
has_bias = [True, False]
quant_type = ['fp4', 'nf4']
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"])
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type"))
def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False

for i in range(k):
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
Expand Down Expand Up @@ -530,32 +441,21 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
torch.testing.assert_close(gradBias1, gradBias2)


funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)

transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global'])
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
req_grad = list(req_grad)
req_grad[2] = False

for i in range(k):
for i in range(3):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
Expand Down
Loading