Skip to content

Commit 4b80284

Browse files
jeromekumsaroufim
andauthored
GaLore and fused kernel prototypes (pytorch#95)
* initial commit * add placeholders for cutlass and triton * update readme * fix versions * minor text edits * clean up * add triton bnb quant kernel and test * add notes on triton quant kernel * refactor code structure * add galore downproj test * refactor test utils * add fused kernel tests * add fused benchmark * add dequant kernel * update docs * add galore memory test * add adamw8bit * fix README * clean up binaries * remove notebook, add instructions to README * remove sample data * Update galore tests Skip tests if no GPU * rename galore docs * More test edits Additional conditions for skipping tests to avoid CI failure. Rename files as they are not actual tests but profiling tools to avoid triggering CI runs. * decrease fused matmul parametrizations * remove long-running tests * remove tf32 test for now --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent eff80e9 commit 4b80284

28 files changed

+3918
-0
lines changed
+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from fused_benchmark_utils import get_benchmark # , make_data
6+
7+
8+
def run(args):
9+
dtype = getattr(torch, args.dtype)
10+
allow_tf32 = args.allow_tf32
11+
fp8_fast_accum = False
12+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
13+
kernel = args.kernel
14+
M, N = args.M, args.N
15+
rank = args.rank
16+
17+
# exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype)
18+
19+
benchmark = get_benchmark(M, N, dtype, allow_tf32=allow_tf32)
20+
save_path = (
21+
f'benchmark_{M}x{N}_{rank}_{args.dtype}_{"tf32" if allow_tf32 else "no-tf32"}'
22+
)
23+
if not os.path.exists(save_path):
24+
os.makedirs(save_path)
25+
print(
26+
f"Running benchmark for {M}x{N}, dtype {args.dtype}, allow_tf32 {allow_tf32}",
27+
flush=True,
28+
)
29+
benchmark.run(show_plots=False, print_data=True, save_path=save_path)
30+
print(f"Finished benchmark, results saved to {save_path}")
31+
32+
33+
if __name__ == "__main__":
34+
parser = argparse.ArgumentParser(
35+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
36+
)
37+
parser.add_argument(
38+
"--kernel",
39+
choices=["hybrid", "fused", "compiled"],
40+
default="hybrid",
41+
type=str,
42+
help="Kernel to test",
43+
)
44+
45+
parser.add_argument(
46+
"--allow_tf32", action="store_true", help="Allow tf32 for matmuls"
47+
)
48+
parser.add_argument("--M", type=int, default=4096, help="Grad (param) shape M")
49+
parser.add_argument("--N", type=int, default=4096, help="Grad (param) shape N")
50+
parser.add_argument(
51+
"--rank", type=int, default=128, help="Rank of GaLore projection"
52+
)
53+
parser.add_argument(
54+
"--dtype",
55+
type=str,
56+
choices=["float32", "float16", "bfloat16"],
57+
default="float32",
58+
help="Data type of grad (param) tensors",
59+
)
60+
61+
args = parser.parse_args()
62+
run(args)

benchmarks/fused_benchmark_utils.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import torch
2+
import triton
3+
from triton.testing import do_bench
4+
5+
from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher
6+
from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher
7+
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
8+
from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector
9+
10+
torch.manual_seed(0)
11+
12+
BETA1 = 0.9
13+
BETA2 = 0.999
14+
EPS = 1e-8
15+
STEP_SIZE = 1e-4
16+
17+
18+
def make_data(M, N, rank, dtype):
19+
grad = torch.randn(M, N, device="cuda", dtype=dtype)
20+
params = torch.randn(M, N, device="cuda", dtype=dtype)
21+
22+
galore_proj = GaLoreProjector(rank=rank)
23+
galore_proj.update_orthogonal_matrix(grad)
24+
25+
if M >= N:
26+
exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype)
27+
else:
28+
exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype)
29+
exp_avg2 = exp_avg**2
30+
31+
return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params
32+
33+
34+
def make_copy(*args):
35+
return [t.detach().clone() for t in args]
36+
37+
38+
def _ref_op(
39+
grad,
40+
proj_matrix,
41+
exp_avg,
42+
exp_avg2,
43+
params,
44+
beta1=BETA1,
45+
beta2=BETA2,
46+
eps=EPS,
47+
step_size=STEP_SIZE,
48+
**kwargs,
49+
):
50+
51+
# Step 1: Down proj grad
52+
M, N = grad.shape
53+
if M >= N:
54+
a, b = grad, proj_matrix.t()
55+
else:
56+
a, b = proj_matrix.t(), grad
57+
low_rank_grad = a @ b
58+
59+
# Step 2: update adam state
60+
exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1))
61+
exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2)
62+
denom = exp_avg2.sqrt().add_(eps)
63+
low_rank_norm_grad = exp_avg / denom
64+
65+
# Step 3: project normalized low rank grad to full rank
66+
if M >= N:
67+
a, b = low_rank_norm_grad, proj_matrix
68+
else:
69+
a, b = proj_matrix, low_rank_norm_grad
70+
full_grad_norm = a @ b
71+
72+
# Finally, update params with updated grad
73+
params.add_(full_grad_norm, alpha=-step_size)
74+
75+
return exp_avg, exp_avg2, params
76+
77+
78+
def _tt_hybrid(
79+
grad,
80+
proj_matrix,
81+
exp_avg,
82+
exp_avg2,
83+
params,
84+
store=True,
85+
step_size=STEP_SIZE,
86+
fp8_fast_accum=False,
87+
allow_tf32=False,
88+
):
89+
M, N = grad.shape
90+
if M >= N:
91+
a, b = grad, proj_matrix.t()
92+
else:
93+
a, b = proj_matrix.t(), grad
94+
low_rank_grad = a @ b
95+
96+
exp_avg, exp_avg2, norm_grad = triton_adam_launcher(
97+
exp_avg, exp_avg2, low_rank_grad, store=store
98+
)
99+
100+
if M >= N:
101+
a, b = low_rank_grad, proj_matrix
102+
else:
103+
a, b = proj_matrix, low_rank_grad
104+
params = triton_mm_launcher(
105+
a,
106+
b,
107+
epilogue_alpha=-step_size,
108+
epilogue_source=params,
109+
allow_tf32=allow_tf32,
110+
fp8_fast_accum=fp8_fast_accum,
111+
)
112+
return exp_avg, exp_avg2, params
113+
114+
115+
def _tt_fused(
116+
grad,
117+
proj_matrix,
118+
exp_avg,
119+
exp_avg2,
120+
params,
121+
store=True,
122+
step_size=STEP_SIZE,
123+
fp8_fast_accum=False,
124+
allow_tf32=False,
125+
):
126+
M, N = grad.shape
127+
128+
if M >= N:
129+
a, b = grad, proj_matrix.t()
130+
else:
131+
a, b = proj_matrix.t(), grad
132+
exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher(
133+
a,
134+
b,
135+
exp_avg=exp_avg,
136+
exp_avg2=exp_avg2,
137+
store=store,
138+
fp8_fast_accum=fp8_fast_accum,
139+
allow_tf32=allow_tf32,
140+
)
141+
142+
if M >= N:
143+
a, b = low_rank_grad, proj_matrix
144+
else:
145+
a, b = proj_matrix, low_rank_grad
146+
params = triton_mm_launcher(
147+
a,
148+
b,
149+
epilogue_alpha=-step_size,
150+
epilogue_source=params,
151+
allow_tf32=allow_tf32,
152+
fp8_fast_accum=fp8_fast_accum,
153+
)
154+
return exp_avg, exp_avg2, params
155+
156+
# logging.basicConfig(level=logging.INFO)
157+
158+
159+
def get_kernel(kernel):
160+
if kernel == "ref":
161+
op = _ref_op
162+
elif kernel == "ref":
163+
op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune")
164+
elif kernel == "hybrid":
165+
op = _tt_hybrid
166+
elif kernel == "fused":
167+
op = _tt_fused
168+
else:
169+
raise ValueError(f"Unknown kernel {kernel}")
170+
171+
return lambda *args, **kwargs: op(*args, **kwargs)
172+
173+
174+
def get_benchmark(
175+
M, N, dtype, allow_tf32, fp8_fast_accum=False, quantiles=[0.5, 0.2, 0.8]
176+
):
177+
config = triton.testing.Benchmark(
178+
x_names=["rank"], # Argument names to use as an x-axis for the plot
179+
x_vals=[
180+
32,
181+
64,
182+
128,
183+
256,
184+
512,
185+
], # Different possible values for `x_name`
186+
line_arg="kernel", # Argument name whose value corresponds to a different line in the plot
187+
# Possible values for `line_arg`
188+
line_vals=["torch", "hybrid", "fused", "compiled"],
189+
# Label name for the lines
190+
line_names=["torch", "hybrid", "fused", "compiled"],
191+
# Line styles
192+
styles=[("black", "-"), ("blue", "-"), ("red", "-"), ("green", "-")],
193+
ylabel="ms", # Label name for the y-axis
194+
plot_name=f"Adam Kernel Comparison Grad shape: {M}x{N}, dtype: {dtype}, allow_tf32: {allow_tf32}\nMedian times (ms)", # Name for the plot, used also as a file name for saving the plot.
195+
args={},
196+
)
197+
198+
def benchmark(rank, kernel):
199+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
200+
201+
exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype)
202+
203+
if kernel == "torch":
204+
ms, min_ms, max_ms = do_bench(
205+
lambda: _ref_op(
206+
grad,
207+
proj_matrix,
208+
exp_avg,
209+
exp_avg2,
210+
params,
211+
),
212+
quantiles=quantiles,
213+
)
214+
if kernel == "hybrid":
215+
ms, min_ms, max_ms = triton.testing.do_bench(
216+
lambda: _tt_hybrid(
217+
grad,
218+
proj_matrix,
219+
exp_avg,
220+
exp_avg2,
221+
params,
222+
store=True,
223+
allow_tf32=allow_tf32,
224+
fp8_fast_accum=fp8_fast_accum,
225+
),
226+
quantiles=quantiles,
227+
)
228+
if kernel == "fused":
229+
ms, min_ms, max_ms = triton.testing.do_bench(
230+
lambda: _tt_fused(
231+
grad,
232+
proj_matrix,
233+
exp_avg,
234+
exp_avg2,
235+
params,
236+
store=True,
237+
allow_tf32=allow_tf32,
238+
fp8_fast_accum=fp8_fast_accum,
239+
),
240+
quantiles=quantiles,
241+
)
242+
if kernel == "compiled":
243+
compiled_op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune")
244+
ms, min_ms, max_ms = triton.testing.do_bench(
245+
lambda: compiled_op(
246+
grad,
247+
proj_matrix,
248+
exp_avg,
249+
exp_avg2,
250+
params,
251+
),
252+
quantiles=quantiles,
253+
)
254+
255+
return ms, max_ms, min_ms
256+
257+
return triton.testing.perf_report(config)(benchmark)

dev-requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@ expecttest
33
parameterized
44
packaging
55
transformers
6+
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
7+
matplotlib # needed for triton benchmarking
8+
pandas # also for triton benchmarking
9+
transformers #for galore testing

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def read_requirements(file_path):
3232
"torchao.kernel.configs": ["*.pkl"],
3333
},
3434
install_requires=read_requirements("requirements.txt"),
35+
extras_require={"dev": read_requirements("dev-requirements.txt")},
3536
description="Package for applying ao techniques to GPU models",
3637
long_description=open("README.md").read(),
3738
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)