Skip to content

Commit

Permalink
GaLore and fused kernel prototypes (#95)
Browse files Browse the repository at this point in the history
* initial commit

* add placeholders for cutlass and triton

* update readme

* fix versions

* minor text edits

* clean up

* add triton bnb quant kernel and test

* add notes on triton quant kernel

* refactor code structure

* add galore downproj test

* refactor test utils

* add fused kernel tests

* add fused benchmark

* add dequant kernel

* update docs

* add galore memory test

* add adamw8bit

* fix README

* clean up binaries

* remove notebook, add instructions to README

* remove sample data

* Update galore tests

Skip tests if no GPU

* rename galore docs

* More test edits

Additional conditions for skipping tests to avoid CI failure.
Rename files as they are not actual tests but profiling tools to avoid
triggering CI runs.

* decrease fused matmul parametrizations

* remove long-running tests

* remove tf32 test for now

---------

Co-authored-by: Mark Saroufim <[email protected]>
  • Loading branch information
jeromeku and msaroufim authored Apr 16, 2024
1 parent d76ecc2 commit b0a649e
Show file tree
Hide file tree
Showing 28 changed files with 3,918 additions and 0 deletions.
62 changes: 62 additions & 0 deletions benchmarks/bench_galore_fused_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import argparse
import os

import torch
from fused_benchmark_utils import get_benchmark # , make_data


def run(args):
dtype = getattr(torch, args.dtype)
allow_tf32 = args.allow_tf32
fp8_fast_accum = False
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
kernel = args.kernel
M, N = args.M, args.N
rank = args.rank

# exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype)

benchmark = get_benchmark(M, N, dtype, allow_tf32=allow_tf32)
save_path = (
f'benchmark_{M}x{N}_{rank}_{args.dtype}_{"tf32" if allow_tf32 else "no-tf32"}'
)
if not os.path.exists(save_path):
os.makedirs(save_path)
print(
f"Running benchmark for {M}x{N}, dtype {args.dtype}, allow_tf32 {allow_tf32}",
flush=True,
)
benchmark.run(show_plots=False, print_data=True, save_path=save_path)
print(f"Finished benchmark, results saved to {save_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--kernel",
choices=["hybrid", "fused", "compiled"],
default="hybrid",
type=str,
help="Kernel to test",
)

parser.add_argument(
"--allow_tf32", action="store_true", help="Allow tf32 for matmuls"
)
parser.add_argument("--M", type=int, default=4096, help="Grad (param) shape M")
parser.add_argument("--N", type=int, default=4096, help="Grad (param) shape N")
parser.add_argument(
"--rank", type=int, default=128, help="Rank of GaLore projection"
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
default="float32",
help="Data type of grad (param) tensors",
)

args = parser.parse_args()
run(args)
257 changes: 257 additions & 0 deletions benchmarks/fused_benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import torch
import triton
from triton.testing import do_bench

from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher
from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector

torch.manual_seed(0)

BETA1 = 0.9
BETA2 = 0.999
EPS = 1e-8
STEP_SIZE = 1e-4


def make_data(M, N, rank, dtype):
grad = torch.randn(M, N, device="cuda", dtype=dtype)
params = torch.randn(M, N, device="cuda", dtype=dtype)

galore_proj = GaLoreProjector(rank=rank)
galore_proj.update_orthogonal_matrix(grad)

if M >= N:
exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype)
else:
exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype)
exp_avg2 = exp_avg**2

return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params


def make_copy(*args):
return [t.detach().clone() for t in args]


def _ref_op(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
beta1=BETA1,
beta2=BETA2,
eps=EPS,
step_size=STEP_SIZE,
**kwargs,
):

# Step 1: Down proj grad
M, N = grad.shape
if M >= N:
a, b = grad, proj_matrix.t()
else:
a, b = proj_matrix.t(), grad
low_rank_grad = a @ b

# Step 2: update adam state
exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1))
exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2)
denom = exp_avg2.sqrt().add_(eps)
low_rank_norm_grad = exp_avg / denom

# Step 3: project normalized low rank grad to full rank
if M >= N:
a, b = low_rank_norm_grad, proj_matrix
else:
a, b = proj_matrix, low_rank_norm_grad
full_grad_norm = a @ b

# Finally, update params with updated grad
params.add_(full_grad_norm, alpha=-step_size)

return exp_avg, exp_avg2, params


def _tt_hybrid(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
store=True,
step_size=STEP_SIZE,
fp8_fast_accum=False,
allow_tf32=False,
):
M, N = grad.shape
if M >= N:
a, b = grad, proj_matrix.t()
else:
a, b = proj_matrix.t(), grad
low_rank_grad = a @ b

exp_avg, exp_avg2, norm_grad = triton_adam_launcher(
exp_avg, exp_avg2, low_rank_grad, store=store
)

if M >= N:
a, b = low_rank_grad, proj_matrix
else:
a, b = proj_matrix, low_rank_grad
params = triton_mm_launcher(
a,
b,
epilogue_alpha=-step_size,
epilogue_source=params,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
)
return exp_avg, exp_avg2, params


def _tt_fused(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
store=True,
step_size=STEP_SIZE,
fp8_fast_accum=False,
allow_tf32=False,
):
M, N = grad.shape

if M >= N:
a, b = grad, proj_matrix.t()
else:
a, b = proj_matrix.t(), grad
exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher(
a,
b,
exp_avg=exp_avg,
exp_avg2=exp_avg2,
store=store,
fp8_fast_accum=fp8_fast_accum,
allow_tf32=allow_tf32,
)

if M >= N:
a, b = low_rank_grad, proj_matrix
else:
a, b = proj_matrix, low_rank_grad
params = triton_mm_launcher(
a,
b,
epilogue_alpha=-step_size,
epilogue_source=params,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
)
return exp_avg, exp_avg2, params

# logging.basicConfig(level=logging.INFO)


def get_kernel(kernel):
if kernel == "ref":
op = _ref_op
elif kernel == "ref":
op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune")
elif kernel == "hybrid":
op = _tt_hybrid
elif kernel == "fused":
op = _tt_fused
else:
raise ValueError(f"Unknown kernel {kernel}")

return lambda *args, **kwargs: op(*args, **kwargs)


def get_benchmark(
M, N, dtype, allow_tf32, fp8_fast_accum=False, quantiles=[0.5, 0.2, 0.8]
):
config = triton.testing.Benchmark(
x_names=["rank"], # Argument names to use as an x-axis for the plot
x_vals=[
32,
64,
128,
256,
512,
], # Different possible values for `x_name`
line_arg="kernel", # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
line_vals=["torch", "hybrid", "fused", "compiled"],
# Label name for the lines
line_names=["torch", "hybrid", "fused", "compiled"],
# Line styles
styles=[("black", "-"), ("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="ms", # Label name for the y-axis
plot_name=f"Adam Kernel Comparison Grad shape: {M}x{N}, dtype: {dtype}, allow_tf32: {allow_tf32}\nMedian times (ms)", # Name for the plot, used also as a file name for saving the plot.
args={},
)

def benchmark(rank, kernel):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32

exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype)

if kernel == "torch":
ms, min_ms, max_ms = do_bench(
lambda: _ref_op(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
),
quantiles=quantiles,
)
if kernel == "hybrid":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _tt_hybrid(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
store=True,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
),
quantiles=quantiles,
)
if kernel == "fused":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _tt_fused(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
store=True,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
),
quantiles=quantiles,
)
if kernel == "compiled":
compiled_op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune")
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: compiled_op(
grad,
proj_matrix,
exp_avg,
exp_avg2,
params,
),
quantiles=quantiles,
)

return ms, max_ms, min_ms

return triton.testing.perf_report(config)(benchmark)
4 changes: 4 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ expecttest
parameterized
packaging
transformers
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
matplotlib # needed for triton benchmarking
pandas # also for triton benchmarking
transformers #for galore testing
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def read_requirements(file_path):
"torchao.kernel.configs": ["*.pkl"],
},
install_requires=read_requirements("requirements.txt"),
extras_require={"dev": read_requirements("dev-requirements.txt")},
description="Package for applying ao techniques to GPU models",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand Down
Loading

0 comments on commit b0a649e

Please sign in to comment.