|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +from triton.testing import do_bench |
| 4 | + |
| 5 | +from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher |
| 6 | +from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher |
| 7 | +from torchao.prototype.galore.kernels.matmul import triton_mm_launcher |
| 8 | +from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector |
| 9 | + |
| 10 | +torch.manual_seed(0) |
| 11 | + |
| 12 | +BETA1 = 0.9 |
| 13 | +BETA2 = 0.999 |
| 14 | +EPS = 1e-8 |
| 15 | +STEP_SIZE = 1e-4 |
| 16 | + |
| 17 | + |
| 18 | +def make_data(M, N, rank, dtype): |
| 19 | + grad = torch.randn(M, N, device="cuda", dtype=dtype) |
| 20 | + params = torch.randn(M, N, device="cuda", dtype=dtype) |
| 21 | + |
| 22 | + galore_proj = GaLoreProjector(rank=rank) |
| 23 | + galore_proj.update_orthogonal_matrix(grad) |
| 24 | + |
| 25 | + if M >= N: |
| 26 | + exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) |
| 27 | + else: |
| 28 | + exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) |
| 29 | + exp_avg2 = exp_avg**2 |
| 30 | + |
| 31 | + return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params |
| 32 | + |
| 33 | + |
| 34 | +def make_copy(*args): |
| 35 | + return [t.detach().clone() for t in args] |
| 36 | + |
| 37 | + |
| 38 | +def _ref_op( |
| 39 | + grad, |
| 40 | + proj_matrix, |
| 41 | + exp_avg, |
| 42 | + exp_avg2, |
| 43 | + params, |
| 44 | + beta1=BETA1, |
| 45 | + beta2=BETA2, |
| 46 | + eps=EPS, |
| 47 | + step_size=STEP_SIZE, |
| 48 | + **kwargs, |
| 49 | +): |
| 50 | + |
| 51 | + # Step 1: Down proj grad |
| 52 | + M, N = grad.shape |
| 53 | + if M >= N: |
| 54 | + a, b = grad, proj_matrix.t() |
| 55 | + else: |
| 56 | + a, b = proj_matrix.t(), grad |
| 57 | + low_rank_grad = a @ b |
| 58 | + |
| 59 | + # Step 2: update adam state |
| 60 | + exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1)) |
| 61 | + exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2) |
| 62 | + denom = exp_avg2.sqrt().add_(eps) |
| 63 | + low_rank_norm_grad = exp_avg / denom |
| 64 | + |
| 65 | + # Step 3: project normalized low rank grad to full rank |
| 66 | + if M >= N: |
| 67 | + a, b = low_rank_norm_grad, proj_matrix |
| 68 | + else: |
| 69 | + a, b = proj_matrix, low_rank_norm_grad |
| 70 | + full_grad_norm = a @ b |
| 71 | + |
| 72 | + # Finally, update params with updated grad |
| 73 | + params.add_(full_grad_norm, alpha=-step_size) |
| 74 | + |
| 75 | + return exp_avg, exp_avg2, params |
| 76 | + |
| 77 | + |
| 78 | +def _tt_hybrid( |
| 79 | + grad, |
| 80 | + proj_matrix, |
| 81 | + exp_avg, |
| 82 | + exp_avg2, |
| 83 | + params, |
| 84 | + store=True, |
| 85 | + step_size=STEP_SIZE, |
| 86 | + fp8_fast_accum=False, |
| 87 | + allow_tf32=False, |
| 88 | +): |
| 89 | + M, N = grad.shape |
| 90 | + if M >= N: |
| 91 | + a, b = grad, proj_matrix.t() |
| 92 | + else: |
| 93 | + a, b = proj_matrix.t(), grad |
| 94 | + low_rank_grad = a @ b |
| 95 | + |
| 96 | + exp_avg, exp_avg2, norm_grad = triton_adam_launcher( |
| 97 | + exp_avg, exp_avg2, low_rank_grad, store=store |
| 98 | + ) |
| 99 | + |
| 100 | + if M >= N: |
| 101 | + a, b = low_rank_grad, proj_matrix |
| 102 | + else: |
| 103 | + a, b = proj_matrix, low_rank_grad |
| 104 | + params = triton_mm_launcher( |
| 105 | + a, |
| 106 | + b, |
| 107 | + epilogue_alpha=-step_size, |
| 108 | + epilogue_source=params, |
| 109 | + allow_tf32=allow_tf32, |
| 110 | + fp8_fast_accum=fp8_fast_accum, |
| 111 | + ) |
| 112 | + return exp_avg, exp_avg2, params |
| 113 | + |
| 114 | + |
| 115 | +def _tt_fused( |
| 116 | + grad, |
| 117 | + proj_matrix, |
| 118 | + exp_avg, |
| 119 | + exp_avg2, |
| 120 | + params, |
| 121 | + store=True, |
| 122 | + step_size=STEP_SIZE, |
| 123 | + fp8_fast_accum=False, |
| 124 | + allow_tf32=False, |
| 125 | +): |
| 126 | + M, N = grad.shape |
| 127 | + |
| 128 | + if M >= N: |
| 129 | + a, b = grad, proj_matrix.t() |
| 130 | + else: |
| 131 | + a, b = proj_matrix.t(), grad |
| 132 | + exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher( |
| 133 | + a, |
| 134 | + b, |
| 135 | + exp_avg=exp_avg, |
| 136 | + exp_avg2=exp_avg2, |
| 137 | + store=store, |
| 138 | + fp8_fast_accum=fp8_fast_accum, |
| 139 | + allow_tf32=allow_tf32, |
| 140 | + ) |
| 141 | + |
| 142 | + if M >= N: |
| 143 | + a, b = low_rank_grad, proj_matrix |
| 144 | + else: |
| 145 | + a, b = proj_matrix, low_rank_grad |
| 146 | + params = triton_mm_launcher( |
| 147 | + a, |
| 148 | + b, |
| 149 | + epilogue_alpha=-step_size, |
| 150 | + epilogue_source=params, |
| 151 | + allow_tf32=allow_tf32, |
| 152 | + fp8_fast_accum=fp8_fast_accum, |
| 153 | + ) |
| 154 | + return exp_avg, exp_avg2, params |
| 155 | + |
| 156 | + # logging.basicConfig(level=logging.INFO) |
| 157 | + |
| 158 | + |
| 159 | +def get_kernel(kernel): |
| 160 | + if kernel == "ref": |
| 161 | + op = _ref_op |
| 162 | + elif kernel == "ref": |
| 163 | + op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") |
| 164 | + elif kernel == "hybrid": |
| 165 | + op = _tt_hybrid |
| 166 | + elif kernel == "fused": |
| 167 | + op = _tt_fused |
| 168 | + else: |
| 169 | + raise ValueError(f"Unknown kernel {kernel}") |
| 170 | + |
| 171 | + return lambda *args, **kwargs: op(*args, **kwargs) |
| 172 | + |
| 173 | + |
| 174 | +def get_benchmark( |
| 175 | + M, N, dtype, allow_tf32, fp8_fast_accum=False, quantiles=[0.5, 0.2, 0.8] |
| 176 | +): |
| 177 | + config = triton.testing.Benchmark( |
| 178 | + x_names=["rank"], # Argument names to use as an x-axis for the plot |
| 179 | + x_vals=[ |
| 180 | + 32, |
| 181 | + 64, |
| 182 | + 128, |
| 183 | + 256, |
| 184 | + 512, |
| 185 | + ], # Different possible values for `x_name` |
| 186 | + line_arg="kernel", # Argument name whose value corresponds to a different line in the plot |
| 187 | + # Possible values for `line_arg` |
| 188 | + line_vals=["torch", "hybrid", "fused", "compiled"], |
| 189 | + # Label name for the lines |
| 190 | + line_names=["torch", "hybrid", "fused", "compiled"], |
| 191 | + # Line styles |
| 192 | + styles=[("black", "-"), ("blue", "-"), ("red", "-"), ("green", "-")], |
| 193 | + ylabel="ms", # Label name for the y-axis |
| 194 | + plot_name=f"Adam Kernel Comparison Grad shape: {M}x{N}, dtype: {dtype}, allow_tf32: {allow_tf32}\nMedian times (ms)", # Name for the plot, used also as a file name for saving the plot. |
| 195 | + args={}, |
| 196 | + ) |
| 197 | + |
| 198 | + def benchmark(rank, kernel): |
| 199 | + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 |
| 200 | + |
| 201 | + exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) |
| 202 | + |
| 203 | + if kernel == "torch": |
| 204 | + ms, min_ms, max_ms = do_bench( |
| 205 | + lambda: _ref_op( |
| 206 | + grad, |
| 207 | + proj_matrix, |
| 208 | + exp_avg, |
| 209 | + exp_avg2, |
| 210 | + params, |
| 211 | + ), |
| 212 | + quantiles=quantiles, |
| 213 | + ) |
| 214 | + if kernel == "hybrid": |
| 215 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 216 | + lambda: _tt_hybrid( |
| 217 | + grad, |
| 218 | + proj_matrix, |
| 219 | + exp_avg, |
| 220 | + exp_avg2, |
| 221 | + params, |
| 222 | + store=True, |
| 223 | + allow_tf32=allow_tf32, |
| 224 | + fp8_fast_accum=fp8_fast_accum, |
| 225 | + ), |
| 226 | + quantiles=quantiles, |
| 227 | + ) |
| 228 | + if kernel == "fused": |
| 229 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 230 | + lambda: _tt_fused( |
| 231 | + grad, |
| 232 | + proj_matrix, |
| 233 | + exp_avg, |
| 234 | + exp_avg2, |
| 235 | + params, |
| 236 | + store=True, |
| 237 | + allow_tf32=allow_tf32, |
| 238 | + fp8_fast_accum=fp8_fast_accum, |
| 239 | + ), |
| 240 | + quantiles=quantiles, |
| 241 | + ) |
| 242 | + if kernel == "compiled": |
| 243 | + compiled_op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") |
| 244 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 245 | + lambda: compiled_op( |
| 246 | + grad, |
| 247 | + proj_matrix, |
| 248 | + exp_avg, |
| 249 | + exp_avg2, |
| 250 | + params, |
| 251 | + ), |
| 252 | + quantiles=quantiles, |
| 253 | + ) |
| 254 | + |
| 255 | + return ms, max_ms, min_ms |
| 256 | + |
| 257 | + return triton.testing.perf_report(config)(benchmark) |
0 commit comments