Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 53 additions & 69 deletions benchmarking/switchback/make_plot_with_jsonl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd

cmap = plt.get_cmap("cool")
cmap=plt.get_cmap('cool')

if __name__ == '__main__':

if __name__ == "__main__":
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 2)

dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
Expand All @@ -17,28 +19,25 @@
ax = fig.add_subplot(gs[0, 0])

# TODO: change this to what you want.
rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
df = rdf[rdf.batch_size == batch_size_for_plot1]

# first plot the time occupied by different operations
for k, marker, ls, color, name in [
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (sum of parts)",
),
("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),

('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),

('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),

('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'),
]:
xs = []
ys = []
Expand All @@ -48,104 +47,89 @@
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split("+"):
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split("+"):
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)

ax.plot(
xs,
ys,
color=color,
label=name,
marker=marker,
markersize=5 if marker == "s" else 5,
linestyle=ls,
linewidth=2 if "+" in k else 1.0,
)

ax.set_xlabel("dim", fontsize=13)
ax.set_ylabel("time (ms)", fontsize=13)
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)


ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13)

ax.grid()

ax.set_xscale("log")
ax.set_xscale('log')
if logscale_plot1:
ax.set_yscale("log")
ax.set_yscale('log')

ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)

ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)

leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight("bold")
leg.get_texts()[1].set_fontweight("bold")
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
leg.get_texts()[0].set_fontweight('bold')
leg.get_texts()[1].set_fontweight('bold')
plt.subplots_adjust(left=0.1)
ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)


ax = fig.add_subplot(gs[0, 1])

# now plot the % speedup for different batch sizes
for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
(
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
"o",
"-",
"C4",
"SwitchBack int8 (total time)",
),
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:

xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in dims_to_consider:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split("+"):
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split("+"):
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)

color = cmap(j * 0.25)
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
markers = ["^", "v", "P", "o"]
ax.plot(
all_xs[0],
real_ys,
color=color,
label=f"batch * sequence length = {batch_size}",
marker=markers[j],
markersize=5 if marker == "s" else 5,
)
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)

ax.legend()
ax.set_xlabel("dim", fontsize=13)
ax.set_xscale("log")
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r"% speedup", fontsize=13)
ax.set_ylabel(r'% speedup', fontsize=13)

ax.tick_params(axis="x", labelsize=11)
ax.tick_params(axis="y", labelsize=11)

ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)

ax.set_xticks(dims_to_xtick)
ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True)

ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)



plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
122 changes: 36 additions & 86 deletions benchmarking/switchback/speed_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,32 @@

# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.


def get_time(k, fn, info_dict):

for _ in range(repeat // 2):
fn()
fn()

torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
fn()

torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms


if __name__ == "__main__":
if __name__ == '__main__':
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:

# switch switches dim_in and dim_out
for switch in [False, True]:

# hparams
repeat = 64
batch_size = batch_size
Expand Down Expand Up @@ -72,86 +73,35 @@ def get_time(k, fn, info_dict):
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()

info = {
"repeat": repeat,
"batch_size": batch_size,
"dim_out": dim_out,
"dim_in": dim_in,
"wm": wm,
"switch": switch,
}

get_time("standard_fwd", lambda: x.matmul(w.t()), info)
get_time("standard_gw", lambda: g.t().matmul(x), info)
get_time("standard_gx", lambda: g.matmul(w), info)
get_time(
"rowwise_fwd",
lambda: int8_matmul_rowwise_dequantize(
x_int8,
w_int8.t(),
state_x_rowwise,
state_w_columnwise,
None,
),
info,
)
get_time(
"rowwise_bwd",
lambda: int8_matmul_rowwise_dequantize(
g_int8,
wt_int8.t(),
state_x_rowwise,
state_w_rowwise,
None,
),
info,
)
get_time(
"global_fwd",
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time(
"global_bwd",
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
info,
)
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
get_time("w_quantize_global", lambda: quantize_global(w), info)
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)

time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
time_rowwise = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_colwise_transpose"]
+ info["w_quantize_rowwise"]
+ info["standard_gw"]
+ info["rowwise_fwd"]
+ info["rowwise_bwd"]
)
time_global = (
info["x_quantize_rowwise"]
+ info["g_quantize_rowwise"]
+ info["w_quantize_global"]
+ info["w_quantize_global_transpose"]
+ info["standard_gw"]
+ info["global_fwd"]
+ info["global_bwd"]
)

print("TOTAL STANDARD", time_standard)
print("TOTAL ROWWISE", time_rowwise)
print("TOTAL GLOBAL", time_global)

print("speedup", -100 * (time_global - time_standard) / time_standard)

info["time_standard"] = time_standard
info["time_rowwise"] = time_rowwise
info["time_global"] = time_global
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}

get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)

time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']

print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)

print('speedup', -100*(time_global - time_standard)/time_standard)

info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global

info_json = json.dumps(info)

Expand Down
11 changes: 8 additions & 3 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import research, utils
from . import cuda_setup, research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand All @@ -12,13 +12,18 @@
matmul_cublas,
mm_cublas,
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
from .optim import adam

if COMPILED_WITH_CUDA:
from .optim import adam

__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
"optim.optimizer.MockArgs": False,
}

__version__ = "0.44.0.dev"
__version__ = "0.43.0"

PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes"
Loading