diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py index fd0dd7d58..b23f63562 100644 --- a/benchmarking/switchback/make_plot_with_jsonl.py +++ b/benchmarking/switchback/make_plot_with_jsonl.py @@ -1,11 +1,13 @@ + import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import pandas as pd -cmap = plt.get_cmap("cool") +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': -if __name__ == "__main__": - fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) + fig = plt.figure(tight_layout=True, figsize=(12,3.5)) gs = gridspec.GridSpec(1, 2) dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] @@ -17,28 +19,25 @@ ax = fig.add_subplot(gs[0, 0]) # TODO: change this to what you want. - rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) + rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) df = rdf[rdf.batch_size == batch_size_for_plot1] # first plot the time occupied by different operations for k, marker, ls, color, name in [ - ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), - ( - "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", - "o", - "-", - "C4", - "SwitchBack int8 (sum of parts)", - ), - ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), - ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), - ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), - ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), - ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), - ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), - ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), - ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), - ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), + + ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), + ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), + ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), + + ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), + ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), + + ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), + ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), + ('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'), + ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'), ]: xs = [] ys = [] @@ -48,46 +47,40 @@ df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 - for k_ in k.split("+"): + for k_ in k.split('+'): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split("+"): + for k_ in k.split('+'): y_ += df_[k_].values[0] ys.append(y_ * 0.5) - ax.plot( - xs, - ys, - color=color, - label=name, - marker=marker, - markersize=5 if marker == "s" else 5, - linestyle=ls, - linewidth=2 if "+" in k else 1.0, - ) - ax.set_xlabel("dim", fontsize=13) - ax.set_ylabel("time (ms)", fontsize=13) + ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) + + + ax.set_xlabel('dim', fontsize=13) + ax.set_ylabel('time (ms)', fontsize=13) ax.grid() - ax.set_xscale("log") + ax.set_xscale('log') if logscale_plot1: - ax.set_yscale("log") + ax.set_yscale('log') - ax.tick_params(axis="x", labelsize=11) - ax.tick_params(axis="y", labelsize=11) + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) - leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) - leg.get_texts()[0].set_fontweight("bold") - leg.get_texts()[1].set_fontweight("bold") + leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight('bold') + leg.get_texts()[1].set_fontweight('bold') plt.subplots_adjust(left=0.1) - ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) + ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) + ax = fig.add_subplot(gs[0, 1]) @@ -95,15 +88,10 @@ for j, batch_size in enumerate(batch_sizes_for_plot2): all_xs, all_ys = [], [] for k, marker, ls, color, name in [ - ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), - ( - "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", - "o", - "-", - "C4", - "SwitchBack int8 (total time)", - ), + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), ]: + xs, ys = [], [] df = rdf[rdf.batch_size == batch_size] for embed_dim in dims_to_consider: @@ -111,11 +99,11 @@ df_ = df_[df_.dim_out == embed_dim * 4] xs.append(embed_dim) y_ = 0 - for k_ in k.split("+"): + for k_ in k.split('+'): y_ += df_[k_].values[0] df_ = df[df.dim_in == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim] - for k_ in k.split("+"): + for k_ in k.split('+'): y_ += df_[k_].values[0] ys.append(y_ * 0.5) all_xs.append(xs) @@ -123,29 +111,25 @@ color = cmap(j * 0.25) real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] - markers = ["^", "v", "P", "o"] - ax.plot( - all_xs[0], - real_ys, - color=color, - label=f"batch * sequence length = {batch_size}", - marker=markers[j], - markersize=5 if marker == "s" else 5, - ) + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) ax.legend() - ax.set_xlabel("dim", fontsize=13) - ax.set_xscale("log") + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') ax.grid() - ax.set_ylabel(r"% speedup", fontsize=13) + ax.set_ylabel(r'% speedup', fontsize=13) - ax.tick_params(axis="x", labelsize=11) - ax.tick_params(axis="y", labelsize=11) + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) ax.set_xticks(dims_to_xtick) ax.set_xticklabels(dims_to_xtick) ax.set_xticks([], minor=True) - ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) + ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + - plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight") + plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py index eaba0e9cd..c4f3cd4c6 100644 --- a/benchmarking/switchback/speed_benchmark.py +++ b/benchmarking/switchback/speed_benchmark.py @@ -20,15 +20,15 @@ # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. - def get_time(k, fn, info_dict): + for _ in range(repeat // 2): - fn() + fn() torch.cuda.synchronize() start = time.time() for _ in range(repeat): - fn() + fn() torch.cuda.synchronize() end = time.time() @@ -36,15 +36,16 @@ def get_time(k, fn, info_dict): print(f"time {k}: {ms:.3f} ms") info_dict[k] = ms - -if __name__ == "__main__": +if __name__ == '__main__': torch.manual_seed(0) wm = 4 for dim in [1024, 1280, 1408, 1664, 2048, 4096]: # note "batch_size" is actually "batch_size * embed_dim", which is why it's large - for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]: + for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: + # switch switches dim_in and dim_out for switch in [False, True]: + # hparams repeat = 64 batch_size = batch_size @@ -72,86 +73,35 @@ def get_time(k, fn, info_dict): state_w_rowwise = w.max(dim=1)[0] state_w_global = w.max() - info = { - "repeat": repeat, - "batch_size": batch_size, - "dim_out": dim_out, - "dim_in": dim_in, - "wm": wm, - "switch": switch, - } - - get_time("standard_fwd", lambda: x.matmul(w.t()), info) - get_time("standard_gw", lambda: g.t().matmul(x), info) - get_time("standard_gx", lambda: g.matmul(w), info) - get_time( - "rowwise_fwd", - lambda: int8_matmul_rowwise_dequantize( - x_int8, - w_int8.t(), - state_x_rowwise, - state_w_columnwise, - None, - ), - info, - ) - get_time( - "rowwise_bwd", - lambda: int8_matmul_rowwise_dequantize( - g_int8, - wt_int8.t(), - state_x_rowwise, - state_w_rowwise, - None, - ), - info, - ) - get_time( - "global_fwd", - lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), - info, - ) - get_time( - "global_bwd", - lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), - info, - ) - get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info) - get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info) - get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info) - get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info) - get_time("w_quantize_global", lambda: quantize_global(w), info) - get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info) - - time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"] - time_rowwise = ( - info["x_quantize_rowwise"] - + info["g_quantize_rowwise"] - + info["w_quantize_colwise_transpose"] - + info["w_quantize_rowwise"] - + info["standard_gw"] - + info["rowwise_fwd"] - + info["rowwise_bwd"] - ) - time_global = ( - info["x_quantize_rowwise"] - + info["g_quantize_rowwise"] - + info["w_quantize_global"] - + info["w_quantize_global_transpose"] - + info["standard_gw"] - + info["global_fwd"] - + info["global_bwd"] - ) - - print("TOTAL STANDARD", time_standard) - print("TOTAL ROWWISE", time_rowwise) - print("TOTAL GLOBAL", time_global) - - print("speedup", -100 * (time_global - time_standard) / time_standard) - - info["time_standard"] = time_standard - info["time_rowwise"] = time_rowwise - info["time_global"] = time_global + info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} + + get_time('standard_fwd', lambda : x.matmul(w.t()), info) + get_time('standard_gw', lambda : g.t().matmul(x), info) + get_time('standard_gx', lambda : g.matmul(w), info) + get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) + get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) + get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) + get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) + get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) + get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) + get_time('w_quantize_global', lambda : quantize_global(w), info) + get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) + + time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] + time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] + time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + print('TOTAL STANDARD', time_standard) + print('TOTAL ROWWISE', time_rowwise) + print('TOTAL GLOBAL', time_global) + + print('speedup', -100*(time_global - time_standard)/time_standard) + + info['time_standard'] = time_standard + info['time_rowwise'] = time_rowwise + info['time_global'] = time_global info_json = json.dumps(info) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 78c99355b..3b83a8d6d 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import research, utils +from . import cuda_setup, research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, @@ -12,8 +12,11 @@ matmul_cublas, mm_cublas, ) +from .cextension import COMPILED_WITH_CUDA from .nn import modules -from .optim import adam + +if COMPILED_WITH_CUDA: + from .optim import adam __pdoc__ = { "libbitsandbytes": False, @@ -21,4 +24,6 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.44.0.dev" +__version__ = "0.43.0" + +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index e716b6f3f..61b42e78f 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,4 +1,108 @@ -if __name__ == "__main__": - from bitsandbytes.diagnostics.main import main +import glob +import os +import sys +from warnings import warn + +import torch + +HEADER_WIDTH = 60 + + +def find_dynamic_library(folder, filename): + for ext in ("so", "dll", "dylib"): + yield from glob.glob(os.path.join(folder, "**", filename + ext)) + + +def generate_bug_report_information(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + print('') + + path_sources = [ + ("ANACONDA CUDA PATHS", os.environ.get("CONDA_PREFIX")), + ("/usr/local CUDA PATHS", "/usr/local"), + ("CUDA PATHS", os.environ.get("CUDA_PATH")), + ("WORKING DIRECTORY CUDA PATHS", os.getcwd()), + ] + try: + ld_library_path = os.environ.get("LD_LIBRARY_PATH") + if ld_library_path: + for path in set(ld_library_path.strip().split(os.pathsep)): + path_sources.append((f"LD_LIBRARY_PATH {path} CUDA PATHS", path)) + except Exception as e: + print(f"Could not parse LD_LIBRARY_PATH: {e}") + + for name, path in path_sources: + if path and os.path.isdir(path): + print_header(name) + print(list(find_dynamic_library(path, '*cuda*'))) + print("") + + +def print_header( + txt: str, width: int = HEADER_WIDTH, filler: str = "+" +) -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_debug_info() -> None: + from . import PACKAGE_GITHUB_URL + print( + "\nAbove we output some debug information. Please provide this info when " + f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" + ) + +def main(): + generate_bug_report_information() + + from . import COMPILED_WITH_CUDA + from .cuda_setup.main import get_compute_capabilities + + print_header("OTHER") + print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") + print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n") + + try: + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + + p1 = p.data.sum().item() + + adam = Adam([p]) + + out = a * p + loss = out.sum() + loss.backward() + adam.step() + + p2 = p.data.sum().item() + + assert p1 != p2 + print("SUCCESS!") + print("Installation was successful!") + except ImportError: + print() + warn( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!" + ) + print_debug_info() + except Exception as e: + print(e) + print_debug_info() + sys.exit(1) + + +if __name__ == "__main__": main() diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e9821cd36..6cbb6efd9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -14,18 +14,16 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) - # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py + """ This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. """ - - class GlobalOutlierPooler: _instance = None @@ -85,7 +83,6 @@ def get_inverse_transform_indices( break # if all indices fit in i bytes, stop early return permuted_tile_indices - def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: """ Undo a tiled permutation such as turing or ampere layout @@ -162,12 +159,20 @@ def backward(ctx, grad_output): ) if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) + qA, S2 = F.vectorwise_quant( + A.view(-1, A.shape[2]), dim=0, quant_type=quant_type + ) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, S2.t(), S1, grad_output.dtype, quant_type + ) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) - qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) + qA, S2 = F.vectorwise_quant( + A, dim=dims, quant_type=quant_type + ) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -196,7 +201,9 @@ def backward(ctx, grad_output): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( @@ -220,7 +227,7 @@ def supports_igemmlt(device: torch.device) -> bool: if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) - nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series + nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores return True @@ -239,7 +246,6 @@ def get_tile_inds(format, device): with torch.no_grad(): return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) - @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None @@ -504,6 +510,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] else: return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + # 1. Dequantize # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) @@ -525,7 +532,7 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad + req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad A, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None @@ -535,9 +542,8 @@ def backward(ctx, grad_output): grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) # not supported by PyTorch. TODO: create work-around - # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: - grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None @@ -548,7 +554,7 @@ def matmul( out: Optional[torch.Tensor] = None, state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None, + bias=None ): state = state or MatmulLtState() if threshold > 0.0: @@ -556,19 +562,11 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_4bit( - A: torch.Tensor, - B: torch.Tensor, - quant_state: F.QuantState, - out: Optional[torch.Tensor] = None, - bias=None, -): +def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: - warn( - f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", - ) + warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') return MatMul4Bit.apply(A, B, out, bias, quant_state) else: out = F.gemv_4bit(A, B.t(), out, state=quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c8ae7358d..858365f02 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,124 +1,39 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - import ctypes as ct -import logging -import os -from pathlib import Path +from warnings import warn import torch -from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs - -logger = logging.getLogger(__name__) - - -def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: - """ - Get the disk path to the CUDA BNB native library specified by the - given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable. - - The library is not guaranteed to exist at the returned path. - """ - library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" - if not cuda_specs.has_cublaslt: - # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - library_name += "_nocublaslt" - library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" - - override_value = os.environ.get("BNB_CUDA_VERSION") - if override_value: - library_name_stem, _, library_name_ext = library_name.rpartition(".") - # `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`; - # let's remove any trailing numbers: - library_name_stem = library_name_stem.rstrip("0123456789") - # `library_name_stem` will now be e.g. `libbitsandbytes_cuda`; - # let's tack the new version number and the original extension back on. - library_name = f"{library_name_stem}{override_value}.{library_name_ext}" - logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" - "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" - "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" - "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLibrary: - binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" - cuda_specs = get_cuda_specs() - if cuda_specs: - cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) - if cuda_binary_path.exists(): - binary_path = cuda_binary_path - else: - logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) - logger.debug(f"Loading bitsandbytes native library from: {binary_path}") - dll = ct.cdll.LoadLibrary(str(binary_path)) - - if hasattr(dll, "get_context"): # only a CUDA-built library exposes this - return CudaBNBNativeLibrary(dll) - - logger.warning( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", - ) - return BNBNativeLibrary(dll) +from bitsandbytes.cuda_setup.main import CUDASetup +setup = CUDASetup.get_instance() +if setup.initialized != True: + setup.run_cuda_setup() +lib = setup.lib try: - lib = get_native_library() -except Exception as e: - lib = None - logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) - if torch.cuda.is_available(): - logger.warning( - """ -CUDA Setup failed despite CUDA being available. Please run the following command to get more information: - -python -m bitsandbytes - -Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them -to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes -and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues -""", - ) + if lib is None and torch.cuda.is_available(): + CUDASetup.get_instance().generate_instructions() + CUDASetup.get_instance().print_log_stack() + raise RuntimeError(''' + CUDA Setup failed despite GPU being available. Please run the following command to get more information: + + python -m bitsandbytes + + Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them + to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes + and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') + _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False + lib.get_context.restype = ct.c_void_p + lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p + COMPILED_WITH_CUDA = True +except AttributeError as ex: + warn("The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") + COMPILED_WITH_CUDA = False + print(str(ex)) + + +# print the setup details after checking for errors so we do not print twice +#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': + #setup.print_log_stack() diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py new file mode 100644 index 000000000..4b2549653 --- /dev/null +++ b/bitsandbytes/cuda_setup/env_vars.py @@ -0,0 +1,53 @@ +import os +from typing import Dict + + +def to_be_ignored(env_var: str, value: str) -> bool: + ignorable = { + "PWD", # PWD: this is how the shell keeps track of the current working dir + "OLDPWD", + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "MAIL", # something related to emails + "SHELL", # binary for currently invoked shell + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "PATH", # this is for finding binaries, not libraries + "LESSOPEN", # related to the `less` command + "LESSCLOSE", + "_", # current Python interpreter + } + return env_var in ignorable + + +def might_contain_a_path(candidate: str) -> bool: + return os.sep in candidate + + +def is_active_conda_env(env_var: str) -> bool: + return "CONDA_PREFIX" == env_var + + +def is_other_conda_env_var(env_var: str) -> bool: + return "CONDA" in env_var + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return is_active_conda_env(env_var) or ( + might_contain_a_path(value) and not + is_other_conda_env_var(env_var) and not + to_be_ignored(env_var, value) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return { + env_var: value + for env_var, value in os.environ.items() + if is_relevant_candidate_env_var(env_var, value) + } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py new file mode 100644 index 000000000..b351f7f03 --- /dev/null +++ b/bitsandbytes/cuda_setup/main.py @@ -0,0 +1,393 @@ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiply) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + +import ctypes as ct +import errno +import os +from pathlib import Path +import platform +from typing import Set, Union +from warnings import warn + +import torch + +from .env_vars import get_potentially_lib_path_containing_env_vars + +DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so") +if platform.system() == "Windows": # Windows + CUDA_RUNTIME_LIBS = ["cudart64_110.dll", "cudart64_12.dll"] +else: # Linux or other + # these are the most common libs names + # libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead + # we have libcudart.so.11.0 which causes a lot of errors before + # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt + CUDA_RUNTIME_LIBS = ["libcudart.so", "libcudart.so.11.0", "libcudart.so.12.0", "libcudart.so.12.1", "libcudart.so.12.2"] + + +class CUDASetup: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def generate_instructions(self): + if getattr(self, 'error', False): return + print(self.error) + self.error = True + if not self.cuda_available: + self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.') + self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') + self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') + self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') + self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') + self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') + self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)') + return + + if self.cudart_path is None: + self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') + self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') + self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') + self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') + self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') + self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') + self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh') + self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') + self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') + + return + + make_cmd = f'CUDA_VERSION={self.cuda_version_string}' + if len(self.cuda_version_string) < 3: + make_cmd += ' make cuda92' + elif self.cuda_version_string == '110': + make_cmd += ' make cuda110' + elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: + make_cmd += ' make cuda11x' + elif self.cuda_version_string[:2] == '12' and 1 >= int(self.cuda_version_string[2]) >= 0: + make_cmd += ' make cuda12x' + elif self.cuda_version_string == '100': + self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') + self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') + return + + + has_cublaslt = is_cublasLt_compatible(self.cc) + if not has_cublaslt: + make_cmd += '_nomatmul' + + self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') + self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git') + self.add_log_entry('cd bitsandbytes') + self.add_log_entry(make_cmd) + self.add_log_entry('python setup.py install') + + def initialize(self): + if not getattr(self, 'initialized', False): + self.has_printed = False + self.lib = None + self.initialized = False + self.error = False + + def manual_override(self): + if not torch.cuda.is_available(): + return + override_value = os.environ.get('BNB_CUDA_VERSION') + if not override_value: + return + + binary_name_stem, _, binary_name_ext = self.binary_name.rpartition(".") + # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda118`; + # let's remove any trailing numbers: + binary_name_stem = binary_name_stem.rstrip("0123456789") + # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda`; + # let's tack the new version number and the original extension back on. + self.binary_name = f"{binary_name_stem}{override_value}.{binary_name_ext}" + + warn( + f'\n\n{"=" * 80}\n' + 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' + 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' + 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' + 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' + 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: + return {Path(ld_path) for ld_path in paths_list_candidate.split(os.pathsep) if ld_path} + + +def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: + existent_directories: Set[Path] = set() + for path in candidate_paths: + try: + if path.exists(): + existent_directories.add(path) + except PermissionError: + # Handle the PermissionError first as it is a subtype of OSError + # https://docs.python.org/3/library/exceptions.html#exception-hierarchy + pass + except OSError as exc: + if exc.errno != errno.ENAMETOOLONG: + raise exc + + non_existent_directories: Set[Path] = candidate_paths - existent_directories + if non_existent_directories: + CUDASetup.get_instance().add_log_entry( + f"The following directories listed in your path were found to be non-existent: {non_existent_directories}", + is_warning=False, + ) + + return existent_directories + + +def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: + paths = set() + for libname in CUDA_RUNTIME_LIBS: + for path in candidate_paths: + try: + if (path / libname).is_file(): + paths.add(path / libname) + except PermissionError: + pass + return paths + + +def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: + """ + Searches a given environmental var for the CUDA runtime library, + i.e. `libcudart.so`. + """ + return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) + + +def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: + return get_cuda_runtime_lib_paths( + resolve_paths_list(paths_list_candidate) + ) + + +def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: + if len(results_paths) > 1: + warning_msg = ( + f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " + "We select the PyTorch default libcudart.so, which is {torch.version.cuda}," + "but this might mismatch with the CUDA version that is needed for bitsandbytes." + "To override this behavior set the BNB_CUDA_VERSION= environmental variable" + "For example, if you want to use the CUDA version 122" + "BNB_CUDA_VERSION=122 python ..." + "OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122" + "In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g." + "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.2") + CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) + + +def determine_cuda_runtime_lib_path() -> Union[Path, None]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + cuda_runtime_libs = set() + if "CONDA_PREFIX" in candidate_env_vars: + conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" + + conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) + warn_in_case_of_duplicates(conda_cuda_libs) + + if conda_cuda_libs: + cuda_runtime_libs.update(conda_cuda_libs) + + CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) + + if "LD_LIBRARY_PATH" in candidate_env_vars: + lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) + + if lib_ld_cuda_libs: + cuda_runtime_libs.update(lib_ld_cuda_libs) + warn_in_case_of_duplicates(lib_ld_cuda_libs) + + CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) + + remaining_candidate_env_vars = { + env_var: value for env_var, value in candidate_env_vars.items() + if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} + } + + cuda_runtime_libs = set() + for env_var, value in remaining_candidate_env_vars.items(): + cuda_runtime_libs.update(find_cuda_lib_in(value)) + + if len(cuda_runtime_libs) == 0: + CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') + cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) + + warn_in_case_of_duplicates(cuda_runtime_libs) + + cuda_setup = CUDASetup.get_instance() + cuda_setup.add_log_entry(f'DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}') + + return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None + + +# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION +def get_cuda_version(): + major, minor = map(int, torch.version.cuda.split(".")) + + if major < 11: + CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') + + return f'{major}{minor}' + +def get_compute_capabilities(): + ccs = [] + for i in range(torch.cuda.device_count()): + cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i)) + ccs.append(f"{cc_major}.{cc_minor}") + + ccs.sort(key=lambda v: tuple(map(int, str(v).split(".")))) + + return ccs + + +def evaluate_cuda_setup(): + cuda_setup = CUDASetup.get_instance() + if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': + cuda_setup.add_log_entry('') + cuda_setup.add_log_entry('='*35 + 'BUG REPORT' + '='*35) + cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), + ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) + cuda_setup.add_log_entry('='*80) + + if not torch.cuda.is_available(): + return f'libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}', None, None, None + + cudart_path = determine_cuda_runtime_lib_path() + cc = get_compute_capabilities()[-1] # we take the highest capability + cuda_version_string = get_cuda_version() + + cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") + cuda_setup.add_log_entry( + "CUDA SETUP: To manually override the PyTorch CUDA version please see:" + "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" + ) + + + # 7.5 is the minimum CC vor cublaslt + has_cublaslt = is_cublasLt_compatible(cc) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + # we use ls -l instead of nvcc to determine the cuda version + # since most installations will have the libcudart.so installed, but not the compiler + + binary_name = f"libbitsandbytes_cuda{cuda_version_string}" + if not has_cublaslt: + # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt + binary_name += "_nocublaslt" + + binary_name = f"{binary_name}{DYNAMIC_LIBRARY_SUFFIX}" + + return binary_name, cudart_path, cc, cuda_version_string diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb6a04892..f0de962e1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -14,17 +14,16 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib +from .cextension import COMPILED_WITH_CUDA, lib # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) - name2qmap = {} -if lib and lib.compiled_with_cuda: +if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( @@ -128,6 +127,7 @@ def prefetch_all(self, to_cpu=False): prefetch_tensor(t, to_cpu) + class CUBLAS_Context: _instance = None @@ -169,7 +169,6 @@ def get_instance(cls): cls._instance.initialize() return cls._instance - dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -177,11 +176,10 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -FIRST_CUDA_DEVICE = torch.device("cuda", index=0) - +FIRST_CUDA_DEVICE = torch.device('cuda', index=0) def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): - num_bytes = dtype2bytes[dtype] * prod(shape) + num_bytes = dtype2bytes[dtype]*prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -190,35 +188,31 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): out.page_deviceid = device.index return out - def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, "Only paged tensors can be prefetched!" + assert A.is_paged, 'Only paged tensors can be prefetched!' if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype] * A.numel() + num_bytes = dtype2bytes[A.dtype]*A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) - def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f"c{func_name}_fp32", None) + func = getattr(lib, f'c{func_name}_fp32', None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f"c{func_name}_uint8", None) + func = getattr(lib, f'c{func_name}_uint8', None) cvalue = ct.c_uint8(value) - if func is None: - raise NotImplementedError(f"Function not implemented: {func_name}") + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') - is_managed = getattr(A, "is_managed", False) + is_managed = getattr(A, 'is_managed', False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: - prefetch_tensor(B) + if B is not None: prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: @@ -228,36 +222,28 @@ def elementwise_func(func_name, A, B, value, prefetch=True): # operation occurred. So we synchronize. torch.cuda.synchronize() - -def fill(A, value, device=None, prefetch=True): - elementwise_func("fill", A, None, value) - - -def arange(A, device=None): - elementwise_func("arange", A, None, 0) - - -def _mul(A, B, device=None): - elementwise_func("_mul", A, B, 0) +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = -1.0 if signed else 0.0 + sign = (-1.0 if signed else 0.0) total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = 2**total_bits if not signed else 2**total_bits - 1 + total_values = (2**total_bits if not signed else 2**total_bits-1) values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel() // 2 # noqa: E741 - return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) + l = values.numel()//2 # noqa: E741 + return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) def create_normal_map(offset=0.9677083, use_extra_value=True): @@ -265,17 +251,18 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): from scipy.stats import norm except ImportError as ie: raise ImportError( - "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.", + "Scipy is required for `create_normal_map`. " + "Install `bitsandbytes` with the `[test]` extra." ) from ie if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type + v2 = [0]*(256-15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type + v2 = [0]*(256-14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -288,37 +275,38 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): return values - def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e + p == total_bits - has_sign + assert e+p == total_bits-has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): + for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) + values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - # for ev in evalues: - bias = 2 ** (exponent_bits - 1) - for evalue in range(2 ** (exponent_bits)): + #for ev in evalues: + bias = 2**(exponent_bits-1) + for evalue in range(2**(exponent_bits)): for bit_pattern in lst: - value = 1 if evalue != 0 else 0 + value = (1 if evalue != 0 else 0) for i, pval in enumerate(list(bit_pattern)): - value += pval * (2 ** -(i + 1)) + value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value * 2**-(bias) + value = value*2**-(bias) else: # normals - value = value * 2 ** -(evalue - bias - 1) + value = value*2**-(evalue-bias-1) values.append(value) if signed: values.append(-value) + assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -332,6 +320,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code + def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -356,11 +345,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int( - 2 ** (i + non_sign_bits - max_exponent_bits) + 1 - if signed - else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, - ) + fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -386,9 +371,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) - def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) + q = estimate_quantiles(A, num_quantiles=2**total_bits-1) q = q.tolist() q.append(0) @@ -399,13 +383,11 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q / q.abs().max() + q = q/q.abs().max() return q - def get_special_format_str(): - if not torch.cuda.is_available(): - return "col_turing" + if not torch.cuda.is_available(): return 'col_turing' major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -414,24 +396,20 @@ def get_special_format_str(): return "col_turing" + def is_on_gpu(tensors): on_gpu = True gpu_ids = set() for t in tensors: - if t is None: - continue # NULL pointers are fine - is_paged = getattr(t, "is_paged", False) - on_gpu &= t.device.type == "cuda" or is_paged + if t is None: continue # NULL pointers are fine + is_paged = getattr(t, 'is_paged', False) + on_gpu &= (t.device.type == 'cuda' or is_paged) if not is_paged: gpu_ids.add(t.device.index) if not on_gpu: - raise TypeError( - f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}", - ) + raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') if len(gpu_ids) > 1: - raise TypeError( - f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}", - ) + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') return on_gpu @@ -469,13 +447,15 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): if not hasattr(lib, name): print(name) raise ValueError( - f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}", + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" ) else: return getattr(lib, name) -def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): +def get_transform_buffer( + shape, dtype, device, to_order, from_order="row", transpose=False +): # init_func = torch.empty init_func = torch.zeros dims = len(shape) @@ -528,7 +508,9 @@ def nvidia_transform( else: from_order = state[1] if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1] + ) else: new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) @@ -552,13 +534,8 @@ def nvidia_transform( return out, new_state -def estimate_quantiles( - A: Tensor, - out: Optional[torch.Tensor] = None, - offset: float = 1 / 512, - num_quantiles=256, -) -> Tensor: - """ +def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: + ''' Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -585,21 +562,14 @@ def estimate_quantiles( ------- torch.Tensor: The 256 quantiles in float32 datatype. - """ - if A.numel() < 256: - raise NotImplementedError( - f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.", - ) - if num_quantiles > 256: - raise NotImplementedError( - f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}", - ) - if num_quantiles < 256 and offset == 1 / (512): + ''' + if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') + if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") + if num_quantiles < 256 and offset == 1/(512): # override default arguments - offset = 1 / (2 * num_quantiles) + offset = 1/(2*num_quantiles) - if out is None: - out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: @@ -611,7 +581,7 @@ def estimate_quantiles( post_call(device) if num_quantiles < 256: - step = round(256 / num_quantiles) + step = round(256/num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] @@ -620,35 +590,12 @@ def estimate_quantiles( class QuantState: """container for quantization state components to work with Params4bit and similar classes""" - - valid_quant_types = ("fp4", "nf4") + valid_quant_types = ('fp4', 'nf4') valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = [ - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "quant_state", - "quant_type", - "blocksize", - "dtype", - "shape", - "nested_blocksize", - "nested_dtype", - "nested_offset", - ] - - def __init__( - self, - absmax, - shape=None, - code=None, - blocksize=None, - quant_type=None, - dtype=None, - offset=None, - state2=None, - ): + valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', + 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] + + def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): self.absmax = absmax self.shape = shape self.code = code @@ -667,20 +614,13 @@ def __get_item__(self, idx): state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] """ if self.nested: - list_repr = [ - self.absmax, - self.shape, - self.dtype, - self.blocksize, - [self.offset, self.state2], - self.quant_type, - ] + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] else: list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] return list_repr[idx] @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': """ unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. @@ -692,39 +632,37 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState # unpacking tensor with non-tensor components qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and "quant_type" not in qs_dict: + if not len(qs_key) and 'quant_type' not in qs_dict: raise ValueError("Expected packed or unpacked quant_state items, found neither") elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", - ) + raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: first_qs_key = qs_key[0] qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - if "nested_absmax" in qs_dict: - offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) + if 'nested_absmax' in qs_dict: + offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) state2 = cls( - absmax=qs_dict["nested_absmax"].to(device), - blocksize=qs_dict["nested_blocksize"], - code=qs_dict["nested_quant_map"].to(device), - dtype=getattr(torch, qs_dict["nested_dtype"]), + absmax=qs_dict['nested_absmax'].to(device), + blocksize=qs_dict['nested_blocksize'], + code=qs_dict['nested_quant_map'].to(device), + dtype=getattr(torch, qs_dict['nested_dtype']), ) else: offset, state2 = None, None quant_state = cls( - quant_type=qs_dict["quant_type"], - absmax=qs_dict["absmax"].to(device), - blocksize=qs_dict["blocksize"], - code=qs_dict["quant_map"].to(device), - dtype=getattr(torch, qs_dict["dtype"]), - shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, + quant_type=qs_dict['quant_type'], + absmax=qs_dict['absmax'].to(device), + blocksize=qs_dict['blocksize'], + code=qs_dict['quant_map'].to(device), + dtype=getattr(torch, qs_dict['dtype']), + shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, offset=offset, state2=state2, ) @@ -736,23 +674,21 @@ def as_dict(self, packed=False): param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving """ qs_dict = { - "quant_type": self.quant_type, - "absmax": self.absmax, - "blocksize": self.blocksize, - "quant_map": self.code, - "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), + 'quant_type': self.quant_type, + 'absmax': self.absmax, + 'blocksize': self.blocksize, + 'quant_map': self.code, + 'dtype': str(self.dtype).strip('torch.'), + 'shape': tuple(self.shape), } if self.nested: - qs_dict.update( - { - "nested_absmax": self.state2.absmax, - "nested_blocksize": self.state2.blocksize, - "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - "nested_dtype": str(self.state2.dtype).strip("torch."), - "nested_offset": self.offset.item(), - }, - ) + qs_dict.update({ + 'nested_absmax': self.state2.absmax, + 'nested_blocksize': self.state2.blocksize, + 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + 'nested_dtype': str(self.state2.dtype).strip('torch.'), + 'nested_offset': self.offset.item(), + }) if not packed: return qs_dict @@ -775,22 +711,14 @@ def __eq__(self, other): return False return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) - and self.shape == other.shape - and torch.allclose(self.code, other.code, atol=1e-6) - and self.dtype == other.dtype - and self.blocksize == other.blocksize - and self.quant_type == other.quant_type - and ( - self.offset == other.offset - if self.offset is not None and other.offset is not None - else self.offset is other.offset - ) - and ( - self.state2 == other.state2 - if self.state2 is not None and other.state2 is not None - else self.state2 is other.state2 - ) + torch.allclose(self.absmax, other.absmax, atol=1e-6) and + self.shape == other.shape and + torch.allclose(self.code, other.code, atol=1e-6) and + self.dtype == other.dtype and + self.blocksize == other.blocksize and + self.quant_type == other.quant_type and + (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and + (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) ) @@ -828,6 +756,7 @@ def quantize_blockwise( The quantization state to undo the quantization. """ + if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -842,66 +771,31 @@ def quantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != "cpu": + if A.device.type != 'cpu': assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - cblocksize, - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) + lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) if nested: offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - quant_state = QuantState( - absmax=qabsmax, - code=code, - blocksize=blocksize, - dtype=A.dtype, - offset=offset, - state2=state2, - ) + quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) else: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) @@ -915,7 +809,7 @@ def dequantize_blockwise( code: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 4096, - nested=False, + nested=False ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -949,76 +843,43 @@ def dequantize_blockwise( code = name2qmap["dynamic"] if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() + if absmax.dtype != torch.float32: absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != "cpu": + if A.device.type != 'cpu': device = pre_call(A.device) code = quant_state.code.to(A.device) if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", - ) + raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(quant_state.absmax), - get_ptr(out), - ct.c_longlong(quant_state.blocksize), - ct.c_longlong(A.numel()), - ) + lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) return out - def get_4bit_type(typename, device=None, blocksize=64): - if device is None: - device = "cuda" + if device is None: device = 'cuda' data = None - if typename == "nf4": - """ Implements the NF4 data type. + if typename == 'nf4': + ''' Implements the NF4 data type. Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1]. @@ -1027,26 +888,12 @@ def get_4bit_type(typename, device=None, blocksize=64): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - """ - data = [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, - ] - elif typename == "fp4": + ''' + data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, + -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, + 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, + 0.7229568362236023, 1.0] + elif typename == 'fp4': # 0b000 = 0 # 0b001 = 0.0625 # 0b010 = 8 @@ -1057,35 +904,20 @@ def get_4bit_type(typename, device=None, blocksize=64): # 0b111 = 3 # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] - elif typename == "int4": + elif typename == 'int4': data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] - elif typename == "af4": + elif typename == 'af4': # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # https://arxiv.org/abs/2306.06965 if blocksize == 64: - data = [ - -1.0, - -0.69441008, - -0.51243739, - -0.3736951, - -0.25607552, - -0.14982478, - -0.04934812, - 0.0, - 0.04273164, - 0.12934483, - 0.21961274, - 0.31675666, - 0.42563882, - 0.55496234, - 0.72424863, - 1.0, - ][::-1] + data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, + -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, + 0.42563882, 0.55496234, 0.72424863, 1.][::-1] else: - raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.") + raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') if data is None: - raise NotImplementedError(f"Typename {typename} not supported") + raise NotImplementedError(f'Typename {typename} not supported') data = Tensor(data) data /= data.abs().max() @@ -1094,26 +926,11 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4( - A: Tensor, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=64, - compress_statistics=False, - quant_storage=torch.uint8, -): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) +def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) - -def quantize_nf4( - A: Tensor, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=64, - compress_statistics=False, - quant_storage=torch.uint8, -): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) +def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) def quantize_4bit( @@ -1122,7 +939,7 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type="fp4", + quant_type='fp4', quant_storage=torch.uint8, ) -> Tuple[Tensor, QuantState]: """ @@ -1150,10 +967,10 @@ def quantize_4bit( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if A.device.type != "cuda": - raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') n = A.numel() input_shape = A.shape @@ -1163,9 +980,10 @@ def quantize_4bit( blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -1173,62 +991,20 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - lib.cquantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - lib.cquantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) elif A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + if quant_type == 'fp4': + lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - lib.cquantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -1240,57 +1016,19 @@ def quantize_4bit( absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) + state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) else: - state = QuantState( - absmax=absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) + state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) return out, state +def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_fp4( - A: Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, -) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") +def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') - -def dequantize_nf4( - A: Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, -) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") - - -def dequantize_4bit( - A: Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", -) -> Tensor: +def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1318,31 +1056,23 @@ def dequantize_4bit( Dequantized tensor. """ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", - ) - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) + quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) else: absmax = quant_state.absmax + if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() + if absmax.dtype != torch.float32: absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -1352,71 +1082,27 @@ def dequantize_4bit( device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - lib.cdequantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - lib.cdequantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: - lib.cdequantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - ) + lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = True if A.shape[0] == 1 else False - if is_transposed: - return out.t() - else: - return out + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out def quantize( @@ -1431,8 +1117,7 @@ def quantize( code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: - absmax = absmax.float() + if absmax.dtype != torch.float32: absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -1459,7 +1144,7 @@ def dequantize( def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - """ + ''' Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -1478,10 +1163,9 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ------- torch.Tensor: Quantized 8-bit tensor. - """ + ''' prev_device = pre_call(A.device) - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1489,7 +1173,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: - """ + ''' Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -1508,10 +1192,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = ------- torch.Tensor: 32-bit output tensor. - """ + ''' prev_device = pre_call(A.device) - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) + if out is None: out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1578,17 +1261,16 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) + optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1608,8 +1290,7 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + ct.c_int32(g.numel())) post_call(prev_device) @@ -1618,18 +1299,18 @@ def optimizer_update_8bit( g: Tensor, p: Tensor, state1: Tensor, - state2: Optional[torch.Tensor], + state2: Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Optional[torch.Tensor], + qmap2: Tensor, max1: Tensor, - max2: Optional[torch.Tensor], + max2: Tensor, new_max1: Tensor, - new_max2: Optional[torch.Tensor], + new_max2: Tensor, weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: Optional[torch.Tensor] = None, @@ -1741,7 +1422,7 @@ def optimizer_update_8bit( ) else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) post_call(prev_device) @@ -1751,20 +1432,21 @@ def optimizer_update_8bit_blockwise( g: Tensor, p: Tensor, state1: Tensor, - state2: Optional[torch.Tensor], + state2: Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Optional[torch.Tensor], + qmap2: Tensor, absmax1: Tensor, - absmax2: Optional[torch.Tensor], + absmax2: Tensor, weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: + optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1772,15 +1454,12 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) post_call(prev_device) @@ -1808,8 +1487,9 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) - -def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): +def percentile_clipping( + grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 +): """Applies percentile clipping grad: torch.Tensor @@ -1851,7 +1531,9 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): +def histogram_scatter_add_2d( + histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor +): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 @@ -1868,12 +1550,12 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, is_on_gpu([histogram, index1, index2, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) - def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): - torch.cuda.init() + if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}") + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) sA = A.shape sB = B.shape @@ -1914,7 +1596,12 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]: + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): correct = True else: if len(sA) == 2 and len(sB) == 2: @@ -1947,29 +1634,26 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 if not correct: raise ValueError( - f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.", + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." ) return sout - def gemv_4bit( A: Tensor, B: Tensor, out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, - state=None, + state=None ): prev_device = pre_call(A.device) - # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") + raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') if A.numel() != A.shape[-1]: - raise ValueError( - 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', - ) + raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') Bshape = state.shape bout = Bshape[0] @@ -1989,7 +1673,7 @@ def gemv_4bit( k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (A.shape[-1] + 1) // 2 + ldb = (A.shape[-1]+1)//2 is_on_gpu([B, A, out, absmax, state.code]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -2000,61 +1684,21 @@ def gemv_4bit( if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - ) + lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') post_call(prev_device) return out - def igemm( A: Tensor, B: Tensor, @@ -2120,7 +1764,7 @@ def igemm( assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}", + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" ) transposed_A = True @@ -2139,20 +1783,8 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm( - ptr, - ct.c_bool(transposed_B), - ct.c_bool(transposed_A), - ct.c_int32(m), - ct.c_int32(n), - ct.c_int32(k), - get_ptr(B), - get_ptr(A), - get_ptr(out), - ct.c_int32(lda), - ct.c_int32(ldb), - ct.c_int32(ldc), - ) + lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) return out @@ -2164,7 +1796,9 @@ def batched_igemm( transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}") + raise ValueError( + f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" + ) sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) @@ -2231,24 +1865,9 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm( - ptr, - ct.c_bool(transposed_B), - ct.c_bool(transposed_A), - ct.c_int32(m), - ct.c_int32(n), - ct.c_int32(k), - get_ptr(B), - get_ptr(A), - get_ptr(out), - ct.c_int32(lda), - ct.c_int32(ldb), - ct.c_int32(ldc), - ct.c_long(strideA), - ct.c_long(strideB), - ct.c_long(strideC), - ct.c_uint32(num_batch), - ) + lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), + ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out @@ -2257,14 +1876,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" + assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -2273,9 +1892,13 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) assert dimsB != 3, "len(B.shape)==3 not supported" assert A.device.type == "cuda" @@ -2317,33 +1940,49 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == "col_turing": + if formatB == 'col_turing': if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") if has_error: - print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") - raise Exception("cublasLt ran into an error!") + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') + raise Exception('cublasLt ran into an error!') torch.cuda.set_device(prev_device) return out, Sout -def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None +): assert A.dtype == torch.int32 - if bias is not None: - assert bias.dtype == torch.float16 + if bias is not None: assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -2351,11 +1990,19 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" prev_device = pre_call(A.device) ptrA = get_ptr(A) @@ -2369,23 +2016,15 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16( - ptrA, - ptrRowStats, - ptrColStats, - ptrOut, - ptrNewRowStats, - ptrNewColStats, - ptrBias, - numRows, - numCols, - ) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) post_call(prev_device) return out -def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): assert A.dtype == torch.float16 device = A.device @@ -2398,12 +2037,18 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -2477,10 +2122,14 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) + return CSRSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values + ) def coo2csc(cooA): @@ -2489,10 +2138,14 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -2502,7 +2155,9 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -2515,7 +2170,9 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -2533,7 +2190,9 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) @@ -2592,16 +2251,12 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: - new_state = (state[0], to_order) # (shape, order) + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2612,7 +2267,7 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == "col32": + if to_order == 'col32': if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2633,7 +2288,7 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") + raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') post_call(prev_device) @@ -2642,7 +2297,9 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2669,28 +2326,16 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo( - ptr, - ptrRowidx, - ptrColidx, - ptrValues, - cnnz, - crowsA, - ccolsA, - ccolsB, - cldb, - ptrB, - cldc, - ptrC, - ct.c_bool(transposed_B), - ) + lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) return out def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): if out is None: - out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) + out = torch.zeros( + (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype + ) nnz = cooA.nnz prev_device = pre_call(B.device) assert cooA.rowidx.numel() == nnz @@ -2708,7 +2353,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}." + assert ( + max_count[0] <= 32 + ), f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -2796,7 +2443,9 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True) + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( + x, dim=dim, keepdim=True + ) dyna[dyna == 0] = 1 qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) @@ -2904,7 +2553,9 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2914,7 +2565,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == "col_turing": + if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2922,7 +2573,6 @@ def extract_outliers(A, SA, idx): return out - def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ec14e5940..f7b96205b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -44,7 +44,6 @@ class StableEmbedding(torch.nn.Embedding): reset_parameters(): Reset embedding parameters using Xavier uniform initialization. forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. """ - def __init__( self, num_embeddings: int, @@ -90,7 +89,9 @@ def __init__( dtype, ) self.norm = torch.nn.LayerNorm(embedding_dim, device=device) - GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -129,7 +130,6 @@ class Embedding(torch.nn.Embedding): """ Embedding class to store and retrieve word embeddings from their indices. """ - def __init__( self, num_embeddings: int, @@ -170,9 +170,11 @@ def __init__( scale_grad_by_freq, sparse, _weight, - device=device, + device=device + ) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} ) - GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -206,16 +208,16 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): def __new__( - cls, - data: Optional[torch.Tensor] = None, - requires_grad=False, # quantized weights should be frozen by default - quant_state: Optional[QuantState] = None, - blocksize: int = 64, - compress_statistics: bool = True, - quant_type: str = "fp4", - quant_storage: torch.dtype = torch.uint8, - module: Optional["Linear4bit"] = None, - bnb_quantized: bool = False, + cls, + data: Optional[torch.Tensor] = None, + requires_grad=False, # quantized weights should be frozen by default + quant_state: Optional[QuantState] = None, + blocksize: int = 64, + compress_statistics: bool = True, + quant_type: str = 'fp4', + quant_storage: torch.dtype = torch.uint8, + module: Optional["Linear4bit"] = None, + bnb_quantized: bool = False ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -248,7 +250,7 @@ def __setstate__(self, state): self.bnb_quantized = state["bnb_quantized"] self.module = state["module"] - def __deepcopy__(self, memo): + def __deepcopy__(self,memo): new_instance = type(self).__new__(type(self)) state = self.__getstate__() new_instance.__setstate__(state) @@ -263,14 +265,7 @@ def __copy__(self): return new_instance @classmethod - def from_prequantized( - cls, - data: torch.Tensor, - quantized_stats: Dict[str, Any], - requires_grad: bool = False, - device="cuda", - **kwargs, - ) -> "Params4bit": + def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) self.requires_grad = requires_grad self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) @@ -297,39 +292,33 @@ def _quantize(self, device): return self def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): - return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) @overload - def to( - self: T, - device: Optional[Union[int, device]] = ..., - dtype: Optional[Union[dtype, str]] = ..., - non_blocking: bool = ..., - ) -> T: ... + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and not self.bnb_quantized: + if (device is not None and device.type == "cuda" and not self.bnb_quantized): return self._quantize(device) else: if self.quant_state is not None: self.quant_state.to(device) - new_param = Params4bit( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, - quant_state=self.quant_state, - blocksize=self.blocksize, - compress_statistics=self.compress_statistics, - quant_type=self.quant_type, - ) + new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state, + blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type) return new_param @@ -366,18 +355,7 @@ class Linear4bit(nn.Linear): quantized_model = quantized_model.to(0) # Quantization happens here ``` """ - - def __init__( - self, - input_features, - output_features, - bias=True, - compute_dtype=None, - compress_statistics=True, - quant_type="fp4", - quant_storage=torch.uint8, - device=None, - ): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): """ Initialize Linear4bit class. @@ -390,14 +368,7 @@ def __init__( Whether the linear class uses the bias term as well. """ super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit( - self.weight.data, - requires_grad=False, - compress_statistics=compress_statistics, - quant_type=quant_type, - quant_storage=quant_storage, - module=self, - ) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False @@ -414,15 +385,11 @@ def set_compute_type(self, x): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this - warnings.warn( - "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.", - ) - warnings.filterwarnings("ignore", message=".*inference.") + warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') + warnings.filterwarnings('ignore', message='.*inference.') if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): - warnings.warn( - "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.", - ) - warnings.filterwarnings("ignore", message=".*inference or training") + warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') + warnings.filterwarnings('ignore', message='.*inference or training') def _save_to_state_dict(self, destination, prefix, keep_vars): """ @@ -440,8 +407,8 @@ def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, "quant_state", None) is None: - if getattr(self, "quant_state", None) is not None: + if getattr(self.weight, 'quant_state', None) is None: + if getattr(self, 'quant_state', None) is not None: # the quant state got lost when the parameter got converted. This happens for example for fsdp # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 @@ -449,9 +416,7 @@ def forward(self, x: torch.Tensor): self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight.quant_state = self.quant_state else: - print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", - ) + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -472,17 +437,7 @@ class LinearFP4(Linear4bit): """ Implements the FP4 data type. """ - - def __init__( - self, - input_features, - output_features, - bias=True, - compute_dtype=None, - compress_statistics=True, - quant_storage=torch.uint8, - device=None, - ): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): """ Args: input_features (`str`): @@ -492,40 +447,21 @@ def __init__( bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. """ - super().__init__( - input_features, - output_features, - bias, - compute_dtype, - compress_statistics, - "fp4", - quant_storage, - device, - ) + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) class LinearNF4(Linear4bit): - """Implements the NF4 data type. - - Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that - is normalized into the range [-1, 1]. + ''' Implements the NF4 data type. - For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) + Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that + is normalized into the range [-1, 1]. - Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in - the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - """ + For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) - def __init__( - self, - input_features, - output_features, - bias=True, - compute_dtype=None, - compress_statistics=True, - quant_storage=torch.uint8, - device=None, - ): + Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in + the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. + ''' + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): """ Args: input_features (`str`): @@ -535,16 +471,7 @@ def __init__( bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. """ - super().__init__( - input_features, - output_features, - bias, - compute_dtype, - compress_statistics, - "nf4", - quant_storage, - device, - ) + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) class Int8Params(torch.nn.Parameter): @@ -587,22 +514,33 @@ def to( device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ..., - ) -> T: ... + ) -> T: + ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) - if device is not None and device.type == "cuda" and self.data.device.type == "cpu": + if ( + device is not None + and device.type == "cuda" + and self.data.device.type == "cpu" + ): return self.cuda(device) else: new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), + super().to( + device=device, dtype=dtype, non_blocking=non_blocking + ), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -655,25 +593,15 @@ class Linear8bitLt(nn.Linear): int8_model = int8_model.to(0) # Quantization happens here ``` """ - - def __init__( - self, - input_features: int, - output_features: int, - bias=True, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, - device=None, - ): + def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None, device=None): """ Initialize Linear8bitLt class. Args: - input_features (`int`): + input_features (`str`): Number of input features of the linear layer. - output_features (`int`): + output_features (`str`): Number of output features of the linear layer. bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. @@ -719,36 +647,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[format_name] = self.state.formatB - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) unexpected_copy = list(unexpected_keys) for key in unexpected_copy: - input_name = key[len(prefix) :] + input_name = key[len(prefix):] if input_name == "SCB": if self.weight.SCB is None: # buffers not yet initialized, can't access them directly without quantizing first - raise RuntimeError( - "Loading a quantized checkpoint into non-quantized Linear8bitLt is " - "not supported. Please call module.cuda() before module.load_state_dict()", - ) + raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()") input_param = state_dict[key] self.weight.SCB.copy_(input_param) @@ -791,18 +702,18 @@ def __init__(self, input_features, output_features, bias=True, device=None): self.is_quantized = False def forward_with_outliers(self, x, outlier_idx): - raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function") + raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') def quantize_weight(self, w, outlier_idx): - raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function") + raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') def forward(self, x): if self.outlier_dim is None: tracer = OutlierTracer.get_instance() if not tracer.is_initialized(): - print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer") + print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') outlier_idx = tracer.get_outliers(self.weight) - # print(outlier_idx, tracer.get_hvalue(self.weight)) + #print(outlier_idx, tracer.get_hvalue(self.weight)) self.outlier_dim = outlier_idx if not self.is_quantized: @@ -810,7 +721,6 @@ def forward(self, x): self.weight.data.copy_(w) self.is_quantized = True - class SwitchBackLinearBnb(nn.Linear): def __init__( self, @@ -821,9 +731,11 @@ def __init__( memory_efficient_backward=False, threshold=0.0, index=None, - device=None, + device=None ): - super().__init__(input_features, output_features, bias, device) + super().__init__( + input_features, output_features, bias, device + ) self.state = bnb.MatmulLtState() self.index = index @@ -833,7 +745,9 @@ def __init__( if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index aa8494942..9c7738c59 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -22,6 +22,7 @@ class _switchback_global(torch.autograd.Function): + @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -36,7 +37,9 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) + return int8_matmul_mixed_dequantize( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -53,8 +56,7 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_global_transpose(W) grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], - -1, + *G_3D.size()[:-1], -1 ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -64,8 +66,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias - class _switchback_vectorrize(torch.autograd.Function): + @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -79,7 +81,9 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call kernel which expects rowwise quantized X and W - return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) + return int8_matmul_rowwise_dequantize( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -95,8 +99,7 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_columnwise_and_transpose(W) grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], - -1, + *G_3D.size()[:-1], -1 ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -106,8 +109,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias - class _switchback_global_mem_efficient(torch.autograd.Function): + @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -124,7 +127,9 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1) + return int8_matmul_mixed_dequantize( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D_sz[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -146,34 +151,35 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) del G W_int8 = W_int8.t().contiguous() - grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1) + grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D_sz[:-1], -1 + ) return grad_X, grad_W, grad_bias - class SwitchBackLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - vector_wise_quantization: bool = False, - mem_efficient: bool = False, - ): + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + vector_wise_quantization: bool = False, + mem_efficient : bool = False, + ): super().__init__(in_features, out_features, bias, device, dtype) if not is_triton_available(): - raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear. - Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""") + raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') # By default, we use the global quantization. self.vector_wise_quantization = vector_wise_quantization if self.vector_wise_quantization: self._fn = _switchback_vectorrize if mem_efficient: - print("mem efficient is not supported for vector-wise quantization.") + print('mem efficient is not supported for vector-wise quantization.') exit(1) else: if mem_efficient: @@ -189,7 +195,7 @@ def prepare_for_eval(self): # if hasattr(m, "prepare_for_eval"): # m.prepare_for_eval() # model.apply(cond_prepare) - print("=> preparing for eval.") + print('=> preparing for eval.') if self.vector_wise_quantization: W_int8, state_W = quantize_rowwise(self.weight) else: @@ -213,22 +219,18 @@ def forward(self, x): X_int8, state_X = quantize_rowwise(X) if self.vector_wise_quantization: - return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( - *x.size()[:-1], - -1, - ) + return int8_matmul_rowwise_dequantize( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) else: - return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( - *x.size()[:-1], - -1, - ) - + return int8_matmul_mixed_dequantize( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) - # This is just the standard linear function. class StandardLinearFunction(torch.autograd.Function): @staticmethod @@ -258,7 +260,7 @@ def backward(ctx, grad_output_3D): return grad_input, grad_weight, grad_bias - class StandardLinear(nn.Linear): + def forward(self, x): return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index b4c95793a..6796b8e0e 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from bitsandbytes.cextension import COMPILED_WITH_CUDA + from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit from .adamw import ( diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 7459dece1..c2ea87ab0 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -38,8 +38,8 @@ def __init__( The epsilon value prevents division by zero in the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -50,7 +50,9 @@ def __init__( if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -105,8 +107,8 @@ def __init__( The epsilon value prevents division by zero in the optimizer. optim_bits (`int`, defaults to 8): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -117,7 +119,9 @@ def __init__( if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -173,8 +177,8 @@ def __init__( The epsilon value prevents division by zero in the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -185,7 +189,9 @@ def __init__( if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 740db26ac..e534c8b8f 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,21 +14,8 @@ class Adam(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ Base Adam optimizer. @@ -47,8 +34,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -58,38 +45,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 8-bit Adam optimizer. @@ -108,8 +68,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -119,38 +79,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 32-bit Adam optimizer. @@ -169,8 +102,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -180,38 +113,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class PagedAdam(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ Paged Adam optimizer. @@ -230,8 +136,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -241,38 +147,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdam8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 8-bit paged Adam optimizer. @@ -291,8 +170,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -302,38 +181,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdam32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ Paged 32-bit Adam optimizer. @@ -352,8 +204,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -363,21 +215,7 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. @@ -455,7 +293,9 @@ def step(self, closure=None): if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) amsgrad = group.get("amsgrad", False) assert not amsgrad @@ -472,9 +312,15 @@ def step(self, closure=None): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device) - state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device) - state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["abserrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["relerrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -482,19 +328,25 @@ def step(self, closure=None): state["exp_avg"] = state["exp_avg"].to(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) if amsgrad: - state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( + p_data_fp32 + ) state["step"] += 1 beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: @@ -507,7 +359,10 @@ def step(self, closure=None): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -546,7 +401,9 @@ def step(self, closure=None): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f"Invalid analysis value: {self.analysis}!") + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -558,7 +415,9 @@ def step(self, closure=None): F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) + F.histogram_scatter_add_2d( + counts, C1.int(), C2.int(), torch.ones_like(abserr) + ) p_data_fp32 += -step_size * update_fp32 @@ -566,10 +425,18 @@ def step(self, closure=None): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) - pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl") - pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl") - pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl") + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) + pathe = os.path.join( + self.savedir, f"{p_id}_{shapestr}_abserr.pkl" + ) + pathrele = os.path.join( + self.savedir, f"{p_id}_{shapestr}_relerr.pkl" + ) + pathcounts = os.path.join( + self.savedir, f"{p_id}_{shapestr}_counts.pkl" + ) torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 4bf3f6436..1e2dc04de 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -6,21 +6,8 @@ class AdamW(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ Base AdamW optimizer. @@ -39,8 +26,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -50,38 +37,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 8-bit AdamW optimizer. @@ -100,8 +60,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -111,38 +71,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 32-bit AdamW optimizer. @@ -161,8 +94,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -172,37 +105,12 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class PagedAdamW(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): """ Paged AdamW optimizer. @@ -221,8 +129,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -232,37 +140,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): """ Paged 8-bit AdamW optimizer. @@ -281,8 +163,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -292,37 +174,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedAdamW32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): """ Paged 32-bit AdamW optimizer. @@ -341,8 +197,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -352,17 +208,4 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 8d29cbbfe..ec829ee85 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -45,8 +45,8 @@ def __init__( Whether to use the AdamW variant. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -109,8 +109,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. adam_w_mode (`bool`, defaults to `True`): Whether to use the AdamW variant. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -173,8 +173,8 @@ def __init__( Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. adam_w_mode (`bool`, defaults to `True`): Whether to use the AdamW variant. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 90c3686fe..7449b805b 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -41,8 +41,8 @@ def __init__( Whether to use Nesterov momentum. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -51,7 +51,9 @@ def __init__( The maximum gradient norm. """ if momentum == 0: - raise NotImplementedError("LARS without momentum is not supported!") + raise NotImplementedError( + "LARS without momentum is not supported!" + ) super().__init__( "lars", params, @@ -98,8 +100,8 @@ def __init__( The weight decay value for the optimizer. nesterov (`bool`, defaults to `False`): Whether to use Nesterov momentum. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -108,7 +110,9 @@ def __init__( The maximum gradient norm. """ if momentum == 0: - raise NotImplementedError("LARS without momentum is not supported!") + raise NotImplementedError( + "LARS without momentum is not supported!" + ) super().__init__( "lars", params, @@ -155,8 +159,8 @@ def __init__( The weight decay value for the optimizer. nesterov (`bool`, defaults to `False`): Whether to use Nesterov momentum. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -165,7 +169,9 @@ def __init__( The maximum gradient norm. """ if momentum == 0: - raise NotImplementedError("LARS without momentum is not supported!") + raise NotImplementedError( + "LARS without momentum is not supported!" + ) super().__init__( "lars", params, @@ -198,7 +204,9 @@ def __init__( if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) defaults = dict( lr=lr, @@ -209,7 +217,9 @@ def __init__( max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening" + ) super().__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2e4163694..ce185f863 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -6,19 +6,7 @@ class Lion(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ Base Lion optimizer. @@ -33,8 +21,8 @@ def __init__( The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -44,35 +32,10 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "lion", - params, - lr, - betas, - 0.0, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion8bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 8-bit Lion optimizer. @@ -85,8 +48,8 @@ def __init__( The beta values are the decay rates of the first and second-order moment of the optimizer. weight_decay (`float`, defaults to 0): The weight decay value for the optimizer. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -96,35 +59,10 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "lion", - params, - lr, - betas, - 0.0, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion32bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): """ 32-bit Lion optimizer. @@ -137,8 +75,8 @@ def __init__( The beta values are the decay rates of the first and second-order moment of the optimizer. weight_decay (`float`, defaults to 0): The weight decay value for the optimizer. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -148,35 +86,11 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ - super().__init__( - "lion", - params, - lr, - betas, - 0.0, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class PagedLion(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): """ Paged Lion optimizer. @@ -191,8 +105,8 @@ def __init__( The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -200,34 +114,10 @@ def __init__( block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ - super().__init__( - "lion", - params, - lr, - betas, - 0.0, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedLion8bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): """ Paged 8-bit Lion optimizer. @@ -242,8 +132,8 @@ def __init__( The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -251,34 +141,10 @@ def __init__( block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ - super().__init__( - "lion", - params, - lr, - betas, - 0.0, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) - + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class PagedLion32bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): """ Paged 32-bit Lion optimizer. @@ -293,8 +159,8 @@ def __init__( The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -302,17 +168,4 @@ def __init__( block_wise (`bool`, defaults to `True`): Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ - super().__init__( - "lion", - params, - lr, - betas, - 0.0, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=True, - ) + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index f1e60e5e7..a97afb026 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -21,7 +21,6 @@ class GlobalOptimManager: """ A global optimizer manager for enabling custom optimizer configs. """ - _instance = None def __init__(self): @@ -49,9 +48,13 @@ def register_parameters(self, params): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[id(p)] + self.index2config[(group_index, p_index)] = self.pid2config[ + id(p) + ] - def override_config(self, parameters, key=None, value=None, key_value_dict=None): + def override_config( + self, parameters, key=None, value=None, key_value_dict=None + ): """ Override initial optimizer config with specific hyperparameters. @@ -129,18 +132,18 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { - "qmap1", - "qmap2", - "max1", - "max2", - "new_max1", - "new_max2", - "state1", - "state2", - "gnorm_vec", - "absmax1", - "absmax2", - "unorm_vec", + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", } if optim_bits == 8: @@ -167,12 +170,16 @@ def load_state_dict(self, state_dict): saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of parameter groups") + raise ValueError( + "loaded state dict has a different number of " + "parameter groups" + ) param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( - "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" ) # Update the state @@ -221,7 +228,9 @@ def update_group(group, new_group): new_group["params"] = group["params"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups) + ] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -231,7 +240,7 @@ def to_gpu(self): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - is_paged = getattr(v, "is_paged", False) + is_paged = getattr(v, 'is_paged', False) if not is_paged: self.state[p][k] = v.to(p.device) @@ -239,7 +248,9 @@ def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + assert isinstance(pmodule, torch.Tensor) or isinstance( + pmodule, torch.Parameter + ) found = False for gindex, group in enumerate(self.param_groups): if found: @@ -251,7 +262,9 @@ def check_overrides(self): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] + self.mng.index2config[ + (gindex, pindex) + ] = self.mng.pid2config[id(p)] found = True @torch.no_grad() @@ -274,7 +287,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True - # if self.is_paged: self.page_mng.prefetch_all() + #if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -291,6 +304,7 @@ def step(self, closure=None): # to sync to make sure all tensors are in the right state torch.cuda.synchronize() + return loss def get_config(self, gindex, pindex, group): @@ -314,7 +328,9 @@ def init_state(self, group, p, gindex, pindex): raise NotImplementedError("init_state method needs to be overridden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError("The update_step method needs to be overridden") + raise NotImplementedError( + "The update_step method needs to be overridden" + ) def get_state_buffer(self, p, dtype=torch.float32): if not self.is_paged or p.numel() < 1e5: @@ -329,12 +345,12 @@ def get_state_buffer(self, p, dtype=torch.float32): def prefetch_state(self, p): if self.is_paged: state = self.state[p] - s1 = state["state1"] - is_paged = getattr(s1, "is_paged", False) + s1 = state['state1'] + is_paged = getattr(s1, 'is_paged', False) if is_paged: - F.prefetch_tensor(state["state1"]) - if "state2" in state: - F.prefetch_tensor(state["state2"]) + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) class Optimizer2State(Optimizer8bit): @@ -353,7 +369,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False, + is_paged=False ): """ Base 2-state update optimizer class. @@ -373,8 +389,8 @@ def __init__( The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -398,9 +414,13 @@ def __init__( betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -429,7 +449,9 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -437,15 +459,21 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( + p.device + ) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -458,13 +486,25 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + state["absmax2"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -484,10 +524,7 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, - state["gnorm_vec"], - step, - config["percentile_clipping"], + grad, state["gnorm_vec"], step, config["percentile_clipping"] ) else: gnorm_scale = 1.0 @@ -531,7 +568,9 @@ def update_step(self, group, p, gindex, pindex): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + unorm_vec=state["unorm_vec"] + if config["max_unorm"] > 0.0 + else None, max_unorm=config["max_unorm"], ) @@ -576,7 +615,7 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False, + is_paged=False ): """ Base 1-state update optimizer class. @@ -596,8 +635,8 @@ def __init__( The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -617,9 +656,13 @@ def __init__( raise ValueError(f"Invalid epsilon value: {eps}") for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -648,7 +691,9 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -656,13 +701,17 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -672,10 +721,16 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -695,10 +750,7 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, - state["gnorm_vec"], - step, - config["percentile_clipping"], + grad, state["gnorm_vec"], step, config["percentile_clipping"] ) else: gnorm_scale = 1.0 @@ -714,7 +766,7 @@ def update_step(self, group, p, gindex, pindex): step, config["lr"], None, - config["betas"][1], + config['betas'][1], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 25611309b..ac371a66f 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -41,8 +41,8 @@ def __init__( Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -51,7 +51,9 @@ def __init__( Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: - raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -104,8 +106,8 @@ def __init__( Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -114,7 +116,9 @@ def __init__( Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. """ if alpha == 0: - raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -167,8 +171,8 @@ def __init__( Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -178,7 +182,9 @@ def __init__( """ if alpha == 0: - raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index ec18f036c..0f0b12e4b 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -38,8 +38,8 @@ def __init__( Whether to use Nesterov momentum. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -94,8 +94,8 @@ def __init__( The weight decay value for the optimizer. nesterov (`bool`, defaults to `False`): Whether to use Nesterov momentum. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): @@ -150,8 +150,8 @@ def __init__( The weight decay value for the optimizer. nesterov (`bool`, defaults to `False`): Whether to use Nesterov momentum. - args (`object`, defaults to `None`): - An object with additional arguments. + args (`dict`, defaults to `None`): + A dictionary with additional arguments. min_8bit_size (`int`, defaults to 4096): The minimum number of elements of the parameter tensors for 8-bit optimization. percentile_clipping (`int`, defaults to 100): diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index b194b8777..7d869e39a 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -195,9 +195,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -216,7 +216,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A.to(torch.float16), threshold=state.threshold + ) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -232,14 +234,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - # print('A shape', A.shape) + #print('A shape', A.shape) if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None # 2. Quantize B if state.has_fp16_weights: - # print('B shape', B.shape) + #print('B shape', B.shape) has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: @@ -270,7 +272,12 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .to(A.dtype) + ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -313,13 +320,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B00 ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - clone_func = torch.clone if len(output_shape) == 3 else lambda x: x + + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors @@ -334,7 +342,9 @@ def backward(ctx, grad_output): # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + grad_output = grad_output.reshape( + -1, grad_output.shape[-1] + ).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) @@ -347,24 +357,25 @@ def backward(ctx, grad_output): if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) # print('back B shape', state.CxBt.shape) # print('back grad shape', C32grad.shape) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception("State must contain either CBt or CB matrix for backward") + raise Exception('State must contain either CBt or CB matrix for backward') return grad_A, grad_B, None, grad_bias, None - def get_block_sizes(input_matrix, weight_matrix): input_features = input_matrix.shape[-1] - output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1] + output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) array = [4096, 2048, 1024, 512, 256, 128, 64, 0] bsz, bsz2 = 1024, 1024 for i, k in enumerate(array): @@ -388,8 +399,7 @@ def matmul_fp8_global( bsz: int = -1, bsz2: int = -1, ): - if bsz == -1 or bsz2 == -1: - bsz, bsz2 = get_block_sizes(A, B) + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) @@ -402,8 +412,7 @@ def matmul_fp8_mixed( bsz: int = -1, bsz2: int = -1, ): - if bsz == -1 or bsz2 == -1: - bsz, bsz2 = get_block_sizes(A, B) + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) @@ -413,7 +422,7 @@ def switchback_bnb( out: Optional[torch.Tensor] = None, state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None, + bias=None ): state = state or MatmulLtState() if threshold > 0.0: diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py index 57c0f3358..7fca34d23 100644 --- a/bitsandbytes/research/nn/modules.py +++ b/bitsandbytes/research/nn/modules.py @@ -28,20 +28,12 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.research.matmul_fp8_mixed( - x, - self.weight.t(), - fw_code=self.fw_code, - bw_code=self.bw_code, - bsz=self.bsz, - bsz2=self.bsz2, - ) + out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) if self.bias is not None: out += self.bias return out - class LinearFP8Global(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) @@ -62,14 +54,7 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8_global( - x, - self.weight.t(), - fw_code=self.fw_code, - bw_code=self.bw_code, - bsz=self.bsz, - bsz2=self.bsz2, - ) + out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) if self.bias is not None: out += self.bias diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py index 26eab84f2..3d7529852 100644 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -5,10 +5,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): - return None + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None else: + import triton import triton.language as tl @@ -16,21 +15,21 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=["n_elements"], + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] ) @triton.jit def _dequantize_rowwise( @@ -52,6 +51,7 @@ def _dequantize_rowwise( output = max_val * x * inv_127 tl.store(output_ptr + offsets, output, mask=row_mask) + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) @@ -60,5 +60,5 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (x.shape[0],) - _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index 583371d91..dc3047d7e 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -3,14 +3,14 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): - return None + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None else: + import triton import triton.language as tl from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + # This is a matmul kernel based on triton.ops.matmul # It is modified to support rowwise quantized input and global quantized weight # It's purpose is fused matmul then dequantize @@ -27,83 +27,58 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, - num_stages=num_stages, - num_warps=num_warps, - ), - ) + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) # split_k for split_k in [2, 4, 8, 16]: - configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ), - ) + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs + @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), *get_configs_io_bound(), ], - key=["M", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, - ) - @triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 }, ) + @triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, + }) @triton.jit - def _int8_matmul_mixed_dequantize( - A, - B, - C, - bias, - state_x_ptr, - state_w_ptr, - M, - N, - K, - divfactor: tl.constexpr, - has_bias: tl.constexpr, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - ): + def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -140,13 +115,13 @@ def _int8_matmul_mixed_dequantize( b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = w_factor * (x_factor * (acc * divfactor)) + acc = (w_factor * (x_factor * (acc * divfactor))) acc = acc.to(C.dtype.element_ty) # conditionally add bias @@ -162,9 +137,10 @@ def _int8_matmul_mixed_dequantize( else: tl.atomic_add(C, acc, mask=mask) + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): device = a.device - divfactor = 1.0 / (127.0 * 127.0) + divfactor = 1. / (127. * 127.) has_bias = 0 if bias is None else 1 # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -178,28 +154,12 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_mixed_dequantize kernel - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) - _int8_matmul_mixed_dequantize[grid]( - a, - b, - c, - bias, - state_x, - state_w, - M, - N, - K, - divfactor, - has_bias, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - GROUP_M=8, - ACC_TYPE=ACC_TYPE, - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index e3d192ded..4881e1468 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -3,9 +3,7 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - return None + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None else: import triton import triton.language as tl @@ -19,6 +17,7 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): def init_to_zero(name): return lambda nargs: nargs[name].zero_() + def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -27,83 +26,58 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, - num_stages=num_stages, - num_warps=num_warps, - ), - ) + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) # split_k for split_k in [2, 4, 8, 16]: - configs.append( - triton.Config( - {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ), - ) + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs + @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), *get_configs_io_bound(), ], - key=["M", "N", "K"], - prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, - ) - @triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 }, ) + @triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, + }) @triton.jit - def _int8_matmul_rowwise_dequantize( - A, - B, - C, - bias, - state_x_ptr, - state_w_ptr, - M, - N, - K, - divfactor, - has_bias: tl.constexpr, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - ): + def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -140,13 +114,13 @@ def _int8_matmul_rowwise_dequantize( b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = w_factor * (x_factor * (acc * divfactor)) + acc = (w_factor * (x_factor * (acc * divfactor))) acc = acc.to(C.dtype.element_ty) if has_bias: @@ -161,8 +135,9 @@ def _int8_matmul_rowwise_dequantize( else: tl.atomic_add(C, acc, mask=mask) + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - divfactor = 1.0 / (127.0 * 127.0) + divfactor = 1. / (127. * 127.) has_bias = 0 if bias is None else 1 @@ -179,28 +154,12 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_rowwise_dequantize kernel - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) - _int8_matmul_rowwise_dequantize[grid]( - a, - b, - c, - bias, - state_x, - state_w, - M, - N, - K, - divfactor, - has_bias, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - GROUP_M=8, - ACC_TYPE=ACC_TYPE, - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index b8eeffd0c..e7961cf53 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -5,10 +5,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - - def quantize_columnwise_and_transpose(x: torch.Tensor): - return None + def quantize_columnwise_and_transpose(x: torch.Tensor): return None else: + import triton import triton.language as tl @@ -16,23 +15,23 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_stages=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=16, num_warps=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=["n_elements"], + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] ) @triton.jit def _quantize_columnwise_and_transpose( @@ -40,8 +39,7 @@ def _quantize_columnwise_and_transpose( output_ptr, output_maxs, n_elements, - M: tl.constexpr, - N: tl.constexpr, + M : tl.constexpr, N : tl.constexpr, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, ): @@ -49,12 +47,12 @@ def _quantize_columnwise_and_transpose( block_start = pid p2_arange = tl.arange(0, P2) p2_arange_mask = p2_arange < M - arange = p2_arange * N + arange = p2_arange * N offsets = block_start + arange x = tl.load(x_ptr + offsets, mask=p2_arange_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127.0 * (x / max_val)) + output = tl.libdevice.llrint(127. * (x / max_val)) new_start = pid * M new_offsets = new_start + p2_arange @@ -70,6 +68,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) return output, output_maxs diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index f35bdd304..5cf194744 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -1,25 +1,24 @@ + import torch from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - - def quantize_global_transpose(input): - return None - - def quantize_global(x: torch.Tensor): - return None + def quantize_global_transpose(input): return None + def quantize_global(x: torch.Tensor): return None else: + import triton import triton.language as tl # global quantize @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), - triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), - ], - key=["n_elements"], + configs=[ + triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), + triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), + + ], + key=['n_elements'] ) @triton.jit def _quantize_global( @@ -35,43 +34,35 @@ def _quantize_global( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) + output = tl.libdevice.llrint(127. * (x * absmax_inv)) tl.store(output_ptr + offsets, output, mask=mask) def quantize_global(x: torch.Tensor): absmax = x.abs().max().unsqueeze(0) - absmax_inv = 1.0 / absmax - output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) + absmax_inv = 1./ absmax + output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) _quantize_global[grid](x, absmax_inv, output, n_elements) return output, absmax + # global quantize and transpose @triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), - # ... - ], - key=["M", "N"], + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + + # ... + ], + key=['M', 'N'] ) @triton.jit - def _quantize_global_transpose( - A, - absmax_inv_ptr, - B, - stride_am, - stride_an, - stride_bn, - stride_bm, - M, - N, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - GROUP_M: tl.constexpr, - ): + def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + GROUP_M : tl.constexpr): pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N @@ -95,30 +86,20 @@ def _quantize_global_transpose( B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) mask = (rm < M)[:, None] & (rn < N)[None, :] - output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) + output = tl.libdevice.llrint(127. * (a * absmax_inv)) tl.store(B, output, mask=mask) def quantize_global_transpose(input): absmax = input.abs().max().unsqueeze(0) - absmax_inv = 1.0 / absmax + absmax_inv = 1./ absmax M, N = input.shape - out = torch.empty(N, M, device="cuda", dtype=torch.int8) + out = torch.empty(N, M, device='cuda', dtype=torch.int8) assert out.size(0) == N and out.size(1) == M assert input.stride(0) == 1 or input.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1 - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) - _quantize_global_transpose[grid]( - input, - absmax_inv, - out, - input.stride(0), - input.stride(1), - out.stride(0), - out.stride(1), - M, - N, - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) return out, absmax diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index f92ace02c..078f4aa2d 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -5,10 +5,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - - def quantize_rowwise(x: torch.Tensor): - return None + def quantize_rowwise(x: torch.Tensor): return None else: + import triton import triton.language as tl @@ -16,21 +15,21 @@ def quantize_rowwise(x: torch.Tensor): # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=["n_elements"], + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] ) @triton.jit def _quantize_rowwise( @@ -50,7 +49,7 @@ def _quantize_rowwise( abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127.0 * (x / max_val)) + output = tl.libdevice.llrint(127. * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0229e59e2..0582f7fc0 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -30,7 +30,7 @@ def outlier_hook(module, input): # (1) zscore test of std of hidden dimension outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) # (2) magnitude > 6 test - dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1))) + dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) outlier_idx2 = torch.where(dims > 0)[0] outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() tracer.hvalue2outlier_idx[hvalue] = outlier_idx @@ -59,14 +59,14 @@ def initialize(self, model): self.hooks.append(m.register_forward_pre_hook(outlier_hook)) def is_initialized(self): - return getattr(self, "initialized", False) + return getattr(self, 'initialized', False) def get_hvalue(self, weight): return weight.data.storage().data_ptr() def get_outliers(self, weight): if not self.is_initialized(): - print("Outlier tracer is not initialized...") + print('Outlier tracer is not initialized...') return None hvalue = self.get_hvalue(weight) if hvalue in self.hvalue2outlier_idx: @@ -80,7 +80,6 @@ def get_instance(cls): cls._instance = cls.__new__(cls) return cls._instance - def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): if rdm: return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() @@ -88,13 +87,13 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) m = weight.mean(reduction_dim) mm = m.mean() mstd = m.std() - zm = (m - mm) / mstd + zm = (m-mm)/mstd std = weight.std(reduction_dim) stdm = std.mean() stdstd = std.std() - zstd = (std - stdm) / stdstd + zstd = (std-stdm)/stdstd if topk is not None: val, idx = torch.topk(std.abs(), k=topk, dim=0) @@ -106,7 +105,10 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False) def execute_and_return(command_string: str) -> Tuple[str, str]: def _decode(subprocess_err_out_tuple): - return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple) + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) def execute_and_return_decoded_std_streams(command_string): return _decode( @@ -114,13 +116,14 @@ def execute_and_return_decoded_std_streams(command_string): shlex.split(command_string), stdout=subprocess.PIPE, stderr=subprocess.PIPE, - ).communicate(), + ).communicate() ) std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err + def replace_linear( model, linear_replacement, @@ -140,7 +143,7 @@ def replace_linear( List of modules names not to convert. Defaults to `lm_head`. copy_weights (`bool`): Copy the weights from the old linear module to the new one - post_processing_function (`str`): + post_processing_fun_name (`str`): A function name of the replacement linear class that is called after processing. """ @@ -160,9 +163,8 @@ def replace_linear( model._modules[name].bias = old_module.bias if post_processing_function is not None: - func = getattr(module, post_processing_function, None) - if func is not None: - func(module) + func = getattr(module, post_processing_function, None) + if func is not None: func(module) return model @@ -177,7 +179,7 @@ def pack_dict_to_tensor(source_dict): A torch tensor containing the packed data. """ json_str = json.dumps(source_dict) - json_bytes = json_str.encode("utf-8") + json_bytes = json_str.encode('utf-8') tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) return tensor_data @@ -194,7 +196,7 @@ def unpack_tensor_to_dict(tensor_data): A Python dictionary containing the unpacked data. """ json_bytes = bytes(tensor_data.cpu().numpy()) - json_str = json_bytes.decode("utf-8") + json_str = json_bytes.decode('utf-8') unpacked_dict = json.loads(json_str) return unpacked_dict diff --git a/check_bnb_install.py b/check_bnb_install.py index 7a9dc93fc..5a7f74f89 100644 --- a/check_bnb_install.py +++ b/check_bnb_install.py @@ -2,14 +2,14 @@ import bitsandbytes as bnb -p = torch.nn.Parameter(torch.rand(10, 10).cuda()) -a = torch.rand(10, 10).cuda() +p = torch.nn.Parameter(torch.rand(10,10).cuda()) +a = torch.rand(10,10).cuda() p1 = p.data.sum().item() adam = bnb.optim.Adam([p]) -out = a * p +out = a*p loss = out.sum() loss.backward() adam.step() @@ -17,5 +17,5 @@ p2 = p.data.sum().item() assert p1 != p2 -print("SUCCESS!") -print("Installation was successful!") +print('SUCCESS!') +print('Installation was successful!') diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2184cce8c..87c4242de 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -22,12 +22,12 @@ title: FAQs - title: Explanation sections: - - local: explanations/optimizers - title: 8-bit optimizers - - local: explanations/resources + - local: resources title: Papers, resources & how to cite - title: API reference sections: + - local: reference/quantization + title: Quantization - title: Optimizers sections: - local: reference/optim/optim_overview diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 5943e7d1d..71b3d67bd 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -1,13 +1,19 @@ -# bitsandbytes +# `bitsandbytes` -bitsandbytes enables accessible large language models via k-bit quantization for PyTorch. bitsandbytes provides three main features for dramatically reducing memory consumption for inference and training: +The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 + 4-bit quantization functions. -* 8-bit optimizers uses block-wise quantization to maintain 32-bit performance at a small fraction of the memory cost. -* LLM.Int() or 8-bit quantization enables large language model inference with only half the required memory and without any performance degradation. This method is based on vector-wise quantization to quantize most features to 8-bits and separately treating outliers with 16-bit matrix multiplication. -* QLoRA or 4-bit quantization enables large language model training with several memory-saving techniques that don't compromise performance. This method quantizes a model to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allow training. +The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8bit optimizers through `bitsandbytes.optim` module. + +There are ongoing efforts to support further hardware backends, i.e. Intel CPU + GPU, AMD GPU, Apple Silicon. Windows support is on its way as well. + +## API documentation + +- [Quantization](quantization) +- [Integrations](integrations) +- [Optimizers](optimizers) # License -bitsandbytes is MIT licensed. +The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license. We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 49d8b4ebd..a63a6a93e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -21,7 +21,7 @@ To install from PyPI. pip install bitsandbytes ``` -## Compile from source +## Alternative: Compiling from source To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu: diff --git a/docs/source/integrations.mdx b/docs/source/integrations.mdx index 4badece49..48b4d6060 100644 --- a/docs/source/integrations.mdx +++ b/docs/source/integrations.mdx @@ -1,89 +1,31 @@ -# Integrations +# Transformers -bitsandbytes is widely integrated with many of the libraries in the Hugging Face and wider PyTorch ecosystem. This guide provides a brief overview of the integrations and how to use bitsandbytes with them. For more details, you should refer to the linked documentation for each library. +With Transformers it's very easy to load any model in 4 or 8-bit, quantizing them on the fly with `bitsandbytes` primitives. -## Transformers +Please review the [`bitsandbytes` section in the Transformers docs](https://huggingface.co/docs/transformers/main/en/quantization#bitsandbytes). -> [!TIP] -> Learn more in the bitsandbytes Transformers integration [guide](https://huggingface.co/docs/transformers/quantization#bitsandbytes). - -With Transformers, it's very easy to load any model in 4 or 8-bit and quantize them on the fly. To configure the quantization parameters, specify them in the [`~transformers.BitsAndBytesConfig`] class. - -For example, to load and quantize a model to 4-bits and use the bfloat16 data type for compute: +Details about the BitsAndBytesConfig can be found [here](https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/quantization#transformers.BitsAndBytesConfig). > [!WARNING] -> bfloat16 is the optimal compute data type if your hardware supports it. The default is float32 for backward compatibility and numerical stability, but it can often lead to numerical instabilities. bfloat16 provides the best of both worlds, numerical stability equivalent to float32, but combined with the memory footprint and significant computation speedup of a 16-bit data type. Make sure to check if your hardware supports bfloat16 and if it does, configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]! +> **Beware: bf16 is the optimal compute data type!** +> +> If your hardware supports it, `bf16` is the optimal compute dtype. The default is `float32` for backward compatibility and numerical stability. `float16` often leads to numerical instabilities, but `bfloat16` provides the benefits of both worlds: numerical stability equivalent to float32, but combined with the memory footprint and significant computation speedup of a 16-bit data type. Therefore, be sure to check if your hardware supports `bf16` and configure it using the `bnb_4bit_compute_dtype` parameter in BitsAndBytesConfig: ```py -from transformers import AutoModelForCausalLM, BitsAndBytesConfig +import torch +from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) -model_4bit = AutoModelForCausalLM.from_pretrained( - "bigscience/bloom-1b7", - device_map=device_map, - quantization_config=quantization_config, -) -``` - -### 8-bit optimizers - -You can use any of the 8-bit or paged optimizers with Transformers by passing them to the [`~transformers.Trainer`] class on initialization. All bitsandbytes optimizers are supported by passing the correct string in the [`~transformers.TrainingArguments`] `optim` parameter. For example, to load a [`~bitsandbytes.optim.PagedAdamW32bit`] optimizer: - -```py -from transformers import TrainingArguments, Trainer - -training_args = TrainingArguments( - ..., - optim="paged_adamw_32bit", -) -trainer = Trainer(model, training_args, ...) -trainer.train() -``` - -## PEFT - -> [!TIP] -> Learn more in the bitsandbytes PEFT integration [guide](https://huggingface.co/docs/peft/developer_guides/quantization#quantization). - -PEFT builds on the bitsandbytes Transformers integration, and extends it for training with a few more steps. Let's prepare the 4-bit model from the section above for training. - -Call the [`~peft.prepare_model_for_kbit_training`] method to prepare the model for training. This only works for Transformers models! - -```py -from peft import prepare_model_for_kbit_training - -model_4bit = prepare_model_for_kbit_training(model_4bit) ``` -Setup a [`~peft.LoraConfig`] to use QLoRA: - -```py -from peft import LoraConfig - -config = LoraConfig( - r=16, - lora_alpha=8, - target_modules="all-linear", - lora_dropout=0.05 - bias="none", - task_type="CAUSAL_LM" -) -``` +# PEFT +With `PEFT`, you can use QLoRA out of the box with `LoraConfig` and a 4-bit base model. -Now call the [`~peft.get_peft_model`] function on your model and config to create a trainable [`PeftModel`]. - -```py -from peft import get_peft_model - -model = get_peft_model(model_4bit, config) -``` +Please review the [bitsandbytes section in the PEFT docs](https://huggingface.co/docs/peft/developer_guides/quantization#quantize-a-model). -## Accelerate +# Accelerate -> [!TIP] -> Learn more in the bitsandbytes Accelerate integration [guide](https://huggingface.co/docs/accelerate/usage_guides/quantization). - -bitsandbytes is also easily usable from Accelerate and you can quantize any PyTorch model by passing a [`~accelerate.utils.BnbQuantizationConfig`] with your desired settings, and then calling the [`~accelerate.utils.load_and_quantize_model`] function to quantize it. +Bitsandbytes is also easily usable from within Accelerate, where you can quantize any PyTorch model simply by passing a quantization config; e.g: ```py from accelerate import init_empty_weights @@ -113,25 +55,37 @@ quantized_model = load_and_quantize_model( ) ``` -## PyTorch Lightning and Lightning Fabric +For further details, e.g. model saving, cpu-offloading andfine-tuning, please review the [`bitsandbytes` section in the Accelerate docs](https://huggingface.co/docs/accelerate/en/usage_guides/quantization). + + + +# PyTorch Lightning and Lightning Fabric + +Bitsandbytes is available from within both +- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), a deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale; +- and [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), a fast and lightweight way to scale PyTorch models without boilerplate). + +Please review the [bitsandbytes section in the PyTorch Lightning docs](https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#quantization-via-bitsandbytes). + + +# Lit-GPT -bitsandbytes is available from: +Bitsandbytes is integrated into [Lit-GPT](https://github.com/Lightning-AI/lit-gpt), a hackable implementation of state-of-the-art open-source large language models, based on Lightning Fabric, where it can be used for quantization during training, finetuning, and inference. -- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), a deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. -- [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), a fast and lightweight way to scale PyTorch models without boilerplate. +Please review the [bitsandbytes section in the Lit-GPT quantization docs](https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md). -Learn more in the bitsandbytes PyTorch Lightning integration [guide](https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#quantization-via-bitsandbytes). -## Lit-GPT +# Trainer for the optimizers -bitsandbytes is integrated with [Lit-GPT](https://github.com/Lightning-AI/lit-gpt), a hackable implementation of state-of-the-art open-source large language models. Lit-GPT is based on Lightning Fabric, and it can be used for quantization during training, finetuning, and inference. +You can use any of the 8-bit and/or paged optimizers by simple passing them to the `transformers.Trainer` class on initialization.All bnb optimizers are supported by passing the correct string in `TrainingArguments`'s `optim` attribute - e.g. (`paged_adamw_32bit`). -Learn more in the bitsandbytes Lit-GPT integration [guide](https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md). +See the [official API docs for reference](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer). -## Blog posts +Here we point out to relevant doc sections in transformers / peft / Trainer + very briefly explain how these are integrated: +e.g. for transformers state that you can load any model in 8-bit / 4-bit precision, for PEFT, you can use QLoRA out of the box with `LoraConfig` + 4-bit base model, for Trainer: all bnb optimizers are supported by passing the correct string in `TrainingArguments`'s `optim` attribute - e.g. (`paged_adamw_32bit`): -To learn in more detail about some of bitsandbytes integrations, take a look at the following blog posts: +# Blog posts -- [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) -- [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes](https://huggingface.co/blog/hf-bitsandbytes-integration) +- [Making LLMs even more accessible with `bitsandbytes`, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) +- [A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and `bitsandbytes`](https://huggingface.co/blog/hf-bitsandbytes-integration) diff --git a/docs/source/optimizers.mdx b/docs/source/optimizers.mdx index 7d04f82b1..734cb2211 100644 --- a/docs/source/optimizers.mdx +++ b/docs/source/optimizers.mdx @@ -1,14 +1,29 @@ -# 8-bit optimizers +# Introduction: 8-bit optimizers -With 8-bit optimizers, large models can be finetuned with 75% less GPU memory without losing any accuracy compared to training with standard 32-bit optimizers. The reduced memory requirements means 8-bit optimizers are 4x faster than a standard optimizer, and no hyperparameter tuning is required. +With 8-bit optimizers, larger models can be finetuned with the same GPU memory compared to standard 32-bit optimizer training. 8-bit optimizers are a drop-in replacement for regular optimizers, with the following properties: -This guide will show you how to use 8-bit optimizers. +- Faster (e.g. 4x faster than regular Adam) +- 75% less memory, same performance +- No hyperparameter tuning needed -> [!WARNING] -> 8-bit optimizers reduce memory usage and accelerate optimization on a wide range of tasks. However, since 8-bit optimizers only reduce memory proportional to the number of parameters, models that use large amounts of activation memory, such as convolutional networks, don't really benefit from 8-bit optimizers. 8-bit optimizers are most beneficial for training or finetuning models with many parameters on highly memory-constrained GPUs. +8-bit optimizers are mostly useful to finetune large models that did not fit into memory before. They also make it easier to pretrain larger models and have great synergy with sharded data parallelism. 8-bit Adam, for example, is already used across multiple teams in Facebook. This optimizer saves a ton of memory at no accuracy hit. -8-bit optimizers are a drop-in replacement for regular optimizers which means they also accept the same arguments as a regular optimizer. For NLP models, it is recommended to use the [`~nn.StableEmbedding`] class to improve stability and results. +Generally, our 8-bit optimizers have three components: +1. **block-wise quantization** isolates outliers and distributes the error more equally over all bits, +2. **dynamic quantization** quantizes both small and large values with high precision, +3. a **stable embedding layer** improves stability during optimization for models with word embeddings. +With these components, performing an optimizer update with 8-bit states is straightforward and for GPUs, this makes 8-bit optimizers way faster than regular 32-bit optimizers. [Further details below](#research-background) + +We feature 8-bit `Adagrad`, `Adam`, `AdamW`, `LAMB`, `LARS`, `Lion`, `RMSprop` and `SGD` (momentum). + +## Caveats + +8-bit optimizers reduce the memory footprint and accelerate optimization on a wide range of tasks. However, since 8-bit optimizers reduce only the memory footprint proportional to the number of parameters, **models that use large amounts of activation memory, such as convolutional networks, have few benefits from using 8-bit optimizers**. Thus, 8-bit optimizers are most beneficial for training or finetuning models with many parameters on highly memory-constrained GPUs. + +## Usage + +It only requires a two-line code change to get started. ```diff import bitsandbytes as bnb @@ -20,29 +35,112 @@ import bitsandbytes as bnb + bnb.nn.StableEmbedding(...) ``` -By default, all parameter tensors with less than 4096 elements are kept at 32-bits even if you initialize those parameters with 8-bit optimizers. This is done because small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). +The arguments passed are the same as standard Adam. For NLP models we recommend to also use the StableEmbedding layers which improves results and helps with stable 8-bit optimization. -You can change this value with the `min_8bit_size` parameter. For example, if you want to optimize parameters to 8-bits only if the minimum size is 16384 values (it is recommended to use multiples of 4096): +Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so: ```py -import bitsandbytes as bnb - +# For parameter tensors with less than 16384 values are optimized in 32-bit +# it is recommended to use multiplies of 4096: adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) ``` -Other parameters you can configure include the learning rate (`lr`), the decay rates (`betas`), the number of bits of the optimizer state (`optim_bits`), and percentile clipping (`percentile_clipping`) which can increase stability. For example, to initialize a 32-bit [`~bitsandbytes.optim.Adam`] optimizer with 5th percentile clipping: +Some more examples of how you can replace your old optimizer with the 8-bit optimizer: -```py +```diff import bitsandbytes as bnb -adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5) +- adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer ++ adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer + +# use 32-bit Adam with 5th percentile clipping ++ adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5) +- adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer ``` -## Optimize unstable parameters +## Overview of supported 8-bit optimizers + +Currently, `bitsandbytes` supports the following optimizers: + +- `Adagrad`, `Adagrad8bit`, `Adagrad32bit` +- `Adam`, `Adam8bit`, `Adam32bit`, `PagedAdam`, `PagedAdam8bit`, `PagedAdam32bit` +- `AdamW`, `AdamW8bit`, `AdamW32bit`, `PagedAdamW`, `PagedAdamW8bit`, `PagedAdamW32bit` +- `LAMB`, `LAMB8bit`, `LAMB32bit` +- `LARS`, `LARS8bit`, `LARS32bit`, `PytorchLARS` +- `Lion`, `Lion8bit`, `Lion32bit`, `PagedLion`, `PagedLion8bit`, `PagedLion32bit` +- `RMSprop`, `RMSprop8bit`, `RMSprop32bit` +- `SGD`, `SGD8bit`, `SGD32bit` + +Additionally, for cases in which you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`, [as explained in greater detail below](#optim_manager). + +Find the API docs [here](#optim_api_docs) (still under construction). + +## Overview of expected gains + +
+ +
+ +See here an overview of the biggest models that can be trained based on optimizer usage: + +
+ +
+ +### Research Background + +Stateful optimizers maintain gradient statistics over time, e.g. the exponentially smoothed sum (SGD with momentum) or squared sum (Adam) of past gradient values. This state can be used to accelerate optimization compared to plain stochastic gradient descent but uses memory that might otherwise be allocated to model parameters, thereby limiting the maximum size of models trained in practice. `bitsandbytes` optimizers use 8-bit statistics, while maintaining the performance levels of using 32-bit optimizer states. + +To overcome the resulting computational, quantization and stability challenges, 8-bit optimizers have three components: + +1. **Block-wise quantization** divides input tensors into smaller blocks that are independently quantized, therein isolating outliers and distributing the error more equally over all bits. Each block is processed in parallel across cores, yielding faster optimization and high precision quantization. +2. **Dynamic quantization**, which quantizes both small and large values with high precision and +3. a **stable embedding layer** improves stability during optimization for models with word embeddings. + +With these components, performing an optimizer update with 8-bit states is straightforward. We dequantize the 8-bit optimizer states to 32-bit, perform the update and then quantize the states back to 8-bit for storage. + +We do this 8-bit to 32-bit conversion element-by-element in registers, which means no slow copies to GPU memory or additional temporary memory are needed to perform quantization and dequantization. For GPUs, this makes 8-bit optimizers much faster than regular 32-bit optimizers. + +For more details, please refer to the paper [8-bit Optimizers via Block-wise Quantization](https://arxiv.org/abs/2110.02861). + +## Stable Embedding Layer -To optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, use the [`~bitsandbytes.optim.GlobalOptimManager`] class to override the specific hyperparameters for a particular layer. You'll need to: +The Stable Embedding Layer enhances the standard word embedding layer for improved training stability in NLP tasks. It addresses the challenge of non-uniform input distributions and mitigates extreme gradient variations, ensuring smoother training processes. -1. Register the parameters while they're on the CPU. +#### Features: + +- **Initialization**: Utilizes Xavier uniform initialization to maintain consistent variance, reducing the likelihood of large gradients. +- **Normalization**: Incorporates layer normalization before adding positional embeddings, aiding in output stability. +- **Optimizer States**: Employs 32-bit optimizer states exclusively for this layer to enhance stability, while the rest of the model may use standard 16-bit precision. + +#### Benefits: + +- Designed to support more aggressive quantization strategies without compromising training stability. +- Helps in achieving stable training outcomes, particularly important for models dealing with diverse and complex language data. + +## Paged optimizers + +Paged optimizers are build on top of the [unified memory](https://developer.nvidia.com/blog/unified-memory-cuda-beginners/) feature of CUDA. This feature is not supported by PyTorch and we added it to `bitsandbytes`. + +It works like regular CPU paging, which means that it only becomes active _if one runs out of GPU memory_. Only then will the memory be transferred, page-by-page, from GPU to CPU. The memory is mapped, meaning that pages are preallocated on the CPU, but they are not updated automatically. They are only updated if the memory is accessed, or a swapping operation is launched. + +The unified memory feature is less efficient than regular asynchronous memory transfers. This means, one usually will not be able to get full PCIe memory bandwidth utilization. If one does a manual prefetch, transfer speeds can be high but still about half or worse than the full PCIe memory bandwidth (tested on 16x lanes PCIe 3.0). + +This all means performance depends highly on the particular use-case. If one evicts, say, 1 GB of memory per forward-backward-optimizer loop: One can expect about 50% of the PCIe bandwidth as time in the best case. So 1 GB for PCIe 3.0 with 16x lanes, which runs at 16 GB/s, is `1/(16*0.5) = 1/8 = 125ms` overhead per optimizer step. Other overhead can be estimated for the particular use-case given a PCIe interface, lanes, and the memory that is evicted in each iteration. + +Compared to CPU offloading, this has the advantage that there is zero overhead if all the memory fits into the device and only some overhead if some of memory needs to be evicted. For offloading, one would usually offload fixed parts of the model and need to off and onload all this memory with each iteration through the model (sometimes twice for both forward and backward pass). + +[Find more details in this discussion](https://github.com/TimDettmers/bitsandbytes/issues/962). + + +## `GlobalOptimManager`: How to override config hyperparameters for particular weights/parameters[[optim_manager]] + +If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: + +1. Register the parameter while they are still on the CPU. +2. Override the config with the new desired hyperparameters (anytime, anywhere). + +For global overrides in many different places in your code you can do: ```py import torch @@ -51,32 +149,23 @@ import bitsandbytes as bnb mng = bnb.optim.GlobalOptimManager.get_instance() model = MyModel() -mng.register_parameters(model.parameters()) -``` - -2. Override the config with the new desired hyperparameters. For example, let's override the `model.fc1.weight` layer to use 32-bit Adam. +mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU -> [!TIP] -> Check the optimizer API documentation for more information about other hyperparameters you can override. - -```py model = model.cuda() # use 8-bit optimizer states for all parameters adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) -# override the parameter model.fc1.weight now uses 32-bit Adam -mng.override_config(model.fc1.weight, "optim_bits", 32) -``` +# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam +mng.override_config(model.fc1.weight, 'optim_bits', 32) -You can also override multiple layers at once by passing them as a list and the new hyperparameters as a dictionary. For example, let's override the `model.special.weight` and `model.also_special.weight` layers to use sparse optimization and a lower learning and decay rate. - -```py +# 2b. override: the two special layers use +# sparse optimization + different learning rate + different Adam betas mng.override_config([model.special.weight, model.also_special.weight], key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) ``` +Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`. -For a specific layer, we recommend overriding locally in each module. Pass the module, the parameter, and its attribute name to the [`~bitsandbytes.optim.GlobalOptimManager`]: - +For overrides for particular layers, we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: ```py class MyModule(torch.nn.Module): def __init__(d_in, d_out): @@ -89,6 +178,13 @@ class MyModule(torch.nn.Module): ``` -## Next steps +## API Docs[[optim_api_docs]] + +... under construction ... + +Here we'll provide further auto-generated API docs soon. Please feel free to contribute doc-strings for the respective optimizers, as `bitsandbytes` is a community effort. + +### StableEmbedding[[stable-emb-api]] -For more conceptual details and explanation about 8-bit optimizers, take a look at the [8-bit optimizers](./explanations/optimizers) guide. +[[autodoc]] bitsandbytes.nn.StableEmbedding + - __init__ diff --git a/docs/source/reference/quantization.mdx b/docs/source/reference/quantization.mdx new file mode 100644 index 000000000..3880cc089 --- /dev/null +++ b/docs/source/reference/quantization.mdx @@ -0,0 +1,13 @@ +# Quantization primitives + +Below you will find the docstring of the quantization primitives exposed in bitsandbytes. + +## Linear4bit (QLoRA)[[linear4bit]] + +[[autodoc]] bitsandbytes.nn.Linear4bit + - __init__ + +## Linear8bitLt[[linear8bit]] + +[[autodoc]] bitsandbytes.nn.Linear8bitLt + - __init__ diff --git a/docs/source/resources.mdx b/docs/source/resources.mdx new file mode 100644 index 000000000..56330175a --- /dev/null +++ b/docs/source/resources.mdx @@ -0,0 +1,92 @@ +# Papers, related resources & how to cite + +The below academic work is ordered in reverse chronological order. + +## [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression (Jun 2023)](https://arxiv.org/abs/2306.03078) + +Authors: Tim Dettmers, Ruslan Svirschevski, Vage Egiazarian, Denis Kuznedelev, Elias Frantar, Saleh Ashkboos, Alexander Borzunov, Torsten Hoefler, Dan Alistarh + +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1666076553665744896) + +``` +@article{dettmers2023spqr, + title={SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression}, + author={Dettmers, Tim and Svirschevski, Ruslan and Egiazarian, Vage and Kuznedelev, Denis and Frantar, Elias and Ashkboos, Saleh and Borzunov, Alexander and Hoefler, Torsten and Alistarh, Dan}, + journal={arXiv preprint arXiv:2306.03078}, + year={2023} +} +``` + +## [QLoRA: Efficient Finetuning of Quantized LLMs (May 2023)](https://arxiv.org/abs/2305.14314) +Authors: Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, Luke Zettlemoyer + +- [Video](https://www.youtube.com/watch?v=y9PHWGOa8HA&ab_channel=LondonMachineLearningMeetup) +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1661379354507476994) + +``` +@article{dettmers2023qlora, + title={Qlora: Efficient finetuning of quantized llms}, + author={Dettmers, Tim and Pagnoni, Artidoro and Holtzman, Ari and Zettlemoyer, Luke}, + journal={arXiv preprint arXiv:2305.14314}, + year={2023} +} +``` + +## [The case for 4-bit precision: k-bit Inference Scaling Laws (Dec 2022)](https://arxiv.org/abs/2212.09720) +Authors: Tim Dettmers, Luke Zettlemoyer + +- [Video](https://www.youtube.com/watch?v=odlQa6AE1gY&ab_channel=TheInsideView) +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1605209171758284805) + +``` +@inproceedings{dettmers2023case, + title={The case for 4-bit precision: k-bit inference scaling laws}, + author={Dettmers, Tim and Zettlemoyer, Luke}, + booktitle={International Conference on Machine Learning}, + pages={7750--7774}, + year={2023}, + organization={PMLR} +} +``` + +## [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)](https://arxiv.org/abs/2208.07339) +Authors: Tim Dettmers, Mike Lewis, Younes Belkada, Luke Zettlemoyer + +- [LLM.int8() Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) +- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) +- [Introduction to Weight Quantization](https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c) +- [Poster](https://twitter.com/Tim_Dettmers/status/1598351301942951937) + +``` +@article{dettmers2022llm, + title={Llm. int8 (): 8-bit matrix multiplication for transformers at scale}, + author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, + journal={arXiv preprint arXiv:2208.07339}, + year={2022} +} +``` + +## [8-bit Optimizers via Block-wise Quantization (Oct 2021)](https://arxiv.org/abs/2110.02861) +Authors: Tim Dettmers, Mike Lewis, Sam Shleifer, Luke Zettlemoyer + +- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) +- [Twitter summary thread](https://twitter.com/Tim_Dettmers/status/1446472128979562499) + +``` +@article{DBLP:journals/corr/abs-2110-02861, + author = {Tim Dettmers and + Mike Lewis and + Sam Shleifer and + Luke Zettlemoyer}, + title = {8-bit Optimizers via Block-wise Quantization}, + journal = {CoRR}, + volume = {abs/2110.02861}, + year = {2021}, + url = {https://arxiv.org/abs/2110.02861}, + eprinttype = {arXiv}, + eprint = {2110.02861}, + timestamp = {Thu, 21 Oct 2021 16:20:08 +0200}, + biburl = {https://dblp.org/rec/journals/corr/abs-2110-02861.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py index 2d4c77952..c89ba8d11 100644 --- a/examples/int8_inference_huggingface.py +++ b/examples/int8_inference_huggingface.py @@ -2,18 +2,23 @@ from transformers import LlamaForCausalLM, LlamaTokenizer MAX_NEW_TOKENS = 128 -model_name = "meta-llama/Llama-2-7b-hf" +model_name = 'meta-llama/Llama-2-7b-hf' -text = "Hamburg is in which country?\n" +text = 'Hamburg is in which country?\n' tokenizer = LlamaTokenizer.from_pretrained(model_name) input_ids = tokenizer(text, return_tensors="pt").input_ids -max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB" +max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' n_gpus = torch.cuda.device_count() max_memory = {i: max_memory for i in range(n_gpus)} -model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory) +model = LlamaForCausalLM.from_pretrained( + model_name, + device_map='auto', + load_in_8bit=True, + max_memory=max_memory +) generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) diff --git a/install_cuda.py b/install_cuda.py index 9e426cbd7..b41b33b39 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -19,7 +19,6 @@ "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run", } - def install_cuda(version, base_path, download_path): formatted_version = f"{version[:-1]}.{version[-1]}" folder = f"cuda-{formatted_version}" @@ -30,7 +29,7 @@ def install_cuda(version, base_path, download_path): subprocess.run(["rm", "-rf", install_path], check=True) url = cuda_versions[version] - filename = url.split("/")[-1] + filename = url.split('/')[-1] filepath = os.path.join(download_path, filename) if not os.path.exists(filepath): @@ -45,14 +44,9 @@ def install_cuda(version, base_path, download_path): # Install CUDA print(f"Installing CUDA version {version}...") install_command = [ - "bash", - filepath, - "--no-drm", - "--no-man-page", - "--override", - "--toolkitpath=" + install_path, - "--toolkit", - "--silent", + "bash", filepath, + "--no-drm", "--no-man-page", "--override", + "--toolkitpath=" + install_path, "--toolkit", "--silent" ] print(f"Running command: {' '.join(install_command)}") @@ -68,7 +62,6 @@ def install_cuda(version, base_path, download_path): print(f"CUDA version {version} installed at {install_path}") - def main(): user_base_path = os.path.expanduser("~/cuda") system_base_path = "/usr/local/cuda" @@ -100,6 +93,5 @@ def main(): print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1) - if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 609ff84fa..f74750720 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,6 @@ src = [ "tests", "benchmarking" ] -target-version = "py38" -line-length = 119 - -[tool.ruff.lint] select = [ "B", # bugbear: security warnings "E", # pycodestyle @@ -21,6 +17,7 @@ select = [ "UP", # alert you when better syntax is available in your python version "RUF", # the ruff developer's own rules ] +target-version = "py38" ignore = [ "B007", # Loop control variable not used within the loop body (TODO: enable) "B028", # Warning without stacklevel (TODO: enable) @@ -33,7 +30,7 @@ ignore = [ ] ignore-init-module-imports = true # allow to expose in __init__.py via imports -[tool.ruff.lint.extend-per-file-ignores] +[tool.ruff.extend-per-file-ignores] "**/__init__.py" = ["F401"] # allow unused imports in __init__.py "{benchmarking,tests}/**/*.py" = [ "B007", @@ -45,7 +42,7 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports "UP030", ] -[tool.ruff.lint.isort] +[tool.ruff.isort] combine-as-imports = true detect-same-package = true force-sort-within-sections = true diff --git a/requirements-ci.txt b/requirements-ci.txt index e6e375ccb..46bd5b9cd 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,7 @@ # Requirements used for GitHub actions pytest==7.2.2 einops==0.6.0 +wheel==0.40.0 lion-pytorch==0.0.6 -scipy==1.10.1; python_version < "3.9" -scipy==1.11.4; python_version >= "3.9" +scipy==1.11.4 +pandas==2.2.0 diff --git a/scripts/set_platform_tag.py b/scripts/set_platform_tag.py new file mode 100644 index 000000000..ca561c880 --- /dev/null +++ b/scripts/set_platform_tag.py @@ -0,0 +1,34 @@ +import argparse +import platform +import sys + + +def get_platform_tag(architecture): + system = platform.system() + + if system == "Linux": + tag = ( + "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64" + ) + elif system == "Darwin": + tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64" + elif system == "Windows": + tag = "win_amd64" if architecture == "x86_64" else "win_arm64" + else: + sys.exit(f"Unsupported system: {system}") + + return tag + + +def main(): + parser = argparse.ArgumentParser(description="Determine platform tag.") + parser.add_argument("arch", type=str, help="Architecture (e.g., x86_64, aarch64)") + args = parser.parse_args() + + tag = get_platform_tag(args.arch) + + print(tag) # This will be captured by the GitHub Actions workflow + + +if __name__ == "__main__": + main() diff --git a/scripts/stale.py b/scripts/stale.py index a65652aeb..613f5b7cb 100644 --- a/scripts/stale.py +++ b/scripts/stale.py @@ -15,7 +15,6 @@ Script to close stale issue. Taken in part from the AllenNLP repository. https://github.com/allenai/allennlp. """ - from datetime import datetime as dt, timezone import os @@ -51,7 +50,7 @@ def main(): issue.create_comment( "This issue has been automatically marked as stale because it has not had " "recent activity. If you think this still needs to be addressed " - "please comment on this thread.\n\n", + "please comment on this thread.\n\n" ) diff --git a/setup.py b/setup.py index a51b3867c..57603a4cc 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def has_ext_modules(self): setup( name="bitsandbytes", - version="0.44.0.dev", + version="0.43.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", diff --git a/tests/conftest.py b/tests/conftest.py index 17ffd281c..7aee8c922 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,6 @@ def pytest_runtest_call(item): if str(ae) == "Torch not compiled with CUDA enabled": pytest.skip("Torch not compiled with CUDA enabled") raise - except RuntimeError as re: - # CUDA-enabled Torch build, but no CUDA-capable device found - if "Found no NVIDIA driver on your system" in str(re): - pytest.skip("No NVIDIA driver found") - raise @pytest.fixture(scope="session") diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 9da665a2d..d01e5e9db 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -20,11 +20,7 @@ @pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) -@pytest.mark.parametrize( - "funcs", - [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], - ids=["func=bmm", "func=matmul"], -) +@pytest.mark.parametrize("funcs", [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], ids=["func=bmm", "func=matmul"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) @@ -34,13 +30,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) for i in range(25): + # normal multiply if funcs[0] in [torch.mm, torch.matmul]: dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) - target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] + ) torch.nn.init.xavier_uniform_(B) if not transpose[0] and not transpose[1]: @@ -72,7 +71,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -80,14 +81,18 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool B.grad = None if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: @@ -114,7 +119,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2) + torch.testing.assert_close( + out_bnb, out_torch, atol=0.027, rtol=0.2 + ) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -126,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -134,7 +143,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool B.grad = None if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -181,7 +192,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -189,7 +202,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool B.grad = None if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -203,17 +218,25 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) -@pytest.mark.parametrize( - "funcs", - [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], - ids=["func=matmul", "func=switchback_bnb"], -) +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], ids=["func=matmul", "func=switchback_bnb"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): +def test_matmullt( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, + has_bias +): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") @@ -222,13 +245,18 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec req_grad[2] = False for i in range(3): + # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + A = torch.randn( + size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype + ) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + B = torch.randn( + size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype + ) target = torch.randn( size=(dim2, dim4), device="cuda", @@ -238,7 +266,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -283,7 +311,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb = torch.nn.functional.mse_loss( + out_bnb, target + ).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad @@ -293,7 +323,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -304,7 +336,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec bias.grad = None if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() if dim2 > 0: @@ -318,7 +352,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) @@ -334,20 +370,8 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) -@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) -def test_matmul_4bit( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - has_bias, - compress_statistics, - quant_type, -): +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'], ids=id_formatter("quant_type")) +def test_matmul_4bit(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: @@ -363,15 +387,11 @@ def test_matmul_4bit( bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_4bit( - B, - compress_statistics=compress_statistics, - quant_type=quant_type, - ) + B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) @@ -390,7 +410,7 @@ def test_matmul_4bit( if n > 0: assert err < 0.115 - # assert err < 0.20 + #assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -404,7 +424,7 @@ def test_matmul_4bit( gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -415,7 +435,7 @@ def test_matmul_4bit( bias.grad = None if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) @@ -428,12 +448,8 @@ def test_matmul_4bit( @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize( - "funcs", - [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], - ids=["matmul_fp8_mixed", "matmul_fp8_global"], -) -def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], ids=["matmul_fp8_mixed", 'matmul_fp8_global']) +def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) req_grad = list(req_grad) @@ -464,7 +480,7 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: assert err < 0.115 - # assert err < 0.20 + #assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -475,7 +491,7 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -483,7 +499,7 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() @@ -498,6 +514,8 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - grad_err = (gradB1 - gradB2).abs().mean() + grad_err = (gradB1-gradB2).abs().mean() assert grad_err.item() < 0.003 - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index fc79a54b0..189aa75b5 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,38 +1,21 @@ -import pytest +import os +from pathlib import Path -from bitsandbytes.cextension import get_cuda_bnb_library_path -from bitsandbytes.cuda_specs import CUDASpecs +import torch -@pytest.fixture -def cuda120_spec() -> CUDASpecs: - return CUDASpecs( - cuda_version_string="120", - highest_compute_capability=(8, 6), - cuda_version_tuple=(12, 0), - ) +# hardcoded test. Not good, but a sanity check for now +# TODO: improve this +def test_manual_override(requires_cuda): + manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2')) + pytorch_version = torch.version.cuda.replace('.', '') -@pytest.fixture -def cuda111_noblas_spec() -> CUDASpecs: - return CUDASpecs( - cuda_version_string="111", - highest_compute_capability=(7, 2), - cuda_version_tuple=(11, 1), - ) + assert pytorch_version != 122 # TODO: this will never be true... - -def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): - monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" - - -def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): - monkeypatch.setenv("BNB_CUDA_VERSION", "110") - assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" - assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? - - -def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): - monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" + os.environ['CUDA_HOME']='{manual_cuda_path}' + os.environ['BNB_CUDA_VERSION']='122' + #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] + import bitsandbytes as bnb + loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name + #assert loaded_lib == 'libbitsandbytes_cuda122.so' diff --git a/tests/test_functional.py b/tests/test_functional.py index b9f1a6ead..d4f65755f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -19,7 +19,9 @@ id_formatter, ) -torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) +torch.set_printoptions( + precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 +) k = 20 @@ -96,7 +98,9 @@ def teardown(): pass -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16], ids=["float", "half"] +) def test_estimate_quantiles(dtype): A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) @@ -132,6 +136,7 @@ def test_quantile_quantization(): assert diff < 0.001 + def test_dynamic_quantization(): diffs = [] reldiffs = [] @@ -144,8 +149,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - print(sum(diffs) / len(diffs)) - print(sum(reldiffs) / len(reldiffs)) + print(sum(diffs)/len(diffs)) + print(sum(reldiffs)/len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device="cuda") @@ -156,12 +161,13 @@ def test_dynamic_quantization(): assert diff < 0.004 + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): - # print('') + #print('') diffs = [] reldiffs = [] for i in range(100): @@ -172,10 +178,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs) / len(diffs) - relerr = sum(reldiffs) / len(reldiffs) - # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) - # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) assert abserr < 0.011 assert relerr < 0.018 assert A2.dtype == dtype @@ -190,9 +196,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs) / len(diffs) - relerr = sum(reldiffs) / len(reldiffs) + #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) if signed: assert abserr < 0.0035 assert relerr < 0.015 @@ -200,11 +206,14 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): assert abserr < 0.00175 assert relerr < 0.012 assert A2.dtype == dtype - # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.parametrize( + "gtype", [torch.float32, torch.float16], ids=["float", "half"] +) def test_percentile_clipping(gtype): gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda") @@ -214,7 +223,9 @@ def test_percentile_clipping(gtype): for i in range(k): step += 1 g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) + gnorm1, clip2, gnorm_scale = F.percentile_clipping( + g, gnorm_vec2, step, percentile=percentile + ) assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 gnorm2 = torch.norm(g.float()) @@ -298,7 +309,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - # print("") + #print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -310,7 +321,9 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) + torch.testing.assert_close( + quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 + ) if batched: out2 = torch.bmm(A, B) C = torch.bmm(Ac.float(), Bc.float()) @@ -325,8 +338,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - # print(mean(errors)) - # print(mean(relerrors)) + #print(mean(errors)) + #print(mean(relerrors)) def test_stable_embedding(): @@ -343,8 +356,16 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): - shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) - shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) + shapeA = ( + (batch_dim, hidden_dim) + if not transpose[0] + else (hidden_dim, batch_dim) + ) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -364,7 +385,11 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -385,10 +410,16 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 2) for i in range(25): - A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) + A = torch.randint( + -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ).to(torch.int8) + B = torch.randint( + -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda" + ).to(torch.int8) out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) - iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) + iout = torch.empty( + A.shape[2], B.shape[2], dtype=torch.int32, device=A.device + ) out = F.igemm(A, B, out=iout) torch.testing.assert_close(out.float(), out2) @@ -413,7 +444,9 @@ def min_max(x): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") + A = torch.normal( + 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ) if transpose: B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: @@ -490,7 +523,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out = F.igemm(A.permute([0, 2, 1]), B) elif transpose[0] and transpose[1]: - out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) + out2 = torch.bmm( + A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() + ) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_close(out.float(), out2.float()) @@ -506,7 +541,7 @@ def test_vector_quant(dim1, dim2, dim3): qA, SA = F.vectorwise_quant(A, dim=0) A1 = F.vectorwise_dequant(qA, SA) n = A1.numel() - assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) + assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @@ -530,7 +565,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + dtype + ) out, S = F.nvidia_transform(A, to_order=orderOut) @@ -542,11 +579,17 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) + n = ( + A.shape[0] + * A.shape[1] + * (A.shape[2] + (32 - (A.shape[2] % 32))) + ) assert out.numel() == n elif orderOut == "col_turing": # 32 col 8 row tiles - n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) + n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( + A.shape[1] + (32 - (A.shape[1] % 32)) + ) assert out.numel() == n total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) for row in range(A.shape[0]): @@ -555,7 +598,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile + rowtile = ( + (row // 8) + (1 if row % 8 != 0 else 0) + ) * total_coltile offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 row2 = (row % 8) * 32 @@ -566,7 +611,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": - out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) + out2, S = F.nvidia_transform( + out, from_order=orderOut, to_order="row", state=S + ) torch.testing.assert_close(A, out2) @@ -579,10 +626,16 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( + torch.int8 + ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + A = torch.randint( + -128, 127, size=(dim1, dim2, dim3), device="cuda" + ).to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( + torch.int8 + ) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") @@ -592,7 +645,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.testing.assert_close(C1, C3.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( + torch.int8 + ) C1 = torch.matmul(A.float(), B.float()) B2t, SBt = F.transform(B, "col_turing", transpose=True) @@ -612,7 +667,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() + A = torch.normal( + 0, 0.5, size=(dim1, dim2, dim3), device="cuda" + ).half() B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) @@ -643,7 +700,6 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # C3, S = F.transform(C2, 'row', state=SC) # torch.testing.assert_close(C1, C3.float()) - @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), [ @@ -673,6 +729,7 @@ def test_bench_8bit_training(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(k): + out1 = torch.matmul(A, w1.t()) # fc1 # out2 = torch.matmul(out1, w2.t())# fc2 @@ -809,15 +866,13 @@ def test_bench_8bit_training(batch, seq, model, hidden): def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() bias = None - if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=torch.float16) + if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) formatB = F.get_special_format_str() for i in range(1): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) - if has_bias: - C1 += bias + if has_bias: C1 += bias A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -828,8 +883,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) - if has_bias: - C4 += bias + if has_bias: C4 += bias # TODO: is something wrong here? If so, the problem goes deeper # n = C1.numel() @@ -863,7 +917,9 @@ def test_colrow_absmax(dim1, dim2, dims): else: assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=threshold + ) A_blocked = einops.rearrange( torch.abs(A), @@ -883,7 +939,9 @@ def test_colrow_absmax(dim1, dim2, dims): torch.testing.assert_close(row_stats1_trunc, row_stats2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=0.0 + ) torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(row_stats1, row_stats2) @@ -905,16 +963,24 @@ def test_double_quant(dim1, dim2): torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() - num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() - num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() + num_not_close_rows = ( + (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + ) + num_not_close_cols = ( + (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() + ) # allow for 1:500 error due to rounding differences min_error = 1 / 500 if num_not_close_cols > (min_error * n): - print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}") + print( + f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" + ) assert False if num_not_close_rows > (min_error * n): - print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}") + print( + f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" + ) assert False torch.testing.assert_close(Srow.flatten().float(), statsA) @@ -925,12 +991,13 @@ def test_double_quant(dim1, dim2): ("dim1", "dim4", "inner"), ( pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") - for (dim1, dim4, inner) in zip( + for (dim1, dim4, inner) + in zip( get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4), get_test_dims(1, 4 * 1024, n=4), ) - ), + ) ) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): @@ -970,12 +1037,13 @@ def test_integrated_igemmlt(dim1, dim4, inner): ("dim1", "dim4", "inner"), ( pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") - for (dim1, dim4, inner) in zip( + for (dim1, dim4, inner) + in zip( get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6), get_test_dims(1, 4 * 1024, n=6), ) - ), + ) ) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): @@ -999,7 +1067,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt( + A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale + ) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1080,7 +1150,9 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt( + A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale + ) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1105,9 +1177,13 @@ def test_row_scale_bench(dim1, dim4, inner): def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( + dtype + ) elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) + A = torch.randint( + 10, 99, size=(dim1, dim2, dim3), device="cuda" + ).to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1148,17 +1224,23 @@ def test_coo_double_quant(dim1, dim2): idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( + A, threshold=threshold + ) if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) - A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + A2[ + coo_tensor.rowidx.long(), coo_tensor.colidx.long() + ] = coo_tensor.values torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_close( + A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 + ) @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) @@ -1179,7 +1261,9 @@ def test_spmm_coo(dim1, dim2, transposed_B): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) A2 = A * idx if transposed_B: @@ -1219,7 +1303,9 @@ def test_spmm_bench(): print(nnz / idx.numel()) rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) for i in range(10): out2 = F.spmm_coo(cooA, B) @@ -1253,7 +1339,9 @@ def test_integrated_sparse_decomp(dim1, dim2): out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( + A, threshold=threshold + ) C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) @@ -1308,7 +1396,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) A2 = A * idx out1 = torch.matmul(A2.half(), B.half()) out = out_func(out1.shape, dtype=torch.float16, device=out1.device) @@ -1323,7 +1413,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): std = out1.std() out1 /= std out2 /= std - assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) + assert_all_approx_close( + out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count + ) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) @@ -1351,7 +1443,9 @@ def test_coo2csr(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) A2 = A * idx csrA = F.coo2csr(cooA) counts = csrA.rowptr[1:] - csrA.rowptr[:-1] @@ -1369,7 +1463,9 @@ def test_coo2csc(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) A2 = A * idx cscA = F.coo2csc(cooA) counts = cscA.colptr[1:] - cscA.colptr[:-1] @@ -1403,7 +1499,9 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) A2 = A * idx out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out1 = torch.matmul(A2, B.half()) @@ -1484,7 +1582,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): @pytest.mark.parametrize( ("batch", "seq", "model", "hidden"), - [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")], + [pytest.param(1, 1, 6656, 4*6656, id="batch=1, seq=1, model=6656, hidden=26k")], ) @pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): @@ -1507,8 +1605,8 @@ def test_bench_matmul(batch, seq, model, hidden): outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half() - # linearMixedBit.eval() + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) + #linearMixedBit.eval() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() @@ -1525,123 +1623,121 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( - f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", - ) + print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - # torch.cuda.synchronize() - # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() + #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - # torch.cuda.synchronize() - # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + #torch.cuda.synchronize() + #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() - print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) torch.cuda.synchronize() - print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # bnb.matmul(A, B) - # torch.cuda.synchronize() - # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # bnb.matmul(A, B, threshold=6.0) - # torch.cuda.synchronize() - # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - # C32A, SA = F.transform(CA, "col32") - # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - # CxB, SB = F.transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + #C32A, SA = F.transform(CA, "col32") + #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + #CxB, SB = F.transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - # torch.cuda.synchronize() - # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # BA, statsB = F.vectorwise_quant(B, dim=1) - # CxB, SB = F.nvidia_transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #BA, statsB = F.vectorwise_quant(B, dim=1) + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1) # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - # torch.cuda.synchronize() - # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - # CxB, SB = F.nvidia_transform(CB, to_order=formatB) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # out = Cout * statsB * statsA * (1.0 / (127 * 127)) - # torch.cuda.synchronize() - # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # linear8bit(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #linear8bit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # linear8bit(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # linearMixedBit(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #linearMixedBit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # linearMixedBit(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # linear8bit_train(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #linear8bit_train(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # linear8bit_train(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - # linear8bit_train_thresh(A) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #linear8bit_train_thresh(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # linear8bit_train(A) - # torch.cuda.synchronize() - # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") def test_zeropoint(): def quant_zp(x): @@ -1682,8 +1778,8 @@ def quant_zp(x): C2 -= A.sum(1).view(-1, 1) * zp ca, cqa, cza = quant_zp(A) - # print(ca.min(), ca.max()) - # print((ca - cza).min(), (ca - cza).max()) + #print(ca.min(), ca.max()) + #print((ca - cza).min(), (ca - cza).max()) zp = 1 scale = 2.0 @@ -1712,14 +1808,14 @@ def quant_zp(x): C7 -= zpa * zpb * A.shape[1] C7 /= qa * qb - # print("") + #print("") # print(C0.flatten()[:10]) - # print(C1.flatten()[:10]) - # print(C2.flatten()[:10]) - # print(C3.flatten()[:10]) - # print(C5.flatten()[:10]) - # print(C6.flatten()[:10]) - # print(C7.flatten()[:10]) + #print(C1.flatten()[:10]) + #print(C2.flatten()[:10]) + #print(C3.flatten()[:10]) + #print(C5.flatten()[:10]) + #print(C6.flatten()[:10]) + #print(C7.flatten()[:10]) err1 = torch.abs(C1 - C2).mean().item() err2 = torch.abs(C1 - C3).mean().item() err3 = torch.abs(C1 - C4).mean().item() @@ -1756,15 +1852,16 @@ def test_extract_outliers(): torch.testing.assert_close(outliers1, outliers2) + def test_blockwise_cpu_large(): diffs = [] reldiffs = [] batch = 128 seq = 128 - for hidden in [128]: # , 14336]: + for hidden in [128]:#, 14336]: for blocksize in [4096, 16384]: for i in range(2): - A1 = torch.randn(batch, seq, hidden, device="cpu") + A1 = torch.randn(batch, seq, hidden, device='cpu') t0 = time.time() C, S = F.quantize_blockwise(A1, blocksize=blocksize) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) @@ -1778,9 +1875,10 @@ def test_blockwise_cpu_large(): # print(sum(reldiffs)/len(reldiffs)) + def test_fp8_quant(): for e_bits in range(1, 7): - p_bits = 7 - e_bits + p_bits = 7-e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() abserr = [] @@ -1790,12 +1888,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) + reldiff = diff/torch.abs(A1+1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - # assert diff < 0.0075 - # print(sum(abserr)/len(abserr)) - # print(sum(relerr)/len(relerr)) + #assert diff < 0.0075 + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -1804,12 +1902,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) + reldiff = diff/torch.abs(A1+1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - # assert diff < 0.0075 - # print(sum(abserr)/len(abserr)) - # print(sum(relerr)/len(relerr)) + #assert diff < 0.0075 + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -1818,48 +1916,50 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) + reldiff = diff/torch.abs(A1+1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - # assert diff < 0.0075 - # print(3, sum(abserr)/len(abserr)) - # print(3, sum(relerr)/len(relerr)) + #assert diff < 0.0075 + #print(3, sum(abserr)/len(abserr)) + #print(3, sum(relerr)/len(relerr)) def test_few_bit_quant(): - # print('') + + #print('') for bits in range(2, 9): - # print('='*30, bits, '='*30) - for method in ["linear", "fp8", "dynamic", "quantile"]: + #print('='*30, bits, '='*30) + for method in ['linear', 'fp8', 'dynamic', 'quantile']: abserrs = [] relerrs = [] code = None - if method == "linear": + if method == 'linear': code = F.create_linear_map(True, total_bits=bits).cuda() - elif method == "fp8": - ebits = math.ceil(bits / 2) - pbits = bits - ebits - 1 + elif method == 'fp8': + ebits = math.ceil(bits/2) + pbits = bits-ebits-1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - elif method == "dynamic": - code = F.create_dynamic_map(True, bits - 0, bits).cuda() - elif method == "quantile": - values = torch.randn(2048, 2048, device="cuda") + elif method == 'dynamic': + code = F.create_dynamic_map(True, bits-0, bits).cuda() + elif method == 'quantile': + values = torch.randn(2048, 2048, device='cuda') code = F.create_quantile_map(values, bits).cuda() # for some data types we have no zero # for some data types we have one zero # for some data types we have two zeros - assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}" - # print(method, (code==0).sum()) + assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}' + #print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): - values = torch.randn(1, 32, device="cuda") + + values = torch.randn(1, 32, device='cuda') values /= values.abs().max() - # values[values.abs() < 1e-6] += 1e-5 + #values[values.abs() < 1e-6] += 1e-5 q1 = [] v1 = [] for v in values[0]: - idx = torch.abs(v - code).argmin() + idx = torch.abs(v-code).argmin() q1.append(idx.item()) v1.append(code[idx].item()) @@ -1870,61 +1970,62 @@ def test_few_bit_quant(): v2 = F.dequantize_blockwise(q2, S2) idx = torch.isclose(q1.int(), q2.int()) - err2 = torch.abs(v2 - values) + err2 = torch.abs(v2-values) abserrs.append(err2.mean().item()) - relerrs.append((err2 / (1e-10 + values).abs()).mean().item()) + relerrs.append((err2/(1e-10+values).abs()).mean().item()) if idx.sum(): # some weird cases - err1 = torch.abs(v1 - values).mean() - # assert err2.mean() <= err1 + err1 = torch.abs(v1-values).mean() + #assert err2.mean() <= err1 else: torch.testing.assert_close(q1, q2) - # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) - # assert False + #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + #assert False def test_kbit_quantile_estimation(): for i in range(100): - data = torch.randn(1024, 1024, device="cuda") + data = torch.randn(1024, 1024, device='cuda') for bits in range(2, 9): - p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) + p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) val1 = torch.Tensor(norm.ppf(p)).cuda() val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) - err = torch.abs(val1 - val2).mean() + err = torch.abs(val1-val2).mean() assert err < 0.038 for i in range(100): - data = torch.randn(1024, 1024, device="cuda") + data = torch.randn(1024, 1024, device='cuda') for bits in range(2, 4): - total_values = 2**bits - 1 - p = np.linspace(0, 1, 2 * total_values + 1) - idx = np.arange(1, 2 * total_values + 1, 2) + total_values = 2**bits-1 + p = np.linspace(0, 1, 2*total_values+1) + idx = np.arange(1, 2*total_values+1, 2) p = p[idx] - offset = 1 / (2 * total_values) - p = np.linspace(offset, 1 - offset, total_values) + offset = 1/(2*total_values) + p = np.linspace(offset, 1-offset, total_values) val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) - err = torch.abs(val1 - val2).mean() + val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1) + err = torch.abs(val1-val2).mean() assert err < 0.035 @pytest.mark.benchmark def test_bench_dequantization(): - a = torch.rand(1024, 1024, device="cuda").half() - code = F.create_fp8_map(True, 3, 0, 4).cuda() + a = torch.rand(1024, 1024, device='cuda').half() + code =F.create_fp8_map(True, 3, 0, 4).cuda() qa, SA = F.quantize_blockwise(a, code=code) print(qa.max()) - max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000 - # print(max_theoretical_mu) + max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 + #print(max_theoretical_mu) torch.cuda.synchronize() t0 = time.time() for i in range(100): qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() - # print((time.time()-t0)/1e6) + #print((time.time()-t0)/1e6) + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @@ -1936,28 +2037,26 @@ def test_fp4_quant(dtype): result = 0 bias = 3 sign, e1, e2, p1 = bits - idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1 + idx = sign*8 + e1*4 + e2*2 + p1*1 sign = -1.0 if sign else 1.0 - exp = e1 * 2 + e2 * 1 + exp = e1*2 + e2*1 if exp == 0: # sub-normal - if p1 == 0: - result = 0 - else: - result = sign * 0.0625 + if p1 == 0: result = 0 + else: result = sign*0.0625 else: # normal - exp = 2 ** (-exp + bias + 1) + exp = 2**(-exp + bias + 1) frac = 1.5 if p1 else 1.0 - result = sign * exp * frac + result = sign*exp*frac code[idx] = result - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) qa, SA = F.quantize_fp4(A1, blocksize=64) A2 = F.dequantize_fp4(qa, SA) err = (A1 - A2).abs().float() - relerr = (err / (A1.abs().float() + 1e-8)).mean() + relerr = (err/(A1.abs().float()+1e-8)).mean() idx = err > 1.0 err = err.mean() @@ -1966,29 +2065,31 @@ def test_fp4_quant(dtype): assert relerr.item() < 0.28 -@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) def test_4bit_compressed_stats(quant_type): for blocksize in [128, 64]: errs1 = [] errs2 = [] for i in range(10): - A1 = torch.randn(1024, 1024, device="cuda").half() + A1 = torch.randn(1024, 1024, device='cuda').half() q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) + err = (A1 - A2).abs().float() - relerr = (err / (A1.abs().float() + 1e-15)).mean() + relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() errs1.append(err.item()) + assert err.item() < 0.11 assert relerr.item() < 0.28 err = (A1 - A3).abs().float() - relerr = (err / (A1.abs().float() + 1e-15)).mean() + relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() errs2.append(err.item()) @@ -1996,71 +2097,70 @@ def test_4bit_compressed_stats(quant_type): assert err.item() < 0.11 assert relerr.item() < 0.28 - # print(sum(errs1)/len(errs1), blocksize, quant_type) - # print(sum(errs2)/len(errs2), blocksize, quant_type) + #print(sum(errs1)/len(errs1), blocksize, quant_type) + #print(sum(errs2)/len(errs2), blocksize, quant_type) + -# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) -@pytest.mark.parametrize("quant_type", ["nf4"]) + +#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ['nf4']) @pytest.mark.benchmark def test_bench_4bit_dequant(quant_type): blocksize = 256 - a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half() + a = torch.rand(1024*12*4, 1024*12, device='cuda').half() qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) - input_size = a.numel() / 2 - output_size = a.numel() * 2 - num_bytes = input_size + output_size - GB = num_bytes / 1e9 - max_theoretical_s = GB / 768 - # print(max_theoretical_s*1e6) - b = torch.randn(128, 1024 * 12, device="cuda").half() + input_size = a.numel()/2 + output_size = a.numel()*2 + num_bytes = input_size+output_size + GB = num_bytes/1e9 + max_theoretical_s = GB/768 + #print(max_theoretical_s*1e6) + b = torch.randn(128, 1024*12, device='cuda').half() iters = 100 torch.cuda.synchronize() t0 = time.time() for i in range(iters): F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - # b.copy_(a) + #b.copy_(a) torch.cuda.synchronize() - # print((time.time()-t0)/iters*1e6) + #print((time.time()-t0)/iters*1e6) - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): # torch.matmul(b, a.t()) - # torch.cuda.synchronize() - # print((time.time()-t0)/iters*1e6) + #torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + def test_normal_map_tree(): code = F.create_normal_map() - values = code[:8].tolist() + code[-8:].tolist() + values =code[:8].tolist() + code[-8:].tolist() num_pivots = 1 - # print(values) - while num_pivots < 16: - idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots)) - # print(idx) + #print(values) + while num_pivots <16: + idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) + #print(idx) num_pivots *= 2 pivots = [] for i in idx: - pivots.append((values[i - 1] + values[i]) / 2) - # print(pivots) + pivots.append((values[i-1]+values[i])/2) + #print(pivots) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") -@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("storage_type", ['nf4', 'fp4']) +@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize( - "quant_storage", - [torch.uint8, torch.float16, torch.bfloat16, torch.float32], - ids=describe_dtype, -) +@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: - # for dim in [4*1024]: - # for dim in [1*16]: + #for dim in [4*1024]: + #for dim in [1*16]: errs1 = [] errs2 = [] errs3 = [] @@ -2071,42 +2171,38 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2 = [] max_errs3 = [] + for i in range(100): - if kind == "fc1": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "fc2": - A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") - B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "attn": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "attn_packed": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - - qB, state = F.quantize_4bit( - B, - quant_type=storage_type, - compress_statistics=double_quant, - quant_storage=quant_storage, - ) + if kind == 'fc1': + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + elif kind == 'fc2': + A = torch.randn(1, 4*dim, dtype=dtype, device='cuda') + B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim) + elif kind == 'attn': + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + elif kind == 'attn_packed': + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) - err1 = (C1 - C2).abs().float() - err2 = (C3 - C2).abs().float() - err3 = (C3 - C1).abs().float() + err1 = (C1-C2).abs().float() + err2 = (C3-C2).abs().float() + err3 = (C3-C1).abs().float() - mag1 = torch.abs(C1).float() + 1e-5 - mag2 = torch.abs(C3).float() + 1e-5 - mag3 = torch.abs(C3).float() + 1e-5 + mag1 = torch.abs(C1).float()+1e-5 + mag2 = torch.abs(C3).float()+1e-5 + mag3 = torch.abs(C3).float()+1e-5 - relerr1 = err1 / mag1 - relerr2 = err2 / mag2 - relerr3 = err3 / mag3 + relerr1 = err1/mag1 + relerr2 = err2/mag2 + relerr3 = err3/mag3 max_err1 = err1.max() max_err2 = err2.max() @@ -2124,34 +2220,34 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2.append(max_err2.item()) max_errs3.append(max_err3.item()) - c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + c = int(C1.numel()*0.0014*(dim/256))+1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - err1 = sum(errs1) / len(errs1) / math.sqrt(dim) - err2 = sum(errs2) / len(errs2) / math.sqrt(dim) - err3 = sum(errs3) / len(errs3) / math.sqrt(dim) - relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) - relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) - relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) - maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) - maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) - maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) - absratio = err2 / err3 - relratio = relerr2 / relerr3 - maxratio = relerr2 / relerr3 + err1 = sum(errs1)/len(errs1)/math.sqrt(dim) + err2 = sum(errs2)/len(errs2)/math.sqrt(dim) + err3 = sum(errs3)/len(errs3)/math.sqrt(dim) + relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim) + relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim) + relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim) + maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim) + maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim) + maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim) + absratio = err2/err3 + relratio = relerr2/relerr3 + maxratio = relerr2/relerr3 # for debugging if the tests fails # - # print('='*80) - # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') - # print(C1.flatten()[-20:]) - # print(C2.flatten()[-20:]) - # print(f'inference vs training abs: {err1}') - # print(f'inference vs training rel: {relerr1}') - # print(f'inference vs training max: {maxerr1}') - # print(f'inference vs training vs torch err ratio abs: {absratio}') - # print(f'inference vs training vs torch err ratio rel: {relratio}') - # print(f'inference vs training vs torch err ratio max: {maxratio}') + #print('='*80) + #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + #print(C1.flatten()[-20:]) + #print(C2.flatten()[-20:]) + #print(f'inference vs training abs: {err1}') + #print(f'inference vs training rel: {relerr1}') + #print(f'inference vs training max: {maxerr1}') + #print(f'inference vs training vs torch err ratio abs: {absratio}') + #print(f'inference vs training vs torch err ratio rel: {relratio}') + #print(f'inference vs training vs torch err ratio max: {maxratio}') if dtype == torch.float16: if dim <= 512: assert err1 < 7e-5 @@ -2187,59 +2283,56 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): - n = 32 * 10 + n = 32*10 A = F.get_paged(n, n, dtype=torch.float32) B = F.get_paged(n, n, dtype=torch.uint8) B2 = F.get_paged(n, n, dtype=torch.float32) assert A.is_paged assert B.is_paged - assert A.page_deviceid == 0 - assert B.page_deviceid == 0 + assert A.page_deviceid==0 + assert B.page_deviceid==0 F.fill(A, 17.0) F.fill(B, 17) F.fill(B2, 2) - assert (A == 17).sum().item() == n * n - assert (B == 17).sum().item() == n * n - C = A * B.float() - assert (C == 289).sum().item() == n * n + assert (A==17).sum().item() == n*n + assert (B==17).sum().item() == n*n + C = A*B.float() + assert (C==289).sum().item() == n*n F._mul(A, B2) F._mul(A, B2) F._mul(A, B2) - assert (A == 17 * (2**3)).sum().item() == n * n - - -# F.prefetch_tensor(A) -# F.prefetch_tensor(B) + assert (A==17*(2**3)).sum().item() == n*n + # F.prefetch_tensor(A) + # F.prefetch_tensor(B) -# F.fill(B2, 17.0) -# F._mul(A, B2) + # F.fill(B2, 17.0) + # F._mul(A, B2) -# F.prefetch_tensor(A, to_cpu=True) -# F.prefetch_tensor(B, to_cpu=True) -# F.prefetch_tensor(B2, to_cpu=True) -# torch.cuda.synchronize() + # F.prefetch_tensor(A, to_cpu=True) + # F.prefetch_tensor(B, to_cpu=True) + # F.prefetch_tensor(B2, to_cpu=True) + # torch.cuda.synchronize() -# assert (A==17).sum().item() == n*n + # assert (A==17).sum().item() == n*n -# torch.testing.assert_close(A, torch.ones(A.shape)*289) + # torch.testing.assert_close(A, torch.ones(A.shape)*289) -@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) +@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) +@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) dims = get_test_dims(0, 8192, n=dims) - dims = [dim + (64 - (dim % 64)) for dim in dims] - # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: + dims = [dim + (64-(dim % 64)) for dim in dims] + #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: for dim in dims: - A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") - B = torch.eye(dim, dtype=dtype, device="cuda") + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda') + B = torch.eye(dim, dtype=dtype, device='cuda') qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) C3 = torch.matmul(A, B.t()) @@ -2250,5 +2343,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): torch.testing.assert_close(A, C3) torch.testing.assert_close(A, C1) torch.testing.assert_close(A, C2) - # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) - # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) + #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) + #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) diff --git a/tests/test_generation.py b/tests/test_generation.py index 911aa14da..b05749bf8 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -3,68 +3,66 @@ import pytest import torch +import transformers +from transformers import ( + AutoModelForCausalLM, + BitsAndBytesConfig, +) from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter -transformers = pytest.importorskip("transformers") - def get_4bit_config(): - return transformers.BitsAndBytesConfig( - load_in_4bit=True, - load_in_8bit=False, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - ) + return BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4', + ) def get_model_and_tokenizer(config): model_name_or_path, quant_type = config bnb_config = get_4bit_config() - if quant_type == "16bit": + if quant_type == '16bit': bnb_config.load_in_4bit = False else: - bnb_config.bnb_4bit_quant_type = quant_type - model = transformers.AutoModelForCausalLM.from_pretrained( - model_name_or_path, + bnb_config.bnb_4bit_quant_type= quant_type + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, quantization_config=bnb_config, - max_memory={0: "48GB"}, - device_map="auto", - torch_dtype=torch.bfloat16, - ).eval() + max_memory={0:'48GB'}, + device_map='auto', + torch_dtype=torch.bfloat16 + ).eval() tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) return model, tokenizer - def get_prompt_for_generation_eval(text, add_roles=True): description = ( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) if add_roles: - prompt = f"{description} ### Human: {text} ### Assistant:" + prompt = f'{description} ### Human: {text} ### Assistant:' else: - prompt = f"{description} {text}" + prompt = f'{description} {text}' return prompt - def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): text = prompt_func(text) - inputs = tokenizer(text, return_tensors="pt").to("cuda:0") - outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config) + inputs = tokenizer(text, return_tensors="pt").to('cuda:0') + outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) return tokenizer.decode(outputs[0], skip_special_tokens=True) +models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] +dtypes = ['nf4', 'fp4'] -models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"] -dtypes = ["nf4", "fp4"] - - -@pytest.fixture(scope="session", params=product(models, dtypes)) +@pytest.fixture(scope='session', params=product(models, dtypes)) def model_and_tokenizer(request): model, tokenizer = get_model_and_tokenizer(request.param) yield request.param, model, tokenizer @@ -86,19 +84,20 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): ) generation_config.max_new_tokens = 20 - # text = 'Please write down the first 50 digits of pi.' - # text = get_prompt_for_generation_eval(text) - # text += ' Sure, here the first 50 digits of pi: 3.14159' + + #text = 'Please write down the first 50 digits of pi.' + #text = get_prompt_for_generation_eval(text) + #text += ' Sure, here the first 50 digits of pi: 3.14159' n_cases = 6 - text = "3.14159" - if hasattr(model.config, "quantization_config"): + text = '3.14159' + if hasattr(model.config, 'quantization_config'): model.config.quantization_config.bnb_4bit_compute_dtype = dtype model.config.quantization_config.bnb_4bit_use_double_quant = DQ if not inference_kernel: - text = [text] * n_cases - inputs = tokenizer(text, return_tensors="pt").to("cuda:0") - x = inputs["input_ids"] + text = [text]*n_cases + inputs = tokenizer(text, return_tensors="pt").to('cuda:0') + x = inputs['input_ids'] outputs = [] if inference_kernel: for i in range(n_cases): @@ -109,14 +108,15 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): outputs = model.generate(x, generation_config=generation_config) outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] + assert len(outputs) == n_cases failure_count = 0 for i in range(n_cases): - if not outputs[i][: len(str(math.pi))] == str(math.pi): + if not outputs[i][:len(str(math.pi))] == str(math.pi): failure_count += 1 - failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4 + failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) if failure_count > failure_max: print(math.pi) for out in outputs: print(out) - raise ValueError(f"Failure count: {failure_count}/{n_cases}") + raise ValueError(f'Failure count: {failure_count}/{n_cases}') diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index bbbd05335..567e1a466 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -28,7 +28,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora device = "cuda" layer_shape = (300, 400) - linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer + linear = torch.nn.Linear( + *layer_shape, dtype=original_dtype, device="cpu" + ) # original layer # Quantizing original layer linear_q = bnb.nn.Linear4bit( @@ -40,7 +42,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) + new_weight = bnb.nn.Params4bit( + data=linear.weight, quant_type=quant_type, requires_grad=False + ) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -168,9 +172,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora target_compression = ( 0.143 if original_dtype == torch.float32 else 0.29 ) # these numbers get lower as weight shape increases - ratio_error_msg = ( - f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" - ) + ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" assert size_ratio < target_compression, ratio_error_msg diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 4b62abd6d..edc3409cd 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -19,7 +19,6 @@ # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py - @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", @@ -51,9 +50,7 @@ def test_linear_no_igemmlt(): linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), - requires_grad=False, - has_fp16_weights=False, + linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False ).to(linear.weight.dtype) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() @@ -80,14 +77,7 @@ def test_linear_no_igemmlt(): @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) @pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) -def test_linear_serialization( - has_fp16_weights, - serialize_before_forward, - deserialize_before_cuda, - force_no_igemmlt, - save_before_forward, - load_before_cuda, -): +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -102,9 +92,7 @@ def test_linear_serialization( linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), - requires_grad=has_fp16_weights, - has_fp16_weights=has_fp16_weights, + linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() diff --git a/tests/test_modules.py b/tests/test_modules.py index db4d72410..674620e29 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -19,18 +19,12 @@ class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super().__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, - dim2, - has_fp16_weights=has_fp16_weights, - memory_efficient_backward=memory_efficient_backward, - threshold=threshold, + dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, - dim1, - has_fp16_weights=has_fp16_weights, - memory_efficient_backward=memory_efficient_backward, - threshold=threshold, + dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) def forward(self, x): @@ -58,7 +52,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) norm = math.sqrt(math.pi) / math.sqrt(2.0) # std = torch.abs(x).mean()*norm std = torch.std(x) @@ -126,7 +122,9 @@ def dequant_min_max(xq, A, B, SA, SB, dtype): return x.to(dtype) def get_8bit_linear(x, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) max1 = torch.abs(x).max() x = x / max1 * 127 x = round_func(x) / 127 * max1 @@ -135,7 +133,9 @@ def get_8bit_linear(x, stochastic=False): @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1[max1 == 0] = 1.0 x = (x * 127) / max1 @@ -219,7 +219,9 @@ def forward(ctx, x, weight, bias=None, args=None): weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) + output = LinearFunction.dequant( + outputq, S1, S2, x.dtype, args.quant_type + ) # if torch.rand(1) < 0.01: # output32 = torch.matmul(x, weight.t()) # err = torch.abs(output-output32).float() @@ -248,25 +250,37 @@ def backward(ctx, grad_output): # weight and x are already 8bit # -> transform grad_output to 8-bit if args.use_8bit_training == "forward+wgrad": - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) grad_input = grad_output.matmul(weight) elif args.use_8bit_training == "full": - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=2 + ) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) + grad_input = LinearFunction.dequant( + grad_input8, S1, S3, grad_output.dtype, args.quant_type + ) else: grad_input = grad_output.matmul(weight) @@ -342,8 +356,12 @@ def test_linear8bitlt_accumulated_gradient(): opt1.zero_grad(True) opt2.step() opt2.zero_grad(True) - assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) - assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close( + l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 + ) + assert_all_approx_close( + l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 + ) # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) @@ -357,17 +375,7 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( - bnb.nn.Linear8bitLt( - 32, - 64, - threshold=threshold, - has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, - ) - .cuda() - .half() - ) + l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -389,7 +397,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .cuda() + .half() + ) assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -402,7 +414,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .half() + .cuda() + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -415,17 +431,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( - MLP8bit( - 32, - 64, - threshold=threshold, - has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, - ) - .half() - .to("cuda") - ) + mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -441,12 +447,8 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc2.weight.device.type == "cuda" mlp = MLP8bit( - 32, - 64, - threshold=threshold, - has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, - ) + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization @@ -487,7 +489,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4, ], - ids=["Int8Lt", "FP4"], + ids=['Int8Lt', 'FP4'], ) def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically @@ -542,7 +544,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() - kbit = kbit.half().to("cuda") + kbit = kbit.half().to('cuda') errs1 = [] errs2 = [] @@ -560,10 +562,10 @@ def test_kbit_backprop(module): bgrad1 = ref[0].bias.grad bgrad2 = kbit[0].bias.grad - err1 = (out1 - out2).abs().float() - err2 = (grad1 - grad2).abs().float() - relerr1 = err1 / (out1.abs().float() + 1e-9) - relerr2 = err2 / (grad1.abs().float() + 1e-9) + err1 = (out1-out2).abs().float() + err2 = (grad1-grad2).abs().float() + relerr1 = (err1/(out1.abs().float()+1e-9)) + relerr2 = (err2/(grad1.abs().float()+1e-9)) errs1.append(err1.mean().item()) errs2.append(err2.mean().item()) relerrs1.append(relerr1.mean().item()) @@ -580,20 +582,20 @@ def test_kbit_backprop(module): assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 - # print('out', sum(errs1)/len(errs1)) - # print('grad', sum(errs2)/len(errs2)) - # print('rel out', sum(relerrs1)/len(relerrs1)) - # print('rel grad', sum(relerrs2)/len(relerrs2)) - + #print('out', sum(errs1)/len(errs1)) + #print('grad', sum(errs2)/len(errs2)) + #print('rel out', sum(relerrs1)/len(relerrs1)) + #print('rel grad', sum(relerrs2)/len(relerrs2)) def test_fp8linear(): + b = 10 h = 1024 inp = torch.randn(b, h).cuda() - fp32 = torch.nn.Linear(h, h * 2).cuda() - fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() - fp32b = torch.nn.Linear(h * 2, h).cuda() - fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() + fp32 = torch.nn.Linear(h, h*2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() + fp32b = torch.nn.Linear(h*2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() fp8.weight.data.copy_(fp32.weight.data) fp8.bias.data.copy_(fp32.bias.data) @@ -603,34 +605,34 @@ def test_fp8linear(): a = fp32b(torch.nn.functional.gelu(fp32(inp))) b = fp8b(torch.nn.functional.gelu(fp8(inp))) - err = (a - b).abs().mean() + err = (a-b).abs().mean() a.mean().backward() b.mean().backward() - graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() - bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() + graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() assert err < 0.05 assert graderr < 0.00002 assert bgraderr < 0.00002 - def test_4bit_warnings(): dim1 = 64 - with pytest.warns(UserWarning, match=r"inference or training"): + with pytest.warns(UserWarning, match=r'inference or training'): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() net(inp) - with pytest.warns(UserWarning, match=r"inference."): + with pytest.warns(UserWarning, match=r'inference.'): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(1, dim1).cuda().half() net(inp) with pytest.warns(UserWarning) as record: + net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() diff --git a/tests/test_optim.py b/tests/test_optim.py index d8c46e415..9395b8820 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -16,7 +16,6 @@ k = 20 - def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() @@ -34,7 +33,6 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) - str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) @@ -68,14 +66,8 @@ def rm_path(path): ) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers["paged_adamw8bit_blockwise"] = ( - torch.optim.AdamW, - lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True), -) -str2optimizers["paged_adam8bit_blockwise"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True), -) +str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) +str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( @@ -98,18 +90,9 @@ def rm_path(path): str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] -str2statenames["adam8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] -str2statenames["paged_adam8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] -str2statenames["paged_adamw8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] +str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] @@ -118,7 +101,7 @@ def rm_path(path): str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] -optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"] +optimizer_names_32bit = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] @pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt")) @@ -126,7 +109,7 @@ def rm_path(path): @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: + if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() if dim1 == 1 and dim2 == 1: return @@ -178,13 +161,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close( - torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2], - atol=atol, - rtol=rtol, - max_error_count=10, - ) + assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], + atol=atol, rtol=rtol, + max_error_count=10) if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit @@ -214,9 +193,13 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) + bnb.optim.GlobalOptimManager.get_instance().override_config( + p3, "optim_bits", 8 + ) - bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) + bnb.optim.GlobalOptimManager.get_instance().register_parameters( + [p1, p2, p3] + ) p1 = p1.cuda() p2 = p2.cuda() p3 = p3.cuda() @@ -259,8 +242,7 @@ def test_global_config(dim1, dim2, gtype): @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]: - pytest.skip() + if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -312,12 +294,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 - # assert num_not_close.sum().item() < 20 + num_not_close = ( + torch.isclose( + torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + ) + == 0 + ) + #assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) - relerr = err / (torch.abs(p1) + 1e-9) + relerr = err / (torch.abs(p1)+1e-9) if g.dtype == torch.bfloat16: assert err.mean() < 0.00015 assert relerr.mean() < 0.0016 @@ -329,7 +316,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors.append(relerr.mean().item()) if i % 10 == 0 and i > 0: - for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + for (name1, name2, qmap, max_val), s in zip( + str2statenames[optim_name], dequant_states + ): s1cpy = s.clone() raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() @@ -359,7 +348,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) torch.testing.assert_close(s1cpy, s1) - num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 + num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 5 errors for Lion @@ -406,11 +395,15 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): for i in range(50): step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( + 0.01 * i + ) g2 = g1.clone() p2.grad = g2 - current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( + g1, gnorm_vec, step, 5 + ) g1 = (g1.float() * gnorm_scale).to(gtype) p1.grad = g1 @@ -504,8 +497,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) -@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) -@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) +@pytest.mark.parametrize("optim_name", ['paged_adamw'], ids=id_formatter("optim_name")) +@pytest.mark.parametrize("mode", ['bnb'], ids=id_formatter("mode")) @pytest.mark.benchmark def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) @@ -513,24 +506,24 @@ def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = layers1.cuda() large_tensor = None - if mode == "torch": + if mode == 'torch': optim = str2optimizers[optim_name][0](layers1.parameters()) else: optim = str2optimizers[optim_name][1](layers1.parameters()) # 12 GB - large_tensor = torch.empty((int(4.5e9),), device="cuda") + large_tensor = torch.empty((int(4.5e9),), device='cuda') torch.cuda.synchronize() time.sleep(5) num_batches = 5 - batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) - lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() + batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) + lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() for i in range(num_batches): print(i) b = batches[i] - if i == 2: + if i ==2: torch.cuda.synchronize() t0 = time.time() diff --git a/tests/test_triton.py b/tests/test_triton.py index 3624fb5e9..218a533d5 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -7,18 +7,15 @@ from tests.helpers import TRUE_FALSE -@pytest.mark.skipif( - not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, - reason="This test requires triton and a GPU with compute capability 8.0 or higher.", -) +@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.") @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) def test_switchback(vector_wise_quantization): for dim in [83]: for batch in [13]: + standard = torch.nn.Linear(dim, 4 * dim).cuda().half() - switchback = ( - SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() - ) + switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() baseline = Linear8bitLt(dim, 4 * dim).cuda().half() switchback.weight.data.copy_(standard.weight) switchback.bias.data.copy_(standard.bias) @@ -41,23 +38,23 @@ def test_switchback(vector_wise_quantization): err_sb = (out_standard - out_sb).abs().mean() err_baseline = (out_standard - out_baseline).abs().mean() - print("OUT", err_sb, err_baseline) + print('OUT', err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() - print("GW2", err_sb, err_baseline) + print('GW2', err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() - print("GW1", err_sb, err_baseline) + print('GW1', err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (x1.grad - x2.grad).abs().mean() err_baseline = (x1.grad - x3.grad).abs().mean() - print("GX1", err_sb, err_baseline) + print('GX1', err_sb, err_baseline) assert err_sb < 2 * err_baseline