Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Resources:
1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)``
2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same)
3. There are two modes:
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``use_fp16_weights=True`` (default)
- Int8 inference. Pass the argument ``use_fp16_weights=False``
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default)
- Int8 inference. Pass the argument ``has_fp16_weights=False``
4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``.
```python
# LLM.int8()
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, use_fp16_weights=False, threshold=6.0)
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0)
# inputs need to be fp16
out = linear(x.to(torch.float16))
```
Expand Down Expand Up @@ -115,7 +115,8 @@ We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fa

## How to cite us
If you found this library and found LLM.int8() useful, please consider citing our work:
```

```bibtex
@article{dettmers2022llmint8,
title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale},
author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke},
Expand All @@ -124,8 +125,9 @@ If you found this library and found LLM.int8() useful, please consider citing ou
}
```

For 8-bit optimizers or quantization routines please consider citing the following work.
```
For 8-bit optimizers or quantization routines, please consider citing the following work:

```bibtex
@article{dettmers2022optimizers,
title={8-bit Optimizers via Block-wise Quantization},
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
from . import cuda_setup
from . import cuda_setup, utils

if COMPILED_WITH_CUDA:
from .optim import adam
Expand Down
7 changes: 3 additions & 4 deletions bitsandbytes/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# cli()
import os
import sys
import torch
from warnings import warn

import torch

HEADER_WIDTH = 60

Expand Down Expand Up @@ -32,8 +33,6 @@ def print_debug_info() -> None:
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
from .cuda_setup.env_vars import to_be_ignored
from .utils import print_stderr


print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
for k, v in os.environ.items():
Expand Down Expand Up @@ -84,7 +83,7 @@ def print_debug_info() -> None:

except ImportError:
print()
print_stderr(
warn(
f"WARNING: {__package__} is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!"
Expand Down
4 changes: 0 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import operator
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F

from dataclasses import dataclass
Expand Down Expand Up @@ -378,9 +377,6 @@ def backward(ctx, grad_output):
return grad_A, grad_B, None, grad_bias, None


matmul = MatMul8bitLt.apply


def matmul(
A: tensor,
B: tensor,
Expand Down
79 changes: 0 additions & 79 deletions bitsandbytes/cuda_setup/compute_capability.py

This file was deleted.

11 changes: 6 additions & 5 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
"""

import ctypes
from pathlib import Path

from ..utils import execute_and_return
from .paths import determine_cuda_runtime_lib_path


Expand All @@ -28,7 +26,7 @@ def check_cuda_result(cuda, result_val):
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
raise Exception(f"CUDA exception! Error code: {error_str.value.decode()}")
print(f"CUDA exception! Error code: {error_str.value.decode()}")

def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
Expand Down Expand Up @@ -57,7 +55,7 @@ def get_cuda_lib_handle():
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
raise Exception('CUDA SETUP: ERROR! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))

Expand All @@ -80,7 +78,6 @@ def get_compute_capabilities(cuda):
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()

result = ctypes.c_int()
device = ctypes.c_int()

check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
Expand Down Expand Up @@ -119,6 +116,10 @@ def evaluate_cuda_setup():
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name

cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(
Expand Down
16 changes: 1 addition & 15 deletions bitsandbytes/cuda_setup/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,11 @@
from typing import Set, Union
from warnings import warn

from ..utils import print_stderr
from .env_vars import get_potentially_lib_path_containing_env_vars


CUDA_RUNTIME_LIB: str = "libcudart.so"


def purge_unwanted_semicolon(tentative_path: Path) -> Path:
"""
Special function to handle the following exception:
__LMOD_REF_COUNT_PATH=/sw/cuda/11.6.2/bin:2;/mmfs1/home/dettmers/git/sched/bin:1;/mmfs1/home/dettmers/data/anaconda3/bin:1;/mmfs1/home/dettmers/data/anaconda3/condabin:1;/mmfs1/home/dettmers/.local/bin:1;/mmfs1/home/dettmers/bin:1;/usr/local/bin:1;/usr/bin:1;/usr/local/sbin:1;/usr/sbin:1;/mmfs1/home/dettmers/.fzf/bin:1;/mmfs1/home/dettmers/data/local/cuda-11.4/bin:1
"""
# if ';' in str(tentative_path):
# path_as_str, _ = str(tentative_path).split(';')
pass


def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}

Expand All @@ -29,7 +17,7 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
}

if non_existent_directories:
print_stderr(
warn(
"WARNING: The following directories listed in your path were found to "
f"be non-existent: {non_existent_directories}"
)
Expand Down Expand Up @@ -117,8 +105,6 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
}



cuda_runtime_libs = set()
for env_var, value in remaining_candidate_env_vars.items():
cuda_runtime_libs.update(find_cuda_lib_in(value))
Expand Down
42 changes: 2 additions & 40 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ctypes as ct
import operator
import random
import math
import torch

from typing import Tuple
Expand Down Expand Up @@ -185,14 +184,9 @@ def create_dynamic_map(signed=True, n=7):


def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7

if major == 7:
if major <= 7:
return "col_turing"
elif major == 8:
return "col_ampere"
Expand Down Expand Up @@ -248,23 +242,6 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
return getattr(lib, name)


class GlobalData(object):
_instance = None

def __init__(self):
raise RuntimeError("Call get_instance() instead")

def initialize(self):
self.data = {}

@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance


def get_transform_buffer(
shape, dtype, device, to_order, from_order="row", transpose=False
):
Expand Down Expand Up @@ -1685,21 +1662,6 @@ def double_quant(
return out_row, out_col, row_stats, col_stats, coo_tensor


def get_special_format_str():
major, minor = torch.cuda.get_device_capability()
if major < 7:
print(
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7

if major == 7: return 'col_turing'
elif major == 8: return 'col_ampere'
else: return 'col_turing'




def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)
Expand Down
15 changes: 7 additions & 8 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

from bitsandbytes.cextension import COMPILED_WITH_CUDA

if COMPILED_WITH_CUDA:
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit

from .optimizer import GlobalOptimManager
9 changes: 0 additions & 9 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import shlex
import subprocess
import sys
from typing import Tuple


Expand All @@ -22,11 +21,3 @@ def execute_and_return_decoded_std_streams(command_string):

std_out, std_err = execute_and_return_decoded_std_streams(command_string)
return std_out, std_err


def print_stderr(s: str) -> None:
print(s, file=sys.stderr)


def warn_of_missing_prerequisite(s: str) -> None:
print_stderr("WARNING, missing pre-requisite: " + s)
6 changes: 5 additions & 1 deletion csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
printf("ERROR: Your GPU does not support Int8 Matmul!");
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);

return 0;
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def read(fname):

setup(
name=f"bitsandbytes",
version=f"0.32.1",
version=f"0.32.3",
author="Tim Dettmers",
author_email="[email protected]",
description="8-bit optimizers and matrix multiplication routines.",
Expand Down
Loading