diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000..01084d44f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,19 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: + +jobs: + Lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.12" + - uses: pre-commit/action@v3.0.0 + env: + RUFF_OUTPUT_FORMAT: github diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..d568a849f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.15 + hooks: + - id: ruff + args: + - --fix + # - id: ruff-format # TODO: enable when the time is right diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py index 8897564e7..3ef87d6b2 100644 --- a/benchmarking/switchback/make_plot_with_jsonl.py +++ b/benchmarking/switchback/make_plot_with_jsonl.py @@ -1,9 +1,7 @@ -import matplotlib.pyplot as plt -import pandas as pd -import numpy as np -import os import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import pandas as pd cmap=plt.get_cmap('cool') diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py index b0983d0b8..d70df0386 100644 --- a/benchmarking/switchback/speed_benchmark.py +++ b/benchmarking/switchback/speed_benchmark.py @@ -1,14 +1,22 @@ import json - import time + import torch -import torch.nn as nn +from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( + int8_matmul_mixed_dequantize, +) +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( + int8_matmul_rowwise_dequantize, +) +from bitsandbytes.triton.quantize_columnwise_and_transpose import ( + quantize_columnwise_and_transpose, +) +from bitsandbytes.triton.quantize_global import ( + quantize_global, + quantize_global_transpose, +) from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 01d5527f5..87307a9d2 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,14 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils, research +from . import cuda_setup, research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, matmul, + matmul_4bit, matmul_cublas, mm_cublas, - matmul_4bit ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 8f58e1665..af5c1c523 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,11 +1,7 @@ import os +from os.path import isdir import sys -import shlex -import subprocess - from warnings import warn -from typing import Tuple -from os.path import isdir import torch @@ -20,7 +16,7 @@ def find_file_recursive(folder, filename): out = glob.glob(os.path.join(folder, "**", filename + ext)) outs.extend(out) except Exception as e: - raise RuntimeError('Error: Something when wrong when trying to find file. {e}') + raise RuntimeError('Error: Something when wrong when trying to find file.') from e return outs @@ -62,14 +58,11 @@ def generate_bug_report_information(): print_header(f"{path} CUDA PATHS") paths = find_file_recursive(path, '*cuda*') print(paths) - except: - print(f'Could not read LD_LIBRARY_PATH: {path}') + except Exception as e: + print(f'Could not read LD_LIBRARY_PATH: {path} ({e})') print('') - - - def print_header( txt: str, width: int = HEADER_WIDTH, filler: str = "+" ) -> None: @@ -78,67 +71,61 @@ def print_header( def print_debug_info() -> None: + from . import PACKAGE_GITHUB_URL print( "\nAbove we output some debug information. Please provide this info when " f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" ) -generate_bug_report_information() +def main(): + generate_bug_report_information() + from . import COMPILED_WITH_CUDA + from .cuda_setup.main import get_compute_capabilities -from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.env_vars import to_be_ignored -from .cuda_setup.main import get_compute_capabilities - + print_header("OTHER") + print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") + print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + print("\nWARNING: Please be sure to sanitize sensitive info from any such env vars!\n") -print_header("OTHER") -print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") -print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") -print_header("") -print_header("DEBUG INFO END") -print_header("") -print( - """ -Running a quick check that: - + library is importable - + CUDA function is callable -""" -) -print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") + try: + from bitsandbytes.optim import Adam -try: - from bitsandbytes.optim import Adam + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() - p = torch.nn.Parameter(torch.rand(10, 10).cuda()) - a = torch.rand(10, 10).cuda() + p1 = p.data.sum().item() - p1 = p.data.sum().item() + adam = Adam([p]) - adam = Adam([p]) + out = a * p + loss = out.sum() + loss.backward() + adam.step() - out = a * p - loss = out.sum() - loss.backward() - adam.step() + p2 = p.data.sum().item() - p2 = p.data.sum().item() + assert p1 != p2 + print("SUCCESS!") + print("Installation was successful!") + except ImportError: + print() + warn( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!" + ) + print_debug_info() + except Exception as e: + print(e) + print_debug_info() + sys.exit(1) - assert p1 != p2 - print("SUCCESS!") - print("Installation was successful!") - sys.exit(0) -except ImportError: - print() - warn( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!" - ) - print_debug_info() - sys.exit(0) -except Exception as e: - print(e) - print_debug_info() - sys.exit(1) +if __name__ == "__main__": + main() diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index 6b9a7e4d1..f262d89ed 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -1 +1 @@ -from ._functions import undo_layout, get_inverse_transform_indices +from ._functions import get_inverse_transform_indices, undo_layout diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 9917e326e..6cbb6efd9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,8 +1,8 @@ -import operator -import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional, Callable +import operator +from typing import Callable, Optional, Tuple +import warnings from warnings import warn import torch diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index d52a6d607..858365f02 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,12 +1,9 @@ import ctypes as ct -import os -import torch - -from pathlib import Path from warnings import warn -from bitsandbytes.cuda_setup.main import CUDASetup +import torch +from bitsandbytes.cuda_setup.main import CUDASetup setup = CUDASetup.get_instance() if setup.initialized != True: @@ -25,7 +22,7 @@ Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False + _ = lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index a5931ef5e..a34385b1f 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -17,15 +17,15 @@ """ import ctypes as ct -import os import errno +import os +from pathlib import Path import platform -import torch +from typing import Set, Union from warnings import warn -from itertools import product -from pathlib import Path -from typing import Set, Union +import torch + from .env_vars import get_potentially_lib_path_containing_env_vars # these are the most common libs names @@ -111,14 +111,16 @@ def manual_override(self): if torch.cuda.is_available(): if 'BNB_CUDA_VERSION' in os.environ: if len(os.environ['BNB_CUDA_VERSION']) > 0: - warn((f'\n\n{"="*80}\n' - 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' - 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' - 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' - 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' - 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: try: if path.exists(): existent_directories.add(path) - except PermissionError as pex: + except PermissionError: # Handle the PermissionError first as it is a subtype of OSError # https://docs.python.org/3/library/exceptions.html#exception-hierarchy pass @@ -217,8 +219,10 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: non_existent_directories: Set[Path] = candidate_paths - existent_directories if non_existent_directories: - CUDASetup.get_instance().add_log_entry("The following directories listed in your path were found to " - f"be non-existent: {non_existent_directories}", is_warning=False) + CUDASetup.get_instance().add_log_entry( + f"The following directories listed in your path were found to be non-existent: {non_existent_directories}", + is_warning=False, + ) return existent_directories @@ -360,8 +364,10 @@ def evaluate_cuda_setup(): cuda_version_string = get_cuda_version() cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") - cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:" - "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md") + cuda_setup.add_log_entry( + "CUDA SETUP: To manually override the PyTorch CUDA version please see:" + "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" + ) # 7.5 is the minimum CC vor cublaslt diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 25aa4e531..1f624a7a8 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,17 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct +from functools import reduce # Required in Python 3 import itertools import operator -import random -import torch -import itertools -import math -import numpy as np +from typing import Any, Dict, Optional, Tuple -from functools import reduce # Required in Python 3 -from typing import Tuple, Any, Dict, Optional +import numpy as np +import torch from torch import Tensor + from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict from .cextension import COMPILED_WITH_CUDA, lib @@ -178,7 +176,9 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): +FIRST_CUDA_DEVICE = torch.device('cuda', index=0) + +def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): num_bytes = dtype2bytes[dtype]*prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) @@ -242,7 +242,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): if gap == 0: return values else: - l = values.numel()//2 + l = values.numel()//2 # noqa: E741 return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) @@ -283,7 +283,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate(range(-(2**(exponent_bits-has_sign)), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) @@ -345,7 +345,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) + fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -899,7 +899,7 @@ def get_4bit_type(typename, device=None, blocksize=64): -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, 0.42563882, 0.55496234, 0.72424863, 1.][::-1] else: - raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') + raise NotImplementedError('4-bit AbnormalFloats currently only support blocksize 64.') if data is None: raise NotImplementedError(f'Typename {typename} not supported') @@ -1635,10 +1635,10 @@ def gemv_4bit( prev_device = pre_call(A.device) #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + raise ValueError('state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') if A.numel() != A.shape[-1]: - raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + raise ValueError('Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') Bshape = state.shape bout = Bshape[0] diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 6fa6d1183..96f4359bf 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,5 +2,21 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding -from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear +from .modules import ( + Embedding, + Int8Params, + Linear4bit, + Linear8bitLt, + LinearFP4, + LinearNF4, + OutlierAwareLinear, + Params4bit, + StableEmbedding, + SwitchBackLinearBnb, +) +from .triton_based_modules import ( + StandardLinear, + SwitchBackLinear, + SwitchBackLinearGlobal, + SwitchBackLinearVectorwise, +) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b1f6deb21..922feae15 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -3,17 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, Optional, TypeVar, Union, overload - import warnings + import torch -import torch.nn.functional as F from torch import Tensor, device, dtype, nn +import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.autograd._functions import get_tile_inds, undo_layout from bitsandbytes.functional import QuantState -from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims +from bitsandbytes.utils import OutlierTracer T = TypeVar("T", bound="torch.nn.Module") @@ -242,10 +242,10 @@ def set_compute_type(self, x): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') + warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') warnings.filterwarnings('ignore', message='.*inference.') if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') + warnings.warn('Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') warnings.filterwarnings('ignore', message='.*inference or training') def _save_to_state_dict(self, destination, prefix, keep_vars): @@ -337,8 +337,8 @@ def cuda(self, device): del CBt del SCBt self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + self.CB = CB + self.SCB = SCB return self diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index 67b45f4a5..9c7738c59 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -1,16 +1,24 @@ -import torch -import torch.nn as nn -import time from functools import partial -from bitsandbytes.triton.triton_utils import is_triton_available +import torch +import torch.nn as nn from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( + int8_matmul_mixed_dequantize, +) +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( + int8_matmul_rowwise_dequantize, +) +from bitsandbytes.triton.quantize_columnwise_and_transpose import ( + quantize_columnwise_and_transpose, +) +from bitsandbytes.triton.quantize_global import ( + quantize_global, + quantize_global_transpose, +) from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize +from bitsandbytes.triton.triton_utils import is_triton_available class _switchback_global(torch.autograd.Function): diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 83a57bd9f..6796b8e0e 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -7,10 +7,17 @@ from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit +from .adamw import ( + AdamW, + AdamW8bit, + AdamW32bit, + PagedAdamW, + PagedAdamW8bit, + PagedAdamW32bit, +) from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 21077f1a0..9ea5812ea 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,7 +5,6 @@ from bitsandbytes.optim.optimizer import Optimizer2State - class AdamW(Optimizer2State): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2bde1a447..b6ba4a9f1 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Lion(Optimizer1State): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index fb83eddf0..8254d16b4 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import abc as container_abcs -from collections import defaultdict +from collections import abc as container_abcs, defaultdict from copy import deepcopy from itertools import chain diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py index 47b720d78..31db4f282 100644 --- a/bitsandbytes/research/__init__.py +++ b/bitsandbytes/research/__init__.py @@ -1,6 +1,6 @@ from . import nn from .autograd._functions import ( - switchback_bnb, matmul_fp8_global, matmul_fp8_mixed, + switchback_bnb, ) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 06b0748ff..e515bfeff 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -1,15 +1,13 @@ -import operator -import warnings -from dataclasses import dataclass from functools import reduce # Required in Python 3 +import operator from typing import Optional +import warnings import torch +from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState import bitsandbytes.functional as F -from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler - # math.prod not compatible with python < 3.8 def prod(iterable): @@ -186,7 +184,9 @@ def backward(ctx, grad_output): class SwitchBackBnb(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # TODO: the B008 on the line below is a likely bug; the current implementation will + # have each SwitchBackBnb instance share a single MatmulLtState instance!!! + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008 # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py index 8faec10bb..417011218 100644 --- a/bitsandbytes/research/nn/__init__.py +++ b/bitsandbytes/research/nn/__init__.py @@ -1 +1 @@ -from .modules import LinearFP8Mixed, LinearFP8Global +from .modules import LinearFP8Global, LinearFP8Mixed diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py index 2a46b40c3..7fca34d23 100644 --- a/bitsandbytes/research/nn/modules.py +++ b/bitsandbytes/research/nn/modules.py @@ -1,12 +1,9 @@ -from typing import Optional, TypeVar, Union, overload +from typing import TypeVar import torch -import torch.nn.functional as F -from torch import Tensor, device, dtype, nn +from torch import nn import bitsandbytes as bnb -from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py index e092680b8..daa59da9c 100644 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -1,6 +1,7 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -9,7 +10,6 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index b0961f558..1b80ab1a0 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -1,4 +1,5 @@ import torch + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -57,7 +58,8 @@ def get_configs_io_bound(): triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), + *get_configs_io_bound(), + ], key=['M', 'N', 'K'], prune_configs_by={ 'early_config_prune': early_config_prune, diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 33f4d13f2..1f28b0d10 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -57,7 +57,8 @@ def get_configs_io_bound(): triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), + *get_configs_io_bound(), + ], key=['M', 'N', 'K'], prune_configs_by={ 'early_config_prune': early_config_prune, diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index 54220d95a..fcadaba3e 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -1,6 +1,7 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -9,7 +10,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # This kernel does fused columnwise quantization and transpose. diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index 845db6ecd..a73a5bbaa 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -1,6 +1,6 @@ -import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): @@ -10,7 +10,6 @@ def quantize_global(x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # global quantize @triton.autotune( diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index 26d218321..fce464b19 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -1,6 +1,6 @@ import math + import torch -import time from bitsandbytes.triton.triton_utils import is_triton_available @@ -10,7 +10,6 @@ def quantize_rowwise(x: torch.Tensor): return None import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py index c74c23962..6bbdbf1c1 100644 --- a/bitsandbytes/triton/triton_utils.py +++ b/bitsandbytes/triton/triton_utils.py @@ -1,4 +1,5 @@ import importlib + def is_triton_available(): return importlib.util.find_spec("triton") is not None diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 48373a1fe..0582f7fc0 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,9 +1,11 @@ import json import shlex import subprocess -import torch from typing import Tuple +import torch + + def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) tracer = OutlierTracer.get_instance() @@ -37,7 +39,7 @@ def outlier_hook(module, input): hook.remove() -class OutlierTracer(object): +class OutlierTracer: _instance = None def __init__(self): @@ -122,7 +124,13 @@ def execute_and_return_decoded_std_streams(command_string): -def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): +def replace_linear( + model, + linear_replacement, + skip_modules=("lm_head",), + copy_weights=False, + post_processing_function=None, +): """ Replace linear modules with a new Linear module. Parameters: diff --git a/check_bnb_install.py b/check_bnb_install.py index 77cd03ec4..5a7f74f89 100644 --- a/check_bnb_install.py +++ b/check_bnb_install.py @@ -1,6 +1,7 @@ -import bitsandbytes as bnb import torch +import bitsandbytes as bnb + p = torch.nn.Parameter(torch.rand(10,10).cuda()) a = torch.rand(10,10).cuda() diff --git a/install_cuda.py b/install_cuda.py index e90f6b6fb..77e258609 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -1,6 +1,6 @@ import os -import sys import subprocess +import sys from urllib.request import urlretrieve cuda_versions = { diff --git a/pyproject.toml b/pyproject.toml index c73f579e0..53942bc41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,7 @@ src = [ "tests", "benchmarking" ] -fix = true select = [ - "A", # prevent using keywords that clobber python builtins "B", # bugbear: security warnings "E", # pycodestyle "F", # pyflakes @@ -24,12 +22,29 @@ select = [ ] target-version = "py38" ignore = [ - "E712", # Allow using if x == False, as it's not always equivalent to if x. + "B007", # Loop control variable not used within the loop body (TODO: enable) + "B028", # Warning without stacklevel (TODO: enable) "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. - "F401", + "E701", # Multiple statements on one line (TODO: enable) + "E712", # Allow using if x == False, as it's not always equivalent to if x. + "E731", # Do not use lambda + "F841", # Local assigned but not used (TODO: enable, these are likely bugs) + "RUF012", # Mutable class attribute annotations ] ignore-init-module-imports = true # allow to expose in __init__.py via imports +[tool.ruff.extend-per-file-ignores] +"**/__init__.py" = ["F401"] # allow unused imports in __init__.py +"{benchmarking,tests}/**/*.py" = [ + "B007", + "B011", + "B023", + "E701", + "E731", + "F841", + "UP030", +] + [tool.ruff.isort] combine-as-imports = true detect-same-package = true diff --git a/scripts/stale.py b/scripts/stale.py index b7f34c1fb..c299643ae 100644 --- a/scripts/stale.py +++ b/scripts/stale.py @@ -15,13 +15,11 @@ Script to close stale issue. Taken in part from the AllenNLP repository. https://github.com/allenai/allennlp. """ +from datetime import datetime as dt, timezone import os -from datetime import datetime as dt -from datetime import timezone from github import Github - # All labels that we don't want to touch LABELS_TO_EXEMPT = [ "feature-request", diff --git a/setup.py b/setup.py index 7a82b7717..407116fbe 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,6 @@ from setuptools import find_packages, setup - libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) libs += list(glob.glob("./bitsandbytes/libbitsandbytes*.dll")) libs = [os.path.basename(p) for p in libs] @@ -19,7 +18,7 @@ def read(fname): setup( - name=f"bitsandbytes", + name="bitsandbytes", version="0.42.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 27b010105..ed482b356 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,4 +1,4 @@ -from itertools import permutations, product +from itertools import product import pytest import torch diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 596d0a030..5e1a548e5 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,8 +1,9 @@ import os -import pytest -import torch from pathlib import Path +import torch + + # hardcoded test. Not good, but a sanity check for now # TODO: improve this def test_manual_override(requires_cuda): diff --git a/tests/test_functional.py b/tests/test_functional.py index 970b4dbdb..5b7f83bc3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,16 +1,16 @@ +from itertools import product import math import random import time -from itertools import product import einops +import numpy as np import pytest +from scipy.stats import norm import torch -import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F -from scipy.stats import norm torch.set_printoptions( precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 diff --git a/tests/test_generation.py b/tests/test_generation.py index ecafdddf8..753623b27 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,22 +1,15 @@ -import pytest -import torch -import math - from itertools import product +import math +import pytest +import torch import transformers from transformers import ( - AutoConfig, AutoModelForCausalLM, - AutoTokenizer, BitsAndBytesConfig, - GenerationConfig, - set_seed, - ) - def get_4bit_config(): return BitsAndBytesConfig( load_in_4bit=True, diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 478255eee..d396a910b 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,6 +1,5 @@ -import os -from contextlib import nullcontext from itertools import product +import os from tempfile import TemporaryDirectory import pytest diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 8904aaf1b..d4967969c 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,6 +1,6 @@ -import os from contextlib import nullcontext from itertools import product +import os from tempfile import TemporaryDirectory import pytest @@ -11,7 +11,6 @@ from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt - # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py diff --git a/tests/test_modules.py b/tests/test_modules.py index cabd7cf54..c98f7a6d4 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,5 +1,6 @@ -from itertools import product +import math +import einops import pytest import torch from torch import nn diff --git a/tests/test_optim.py b/tests/test_optim.py index 49d4f442a..993ac8b60 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,14 +1,12 @@ -import ctypes +from itertools import product import os +from os.path import join import shutil import time import uuid -from itertools import product -from os.path import join -import pytest from lion_pytorch import Lion - +import pytest import torch import bitsandbytes as bnb @@ -27,7 +25,7 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def get_temp_dir(): - path = f"/tmp/autoswap/{str(uuid.uuid4())}" + path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) return path diff --git a/tests/test_triton.py b/tests/test_triton.py index e18c7a930..d0397ee4a 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,9 +1,10 @@ import pytest import torch -from bitsandbytes.triton.triton_utils import is_triton_available -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn import Linear8bitLt +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear +from bitsandbytes.triton.triton_utils import is_triton_available + @pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires triton and a GPU with compute capability 8.0 or higher.")