Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ NVCC := $(CUDA_HOME)/bin/nvcc
###########################################

CSRC := $(ROOT_DIR)/csrc
BUILD_DIR:= $(ROOT_DIR)/cuda_build
BUILD_DIR:= $(ROOT_DIR)/build

FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/pythonInterface.c
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c

INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
Expand Down Expand Up @@ -46,27 +46,30 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

cuda110: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

cuda11x: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)

cpuonly: $(BUILD_DIR) env
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so

env:
@echo "ENVIRONMENT"
Expand All @@ -80,7 +83,7 @@ env:
@echo "============================"

$(BUILD_DIR):
mkdir -p cuda_build
mkdir -p build
mkdir -p dependencies

$(ROOT_DIR)/dependencies/cub:
Expand Down
11 changes: 8 additions & 3 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .optim import adam

from .nn import modules
__pdoc__ = {'libBitsNBytes' : False,
from .cextension import COMPILED_WITH_CUDA

if COMPILED_WITH_CUDA:
from .optim import adam

__pdoc__ = {'libBitsNBytes': False,
'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False
}
}
13 changes: 13 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import ctypes as ct
import os
from warnings import warn

lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')

try:
lib.cadam32bit_g32
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False
52 changes: 26 additions & 26 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,38 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import os
import random
from typing import Tuple

import torch
from torch import Tensor

lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
from .cextension import lib, COMPILED_WITH_CUDA

name2qmap = {}

''' C FUNCTIONS FOR OPTIMIZERS '''

str2optimizer32bit = {}
str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)

str2optimizer8bit = {}
str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16)
str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)

str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
if COMPILED_WITH_CUDA:
''' C FUNCTIONS FOR OPTIMIZERS '''
str2optimizer32bit = {}
str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)

str2optimizer8bit = {}
str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)
str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16)
str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16)

str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)

optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0]

Expand Down Expand Up @@ -138,7 +138,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementError(f'Not supported data type {A.dtype}')
raise NotImplementedError(f'Not supported data type {A.dtype}')
return out

def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
Expand Down Expand Up @@ -384,7 +384,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
param_norm = torch.norm(p.data.float())

if optimizer_name not in str2optimizer32bit:
raise NotImplementError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')

if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),
Expand Down
19 changes: 12 additions & 7 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit

from bitsandbytes.cextension import COMPILED_WITH_CUDA

if COMPILED_WITH_CUDA:
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit

from .optimizer import GlobalOptimManager
2 changes: 1 addition & 1 deletion bitsandbytes/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, moment
if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
4 changes: 2 additions & 2 deletions compile_from_source.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Compiling from source

Basic steps.
1. `make cudaXXX` where `cudaXXX` is among `cuda92, cuda10x, cuda110, cuda11x`
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly`
2. `CUDA_VERSION=XXX python setup.py install`

To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).

For your convenience, there is a install script int he root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
```bash
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64/" >> ~/.bashrc
echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc
Expand Down
39 changes: 39 additions & 0 deletions csrc/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <common.h>
#include <float.h>

void *quantize_block(void *arguments) {
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index

struct quantize_block_args *args = (quantize_block_args *) arguments;

// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (int i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));

args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;

for (int i = args->block_idx; i < args->block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i] / absmax_block;
int idx = args->bin_searcher->scalar(normed_value);

// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if (idx < 255) {
float dist_left = fabs(normed_value - (args->code[idx]));
float dist_right = fabs(normed_value - (args->code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
}

// 5. store index
args->out[i] = (unsigned char) idx;
}

return NULL;
}
23 changes: 23 additions & 0 deletions csrc/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <BinSearch.h>

#ifndef common
#define common

using namespace BinSearch;

struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
int block_end;
int block_idx;
int threadidx;
};

#define BLOCK_SIZE 4096

void *quantize_block(void *arguments);

#endif
Loading