diff --git a/benchmarks/bench_galore_fused_kernels.py b/benchmarks/bench_galore_fused_kernels.py new file mode 100644 index 0000000000..c05f31e921 --- /dev/null +++ b/benchmarks/bench_galore_fused_kernels.py @@ -0,0 +1,62 @@ +import argparse +import os + +import torch +from fused_benchmark_utils import get_benchmark # , make_data + + +def run(args): + dtype = getattr(torch, args.dtype) + allow_tf32 = args.allow_tf32 + fp8_fast_accum = False + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + kernel = args.kernel + M, N = args.M, args.N + rank = args.rank + + # exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) + + benchmark = get_benchmark(M, N, dtype, allow_tf32=allow_tf32) + save_path = ( + f'benchmark_{M}x{N}_{rank}_{args.dtype}_{"tf32" if allow_tf32 else "no-tf32"}' + ) + if not os.path.exists(save_path): + os.makedirs(save_path) + print( + f"Running benchmark for {M}x{N}, dtype {args.dtype}, allow_tf32 {allow_tf32}", + flush=True, + ) + benchmark.run(show_plots=False, print_data=True, save_path=save_path) + print(f"Finished benchmark, results saved to {save_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--kernel", + choices=["hybrid", "fused", "compiled"], + default="hybrid", + type=str, + help="Kernel to test", + ) + + parser.add_argument( + "--allow_tf32", action="store_true", help="Allow tf32 for matmuls" + ) + parser.add_argument("--M", type=int, default=4096, help="Grad (param) shape M") + parser.add_argument("--N", type=int, default=4096, help="Grad (param) shape N") + parser.add_argument( + "--rank", type=int, default=128, help="Rank of GaLore projection" + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + default="float32", + help="Data type of grad (param) tensors", + ) + + args = parser.parse_args() + run(args) diff --git a/benchmarks/fused_benchmark_utils.py b/benchmarks/fused_benchmark_utils.py new file mode 100644 index 0000000000..5456154c30 --- /dev/null +++ b/benchmarks/fused_benchmark_utils.py @@ -0,0 +1,257 @@ +import torch +import triton +from triton.testing import do_bench + +from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher +from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher +from torchao.prototype.galore.kernels.matmul import triton_mm_launcher +from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector + +torch.manual_seed(0) + +BETA1 = 0.9 +BETA2 = 0.999 +EPS = 1e-8 +STEP_SIZE = 1e-4 + + +def make_data(M, N, rank, dtype): + grad = torch.randn(M, N, device="cuda", dtype=dtype) + params = torch.randn(M, N, device="cuda", dtype=dtype) + + galore_proj = GaLoreProjector(rank=rank) + galore_proj.update_orthogonal_matrix(grad) + + if M >= N: + exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) + else: + exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) + exp_avg2 = exp_avg**2 + + return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params + + +def make_copy(*args): + return [t.detach().clone() for t in args] + + +def _ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + beta1=BETA1, + beta2=BETA2, + eps=EPS, + step_size=STEP_SIZE, + **kwargs, +): + + # Step 1: Down proj grad + M, N = grad.shape + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + low_rank_grad = a @ b + + # Step 2: update adam state + exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1)) + exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2) + denom = exp_avg2.sqrt().add_(eps) + low_rank_norm_grad = exp_avg / denom + + # Step 3: project normalized low rank grad to full rank + if M >= N: + a, b = low_rank_norm_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_norm_grad + full_grad_norm = a @ b + + # Finally, update params with updated grad + params.add_(full_grad_norm, alpha=-step_size) + + return exp_avg, exp_avg2, params + + +def _tt_hybrid( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + step_size=STEP_SIZE, + fp8_fast_accum=False, + allow_tf32=False, +): + M, N = grad.shape + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + low_rank_grad = a @ b + + exp_avg, exp_avg2, norm_grad = triton_adam_launcher( + exp_avg, exp_avg2, low_rank_grad, store=store + ) + + if M >= N: + a, b = low_rank_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_grad + params = triton_mm_launcher( + a, + b, + epilogue_alpha=-step_size, + epilogue_source=params, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ) + return exp_avg, exp_avg2, params + + +def _tt_fused( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + step_size=STEP_SIZE, + fp8_fast_accum=False, + allow_tf32=False, +): + M, N = grad.shape + + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher( + a, + b, + exp_avg=exp_avg, + exp_avg2=exp_avg2, + store=store, + fp8_fast_accum=fp8_fast_accum, + allow_tf32=allow_tf32, + ) + + if M >= N: + a, b = low_rank_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_grad + params = triton_mm_launcher( + a, + b, + epilogue_alpha=-step_size, + epilogue_source=params, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ) + return exp_avg, exp_avg2, params + + # logging.basicConfig(level=logging.INFO) + + +def get_kernel(kernel): + if kernel == "ref": + op = _ref_op + elif kernel == "ref": + op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") + elif kernel == "hybrid": + op = _tt_hybrid + elif kernel == "fused": + op = _tt_fused + else: + raise ValueError(f"Unknown kernel {kernel}") + + return lambda *args, **kwargs: op(*args, **kwargs) + + +def get_benchmark( + M, N, dtype, allow_tf32, fp8_fast_accum=False, quantiles=[0.5, 0.2, 0.8] +): + config = triton.testing.Benchmark( + x_names=["rank"], # Argument names to use as an x-axis for the plot + x_vals=[ + 32, + 64, + 128, + 256, + 512, + ], # Different possible values for `x_name` + line_arg="kernel", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=["torch", "hybrid", "fused", "compiled"], + # Label name for the lines + line_names=["torch", "hybrid", "fused", "compiled"], + # Line styles + styles=[("black", "-"), ("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="ms", # Label name for the y-axis + plot_name=f"Adam Kernel Comparison Grad shape: {M}x{N}, dtype: {dtype}, allow_tf32: {allow_tf32}\nMedian times (ms)", # Name for the plot, used also as a file name for saving the plot. + args={}, + ) + + def benchmark(rank, kernel): + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + + exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) + + if kernel == "torch": + ms, min_ms, max_ms = do_bench( + lambda: _ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ), + quantiles=quantiles, + ) + if kernel == "hybrid": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: _tt_hybrid( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ), + quantiles=quantiles, + ) + if kernel == "fused": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: _tt_fused( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ), + quantiles=quantiles, + ) + if kernel == "compiled": + compiled_op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: compiled_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ), + quantiles=quantiles, + ) + + return ms, max_ms, min_ms + + return triton.testing.perf_report(config)(benchmark) diff --git a/dev-requirements.txt b/dev-requirements.txt index 3d70d57ab9..74f75e9093 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -3,3 +3,7 @@ expecttest parameterized packaging transformers +bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers +matplotlib # needed for triton benchmarking +pandas # also for triton benchmarking +transformers #for galore testing \ No newline at end of file diff --git a/setup.py b/setup.py index 378339fa16..27c1f260e8 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ def read_requirements(file_path): "torchao.kernel.configs": ["*.pkl"], }, install_requires=read_requirements("requirements.txt"), + extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/test/galore/README.md b/test/galore/README.md new file mode 100644 index 0000000000..fc479267d8 --- /dev/null +++ b/test/galore/README.md @@ -0,0 +1,170 @@ +### GaLore Memory Profiler + +Tests memory usage of `GaLore` optimizers. + +Uses `torch.profiler` under the hood with additional options for `nsys`, [`torch.cuda.memory`](https://pytorch.org/docs/stable/torch_cuda_memory.html) analyses. + +Runs an untrained Llama model with configs for various model sizes (see `configs`) from the original GaLore [repo](https://github.com/jiaweizzhao/GaLore/tree/master/configs) on a sample batch of data for a configurable set of iterations. + +The results of the profiler are saved and can be analyzed using the provided notebook. + +#### Examples + +Run memory profiler with `torch.optim.AdamW` + +``` +python galore_mem_prof.py -t --optimizer=adamw +``` + +Run profiler with `GaLoreAdamW` reference implementation with rank 128 + +``` +python galore_mem_prof.py -t --optimizer=galore_adamw --rank=128 +``` + +More options + +``` +python profile_memory_usage.py --help + +usage: profile_memory_usage.py [-h] [-t] [-m] [-ns] [--optimizer {adamw,galore_adamw}] [--rank RANK] [--update_proj_gap UPDATE_PROJ_GAP] + [--galore_scale GALORE_SCALE] [--wait_steps WAIT_STEPS] [--warmup_steps WARMUP_STEPS] [--profiler_steps PROFILER_STEPS] + [--max_steps MAX_STEPS] [--model_config MODEL_CONFIG] [--data_path DATA_PATH] [--output_dir OUTPUT_DIR] [-lr LEARNING_RATE] + [--weight_decay WEIGHT_DECAY] [--seed SEED] + +options: + -h, --help show this help message and exit + -t, --torch_profiler Enable torch profiler (default: False) + -m, --torch_memory_snapshot + Enable torch memory snapshot (default: False) + -ns, --nsys_profiler Enable nsys profiling context managerSurrounds training loop with cudaProfilerApi.{Start,Stop} (default: False) + --optimizer {adamw,galore_adamw} + Which optimizer to use (default: adamw) + --rank RANK + --update_proj_gap UPDATE_PROJ_GAP + --galore_scale GALORE_SCALE + --wait_steps WAIT_STEPS + Number of steps to run before starting torch profiler (default: 0) + --warmup_steps WARMUP_STEPS + Number of warmup steps for torch profiler (default: 0) + --profiler_steps PROFILER_STEPS + Number of active steps for torch profiler (default: 5) + --max_steps MAX_STEPS + Max number of train steps to run.Total train steps will be min of `max_steps` and the sum of torch profiler steps (`wait_steps` + + `warmup_steps` + `profiler_steps`). (default: 100) + --model_config MODEL_CONFIG + Path to Llama config file see `https://github.com/jiaweizzhao/GaLore/tree/master/configs` (default: ./configs/llama_100m.json) + --data_path DATA_PATH + Path to sample batch (default: ./data/sample_batch.pt) + --output_dir OUTPUT_DIR + Directory for profiler outputs (default: profiler_out) + -lr LEARNING_RATE, --learning_rate LEARNING_RATE + Learning rate (default: 0.001) + --weight_decay WEIGHT_DECAY + Weight decay for AdamW (default: 0.01) + --seed SEED Random seed for torch (default: 0) +``` + +#### Analysis + +After running the `profile_memory_usage`, the output directory (defaults to `profiler_out`) will have three types of files: + +- `*.{json,html} - these are the memory trace exports of `torch.profiler` + - the `html` contains the memory timeline plot + - the `json` file contains the raw data for this plot, which can be analyzed to extract summary stats. + - `galore_memory_analysis.py` along with `galore_memory_analysis_utils.py` demonstrate such analysis. +- `*.json.gz` - these are the complete `torch.profiler` traces which can be viewed using `perfetto`. + +#### Preliminary Observations + +- Memory Usage over Time + + - We can see a long delay between the first backwards step for `GaLoreAdamW` due to the calculation of the projection matrix (calls `torch.linalg.svd` on the `grad`). + - To visualize, paste the following into a jupyter notebook (replacing the filenames with the those after running the profiler script): + + ```python + adamW_html_trace = "./profiler_out/adamw_04-09-23.html" + adamW8bit_html_trace = "./profiler_out/adamw8bit_04-11-01.html" + galore_adamw_128_html_trace = "./profiler_out/galore_adamw-128-1.0-50_04-09-23.html" + galore_adamw8bit_128_html_trace = "./profiler_out/galore_adamw8bit-128-1.0-50_04-11-01.html" + + plot_memory_timeline(adamW_html_trace) + plot_memory_timeline(adamW8bit_html_trace) + plot_memory_timeline(galore_adamw_128_html_trace) + plot_memory_timeline(galore_adamw8bit_128_html_trace) + ``` + +- Memory Usage Stats + + - Summary stats for memory usage by type as well as total across all types can be viewed by running the following in jupyter notebook, again replacing the respective filepaths: + + ```python + adamW_trace = "./profiler_out/adamw_04-11-21-memory-timeline.json" + adamW8bit_trace = "./profiler_out/adamw8bit_04-11-21-memory-timeline.json" + galore_adamW_trace_128 = "./profiler_out/galore_adamw-128-1.0-50_04-11-21-memory-timeline.json" + galore_adamW8bit_trace_128 = "./profiler_out/galore_adamw8bit-128-1.0-50_04-11-21-memory-timeline.json" + + adamW_df = create_mem_df(adamW_trace, units="MB") + adamW8bit_df = create_mem_df(adamW8bit_trace, units="MB") + galore_adamW_df_128 = create_mem_df(galore_adamW_trace_128, units="MB") + galore_adamW8bit_df_128 = create_mem_df(galore_adamW8bit_trace_128, units="MB") + + show_memory_stats(adamW_df) + show_memory_stats(adamW8bit_df) + show_memory_stats(galore_adamW_df_128) + show_memory_stats(galore_adamW8bit_df_128) + ``` + + The following are results from sample runs of `Llama1B` model config with the following optimizers (all units in MB): + +- torch.optim.AdamW + + | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | + | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | + | mean | 5,108.2 | 8,330.3 | 0.0 | 0.6 | 2,249.5 | 2,113.8 | 19.0 | 197.3 | 18,018.8 | + | min | 5,108.2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | + | median | 5,108.2 | 10,216.4 | 0.0 | 0.0 | 2,151.1 | 1,930.1 | 10.0 | 16.3 | 20,306.5 | + | max | 5,108.3 | 10,216.4 | 0.3 | 20.0 | 5,946.4 | 5,108.2 | 312.2 | 5,124.4 | 25,557.3 | + +- GaLoreAdamW reference, rank 128 + + | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | + | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | + | mean | 7,298.0 | 1,348.4 | 0.0 | 0.7 | 1,455.6 | 3,183.6 | 12.2 | 31.3 | 13,330.0 | + | min | 5,108.2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | + | median | 7,796.2 | 1,576.7 | 0.0 | 0.0 | 545.4 | 3,898.2 | 0.0 | 26.2 | 14,422.8 | + | max | 8,047.2 | 1,576.7 | 0.3 | 42.7 | 5,960.0 | 5,108.2 | 312.2 | 518.2 | 15,349.2 | + +- bitsandbytes AdamW8bit + + | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | + | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | + | mean | 5,108.2 | 2,047.4 | 0.0 | 0.7 | 2,390.0 | 1,925.2 | 20.1 | 20.3 | 11,511.9 | + | min | 5,108.2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | + | median | 5,108.2 | 2,560.4 | 0.0 | 0.0 | 2,351.0 | 1,738.1 | 10.0 | 16.3 | 12,621.3 | + | max | 5,108.3 | 2,560.4 | 0.3 | 20.0 | 5,946.4 | 5,108.2 | 312.2 | 46.9 | 13,631.3 | + +- GaLore AdamW8bit + + | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | + | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | + | mean | 4,971.0 | 334.7 | 0.1 | 0.8 | 1,644.0 | 2,130.9 | 13.8 | 2,360.3 | 11,455.6 | + | min | 500.4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | + | median | 5,108.2 | 395.6 | 0.0 | 0.0 | 1,076.4 | 2,106.1 | 0.0 | 2,704.3 | 11,673.8 | + | max | 5,153.5 | 395.6 | 85.4 | 42.7 | 5,947.8 | 5,109.2 | 312.2 | 7,685.4 | 14,155.9 | + +- The `optimizer state` is indeed smaller for the `GaLoreAdamW` optimizer. +- Interestingly, the `Parameter` sizes balloons in the `GaLore` optimizer, likely due to extra data copies. Admittedly, the implementation is only a reference (per original repo) and leaves much room for optimization. +- The memory usage is in terms of memory allocated, which we can confirm by printing the max cuda memory allocated vs reserved (which the profiler script prints automatically). +- The `Total` column shows the allocation stats across all categories across all sampled timepoints. (Should not be interpreted as the row-wise sums). + +**NOTE**: The `json` output of the torch profiler memory trace is unlabeled. However, we can infer -- and confirm -- the labels by comparing the plots of the parsed dataframe with that of the direct `html` export of the profiler. + +- For example, after creating the dataframes per above, the following will plot the raw data, which should roughly reproduce the direct `html` export from `torch.profiler`, albeit with different timescale: + +```python +_ = adamW_df.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) +_ = adamW8bit_df.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) +_ = galore_adamW_df_128.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) +_ = galore_adamW8bit_df_128.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) +``` diff --git a/test/galore/memory_analysis_utils.py b/test/galore/memory_analysis_utils.py new file mode 100644 index 0000000000..6e464e8766 --- /dev/null +++ b/test/galore/memory_analysis_utils.py @@ -0,0 +1,73 @@ +from functools import partial + +import pandas as pd +from IPython.display import HTML + + +def plot_memory_timeline(trace_file): + """Plots html output of torch profiler memory trace + For use within Jupyter Notebook only! + See https://pytorch.org/docs/main/profiler.html#torch.profiler._KinetoProfile.export_memory_timeline + + Args: + trace_file: path to html export of torch profiler memory timeline + """ + with open(trace_file) as f: + return HTML(f.read()) + + +# These are the (unlabeled) columns in the json export of a torch profiler memory timeline trace +COL_NAMES = [ + "Parameter", + "Optimizer_State", + "Input", + "Temporary", + "Activation", + "Gradient", + "Autograd_Detail", + "Unknown", +] + + +def create_mem_df(mem_trace, units="GB"): + """Create dataframe from json export of torch profiler CUDA memory timeline trace + Columns per COL_NAMES, in units of MB + These are the (unlabeled) columns in the json export of a torch profiler memory timeline trace but can be + inferred (and confirmed) by comparing the plots of the json export with the plots of the html export + + E.g., df.plot(kind="area", stacked=True, ylabel="MB") + + See https://pytorch.org/docs/main/profiler.html#torch.profiler._KinetoProfile.export_memory_timeline + Args: + mem_trace: path to json export of torch profiler memory timeline + units: "MB" or "GB" + """ + df = pd.read_json(mem_trace).T.explode(0) + + def _convert_to_units(df, col): + return df[col] / 1024 ** (3 if units == "GB" else 2) + + convert_cols_to_MB = {col: partial(_convert_to_units, col=col) for col in COL_NAMES} + + df = pd.DataFrame( + [l[1:] for l in df.iloc[:, 1].to_list()], columns=COL_NAMES + ).assign(**convert_cols_to_MB) + df["Total"] = df.sum(axis=1) + return df + + +def show_memory_stats(df, stats=["mean", "min", "50%", "max"]): + """Show summary statistics for torch profiler CUDA memory timeline trace + Args: + df: dataframe created by create_mem_df + stats: list of statistics to show. Valid stats are "mean", "min", "25%", "50%", "75%", "max" + + """ + mem_sum = ( + df.describe() + .loc[stats] + .rename(index={"50%": "median"}) + .style.format(precision=1, thousands=",") + ) + + return mem_sum diff --git a/test/galore/model_configs.py b/test/galore/model_configs.py new file mode 100644 index 0000000000..358f5a6868 --- /dev/null +++ b/test/galore/model_configs.py @@ -0,0 +1,176 @@ +# LLAMA100M = { +# "architectures": ["LLaMAForCausalLM"], +# "attention_bias": False, +# "attention_dropout": 0.0, +# "bos_token_id": 0, +# "eos_token_id": 1, +# "hidden_act": "silu", +# "hidden_size": 640, +# "initializer_range": 0.02, +# "intermediate_size": 1708, +# "max_position_embeddings": 2048, +# "max_sequence_length": 1024, +# "model_type": "llama", +# "num_attention_heads": 10, +# "num_hidden_layers": 12, +# "num_key_value_heads": 10, +# "pad_token_id": -1, +# "pretraining_tp": 1, +# "rms_norm_eps": 1e-06, +# "rope_scaling": None, +# "rope_theta": 10000.0, +# "tie_word_embeddings": False, +# "transformers_version": "4.39.3", +# "use_cache": True, +# "vocab_size": 32100, +# } +LLAMA1B = { + "vocab_size": 32000, + "max_position_embeddings": 2048, + "hidden_size": 2048, + "intermediate_size": 5461, + "num_hidden_layers": 24, + "num_attention_heads": 32, + "num_key_value_heads": 32, + "hidden_act": "silu", + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "return_dict": True, + "output_hidden_states": False, + "output_attentions": False, + "torchscript": False, + "torch_dtype": None, + "use_bfloat16": False, + "tf_legacy_loss": False, + "pruned_heads": {}, + "tie_word_embeddings": False, + "chunk_size_feed_forward": 0, + "is_encoder_decoder": False, + "is_decoder": False, + "cross_attention_hidden_size": None, + "add_cross_attention": False, + "tie_encoder_decoder": False, + "max_length": 20, + "min_length": 0, + "do_sample": False, + "early_stopping": False, + "num_beams": 1, + "num_beam_groups": 1, + "diversity_penalty": 0.0, + "temperature": 1.0, + "top_k": 50, + "top_p": 1.0, + "typical_p": 1.0, + "repetition_penalty": 1.0, + "length_penalty": 1.0, + "no_repeat_ngram_size": 0, + "encoder_no_repeat_ngram_size": 0, + "bad_words_ids": None, + "num_return_sequences": 1, + "output_scores": False, + "return_dict_in_generate": False, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "remove_invalid_values": False, + "exponential_decay_length_penalty": None, + "suppress_tokens": None, + "begin_suppress_tokens": None, + "architectures": ["LLaMAForCausalLM"], + "finetuning_task": None, + "id2label": {0: "LABEL_0", 1: "LABEL_1"}, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "tokenizer_class": None, + "prefix": None, + "bos_token_id": 0, + "pad_token_id": -1, + "eos_token_id": 1, + "sep_token_id": None, + "decoder_start_token_id": None, + "task_specific_params": None, + "problem_type": None, + "_name_or_path": "./configs/llama_1b.json", + "transformers_version": "4.39.3", + "max_sequence_length": 1024, + "model_type": "llama", +} +LLAMA100M = { + "vocab_size": 32100, + "max_position_embeddings": 2048, + "hidden_size": 640, + "intermediate_size": 1708, + "num_hidden_layers": 12, + "num_attention_heads": 10, + "num_key_value_heads": 10, + "hidden_act": "silu", + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "use_cache": True, + "rope_theta": 10000.0, + "rope_scaling": None, + "attention_bias": False, + "attention_dropout": 0.0, + "return_dict": True, + "output_hidden_states": False, + "output_attentions": False, + "torchscript": False, + "torch_dtype": None, + "use_bfloat16": False, + "tf_legacy_loss": False, + "pruned_heads": {}, + "tie_word_embeddings": False, + "chunk_size_feed_forward": 0, + "is_encoder_decoder": False, + "is_decoder": False, + "cross_attention_hidden_size": None, + "add_cross_attention": False, + "tie_encoder_decoder": False, + "max_length": 20, + "min_length": 0, + "do_sample": False, + "early_stopping": False, + "num_beams": 1, + "num_beam_groups": 1, + "diversity_penalty": 0.0, + "temperature": 1.0, + "top_k": 50, + "top_p": 1.0, + "typical_p": 1.0, + "repetition_penalty": 1.0, + "length_penalty": 1.0, + "no_repeat_ngram_size": 0, + "encoder_no_repeat_ngram_size": 0, + "bad_words_ids": None, + "num_return_sequences": 1, + "output_scores": False, + "return_dict_in_generate": False, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "remove_invalid_values": False, + "exponential_decay_length_penalty": None, + "suppress_tokens": None, + "begin_suppress_tokens": None, + "architectures": ["LLaMAForCausalLM"], + "finetuning_task": None, + "id2label": {0: "LABEL_0", 1: "LABEL_1"}, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "tokenizer_class": None, + "prefix": None, + "bos_token_id": 0, + "pad_token_id": -1, + "eos_token_id": 1, + "sep_token_id": None, + "decoder_start_token_id": None, + "task_specific_params": None, + "problem_type": None, + "_name_or_path": "./configs/llama_100m.json", + "transformers_version": "4.39.3", + "max_sequence_length": 1024, + "model_type": "llama", +} diff --git a/test/galore/profile_memory_usage.py b/test/galore/profile_memory_usage.py new file mode 100644 index 0000000000..fce4e18a87 --- /dev/null +++ b/test/galore/profile_memory_usage.py @@ -0,0 +1,292 @@ +import argparse +import contextlib +import logging +import os + +import model_configs +import profiling_utils +import torch +import torch.nn as nn +import torch.utils.data +from bitsandbytes.optim import AdamW8bit +from torch.profiler import record_function +from transformers import LlamaConfig, LlamaForCausalLM + +from torchao.prototype.galore.optim.galore_torch import AdamW as GaLoreAdamW +from torchao.prototype.galore.optim.galore_torch import AdamW8bit as GaLoreAdamW8bit + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def setup_galore(model, lr, weight_decay, rank, galore_scale, update_proj_gap): + galore_params = [] + target_modules_list = ["attn", "mlp"] + for module_name, module in model.named_modules(): + if not isinstance(module, nn.Linear): + continue + + if not any(target_key in module_name for target_key in target_modules_list): + continue + + logger.debug("Enabling GaLore for weights in module: ", module_name) + galore_params.append(module.weight) + id_galore_params = [id(p) for p in galore_params] + # make parameters without "rank" to another group + regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] + # then call galore_adamw + + total_galore_params = sum(p.numel() for p in galore_params) + total_regular_params = sum(p.numel() for p in regular_params) + total_params = sum(p.numel() for p in model.parameters()) + assert total_galore_params + total_regular_params == total_params + + print( + f"Total params: {total_params} = GaLore params: {total_galore_params} + Regular params: {total_regular_params}" + ) + param_groups = [ + {"params": regular_params}, + { + "params": galore_params, + "rank": rank, + "update_proj_gap": update_proj_gap, + "scale": galore_scale, + "proj_type": "std", + }, + ] + if "adamw" in args.optimizer: + if "8bit" in args.optimizer: + optimizer = GaLoreAdamW8bit(param_groups, lr=lr, weight_decay=weight_decay) + else: + optimizer = GaLoreAdamW(param_groups, lr=lr, weight_decay=weight_decay) + else: + raise ValueError(f"Unknown optimizer: {args.optimizer}") + return optimizer + + +def train_step(model, batch, labels, optimizer, profiler=None): + with record_function("MODEL_FORWARD"): + loss = model(**batch, labels=labels).loss + + with record_function("MODEL_BACKWARD"): + loss.backward() + + with record_function("OPTIMIZER_STEP"): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + if profiler: + profiler.step() + + +def run(args, file_prefix): + torch.manual_seed(args.seed) + + # Initialize model from config dict + model_config = LlamaConfig() + try: + model_config_dict = getattr(model_configs, args.model_config.upper()) + except: + raise ValueError(f"Model config {args.model_config} not found") + model_config.update(model_config_dict) + model = LlamaForCausalLM(model_config).to("cuda") + + # Load sample batch + input_ids = torch.randint( + 0, + model_config.vocab_size, + size=(args.batch_size, args.max_seq_len), + dtype=torch.int64, + device="cuda", + ) + attention_mask = torch.ones_like(input_ids) + batch = dict(input_ids=input_ids, attention_mask=attention_mask) + labels = batch["input_ids"].clone() + + n_total_params = sum(p.numel() for p in model.parameters()) + trainable_params = [p for p in model.parameters() if p.requires_grad] + print( + f"Trainable params: {sum(p.numel() for p in trainable_params)} / {n_total_params}" + ) + + if args.optimizer.lower() == "adamw": + optimizer = torch.optim.AdamW( + trainable_params, lr=args.learning_rate, weight_decay=args.weight_decay + ) + + elif "galore" in args.optimizer.lower(): + optimizer = setup_galore( + model, + args.learning_rate, + args.weight_decay, + rank=args.rank, + galore_scale=args.galore_scale, + update_proj_gap=args.update_proj_gap, + ) + elif args.optimizer.lower() == "adamw8bit": + optimizer = AdamW8bit( + trainable_params, lr=args.learning_rate, weight_decay=args.weight_decay + ) + else: + raise "Unsupported optimizer" + + if args.torch_profiler: + prof_ctx = profiling_utils.get_torch_profiler( + name=file_prefix, + output_dir=args.output_dir, + wait_steps=args.wait_steps, + warmup_steps=args.warmup_steps, + active_steps=args.profiler_steps, + ) + elif args.nsys_profiler: + prof_ctx = profiling_utils.nsys_profiler() + else: + prof_ctx = contextlib.nullcontext() + + total_steps = min( + args.wait_steps + args.warmup_steps + args.profiler_steps, args.max_steps + ) + print( + f"Profiling {args.model_config} with {args.optimizer.upper()} for {total_steps} steps (wait_steps={args.wait_steps}, warmup_steps={args.warmup_steps}, profiler_steps={args.profiler_steps})" + ) + with prof_ctx as prof: + logger.debug(f"Profiler: {prof}") + for _ in range(total_steps): + with record_function("TRAIN_STEP"): + train_step( + model, + batch, + labels, + optimizer, + profiler=prof if args.torch_profiler else None, + ) + if args.torch_profiler: + print(f"Finished profiling, outputs saved to {args.output_dir}/{file_prefix}*") + else: + print(f"Finished profiling") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-t", "--torch_profiler", action="store_true", help="Enable torch profiler" + ) + parser.add_argument( + "-m", + "--torch_memory_snapshot", + action="store_true", + help="Enable torch memory snapshot", + ) + + parser.add_argument( + "-ns", + "--nsys_profiler", + action="store_true", + help="Enable nsys profiling context manager" + "Surrounds training loop with cudaProfilerApi.{Start,Stop}", + ) + parser.add_argument( + "--optimizer", + default="adamw", + type=str, + choices=["adamw", "galore_adamw", "adamw8bit", "galore_adamw8bit"], + help="Which optimizer to use", + ) + parser.add_argument("--rank", type=int, default=128) + parser.add_argument("--update_proj_gap", type=int, default=50) + parser.add_argument("--galore_scale", type=float, default=1.0) + # parser.add_argument("--proj_type", type=str, default="std") + parser.add_argument( + "--wait_steps", + type=int, + default=0, + help="Number of steps to run before starting torch profiler", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=0, + help="Number of warmup steps for torch profiler", + ) + + parser.add_argument( + "--profiler_steps", + type=int, + default=5, + help="Number of active steps for torch profiler", + ) + parser.add_argument( + "--max_steps", + type=int, + default=100, + help="Max number of train steps to run." + "Total train steps will be min of `max_steps` and the sum of torch profiler steps (`wait_steps` + `warmup_steps` + `profiler_steps`).", + ) + parser.add_argument( + "--model_config", + default="llama100M", + type=str, + choices=["llama100M", "llama1B"], + help="Model configuration (see model_configs.py)", + ) + parser.add_argument( + "--batch_size", default=5, type=int, help="Batch size to use for train step" + ) + parser.add_argument( + "--max_seq_len", + default=256, + type=int, + help="Sequence length to use for train step, should be less than that in the specific model config", + ) + parser.add_argument( + "--output_dir", + default="profiler_out", + type=str, + help="Directory for profiler outputs", + ) + + parser.add_argument( + "-lr", + "--learning_rate", + default=1e-3, + type=float, + help="Learning rate", + ) + parser.add_argument( + "--weight_decay", + default=1e-2, + type=float, + help="Weight decay for AdamW", + ) + + parser.add_argument("--seed", default=0, type=int, help="Random seed for torch") + args = parser.parse_args() + output_dir = args.output_dir + # output_prefix = args.output_prefix + if not os.path.exists(output_dir): + os.makedirs(output_dir) + if "galore" not in args.optimizer.lower(): + file_prefix = args.optimizer.lower() + else: + file_prefix = "-".join( + [ + args.optimizer.lower(), + str(args.rank), + str(args.galore_scale), + str(args.update_proj_gap), + ] + ) + mem_ctx = ( + profiling_utils.memory_recorder( + file_name=os.path.join(output_dir, f"{file_prefix}-memory-snapshot") + ) + if args.torch_memory_snapshot + else contextlib.nullcontext() + ) + profiling_utils.flush_cuda_mem() + with mem_ctx: + run(args, file_prefix) + + profiling_utils.get_cuda_memory_usage(units="MB", show=True) diff --git a/test/galore/profiling_utils.py b/test/galore/profiling_utils.py new file mode 100644 index 0000000000..80d6c03d84 --- /dev/null +++ b/test/galore/profiling_utils.py @@ -0,0 +1,193 @@ +import gc +import logging +import os +from contextlib import contextmanager +from datetime import datetime +from functools import partial + +import torch + +logging.basicConfig( + format="%(levelname)s:%(asctime)s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(level=logging.INFO) + +TIME_FORMAT_STR: str = "%m-%d-%H" + +# Keep a max of 100,000 alloc/free events in the recorded history +# leading up to the snapshot. +MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000 + + +def flush_cuda_mem(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() + torch.cuda.reset_accumulated_memory_stats() + + +@contextmanager +def cuda_max_memory(): + try: + flush_cuda_mem() + yield + + finally: + mem_miB = torch.cuda.max_memory_allocated() // (1024 * 1024) + print(f"{mem_miB} MB of CUDA memory allocated") + flush_cuda_mem() + return mem_miB + + +def get_cuda_memory_usage(units="MB", show=True): + """ + Get maximum allocated / reserved CUDA memory in given units + + Args: + units: MB, GB, or B + """ + units = units.upper() + if units == "MB": + divisor = 1024**2 + elif units == "GB": + divisor = 1024**3 + else: + units = "B" + divisor = 1 + max_memory_allocated = torch.cuda.max_memory_allocated() / divisor + max_memory_reserved = torch.cuda.max_memory_reserved() / divisor + if show: + print( + "Max Memory Allocated:", + f"{max_memory_allocated:,.1f} {units}", + ) + print( + "Max Memory Reserved:", + f"{max_memory_reserved:,.1f} {units}", + ) + + return max_memory_allocated, max_memory_reserved + + +def export_memory_snapshot(prefix) -> None: + + # Prefix for file names. + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = f"{prefix}_{timestamp}" + + try: + logger.info(f"Saving snapshot to local file: {file_prefix}.pickle") + torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle") + except Exception as e: + logger.error(f"Failed to capture memory snapshot {e}") + return + + +@contextmanager +def memory_recorder(file_name="cuda_memory_snapshot", export=False) -> None: + assert ( + torch.cuda.is_available() + ), "Memory profiler requires GPU, check torch.cuda.is_available()" + try: + logger.info("Starting snapshot record_memory_history") + torch.cuda.memory._record_memory_history( + max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT + ) + yield + finally: + logger.info("Stopping snapshot record_memory_history") + torch.cuda.memory._record_memory_history(enabled=None) + if export: + export_memory_snapshot(file_name) + + +def trace_handler( + prof: torch.profiler.profile, + prefix: str = "profile", + output_dir="./", + sort_key="cuda_time_total", + export_trace=True, + export_memory_timeline=True, + print_table=True, +): + + timestamp = datetime.now().strftime(TIME_FORMAT_STR) + file_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}") + + if export_trace: + prof.export_chrome_trace(f"{file_prefix}-trace.json.gz") + + if export_memory_timeline: + prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0") + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.json", device="cuda:0" + ) + if print_table: + print(prof.key_averages().table(sort_by=sort_key, row_limit=10)) + + +def get_torch_profiler( + name: str = "profile", + output_dir: str = "./profiler_out", + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + wait_steps=1, + warmup_steps=1, + active_steps=10, + repeat=1, + # options for profiler outputs + on_trace_ready=trace_handler, + export_trace=True, + export_memory_timeline=True, + print_table=True, +): + """ + Args: + name: name of the profiler, used for output files + table_key: key to sort profiler table by: one of `cpu_time`, `cuda_time`, `cpu_time_total`, + `cuda_time_total`, `cpu_memory_usage`, `cuda_memory_usage`, + `self_cpu_memory_usage`, `self_cuda_memory_usage`, `count`. + + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + return torch.profiler.profile( + activities=activities, + schedule=torch.profiler.schedule( + wait=wait_steps, warmup=warmup_steps, active=active_steps, repeat=repeat + ), + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + on_trace_ready=partial( + on_trace_ready, + prefix=name, + output_dir=output_dir, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + print_table=print_table, + ), + ) + + +@contextmanager +def nsys_profiler(): + try: + torch.cuda.cudart().cudaProfilerStart() + free, total = torch.cuda.mem_get_info() + print(f"Start, Memory Usage: Free {free:.2e}, Used {(total - free):.2e}") + yield "nsys" + finally: + free, total = torch.cuda.mem_get_info() + print(f"End, Memory Usage: Free {free:.2e}, Used {(total - free):.2e}") + torch.cuda.cudart().cudaProfilerStop() diff --git a/test/kernel/galore_test_utils.py b/test/kernel/galore_test_utils.py new file mode 100644 index 0000000000..1e83b7b48f --- /dev/null +++ b/test/kernel/galore_test_utils.py @@ -0,0 +1,176 @@ +import torch + +from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher +from torchao.prototype.galore.kernels.adam_downproj_fused import ( + set_tuner_top_k as adam_downproj_tuner_topk, +) +from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher +from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk +from torchao.prototype.galore.kernels.matmul import triton_mm_launcher +from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector + +torch.manual_seed(0) + +adam_downproj_tuner_topk(10) +matmul_tuner_topk(10) + +BETA1 = 0.9 +BETA2 = 0.999 +EPS = 1e-8 +STEP_SIZE = 1e-4 + + +def make_data(M, N, rank, dtype): + grad = torch.randn(M, N, device="cuda", dtype=dtype) + params = torch.randn(M, N, device="cuda", dtype=dtype) + + galore_proj = GaLoreProjector(rank=rank) + galore_proj.update_orthogonal_matrix(grad) + + if M >= N: + exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) + else: + exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) + exp_avg2 = exp_avg**2 + + return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params + + +def make_copy(*args): + return [t.detach().clone() for t in args] + + +def _ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + beta1=BETA1, + beta2=BETA2, + eps=EPS, + step_size=STEP_SIZE, + **kwargs, +): + + # Step 1: Down proj grad + M, N = grad.shape + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + low_rank_grad = a @ b + + # Step 2: update adam state + exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1)) + exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2) + denom = exp_avg2.sqrt().add_(eps) + low_rank_norm_grad = exp_avg / denom + + # Step 3: project normalized low rank grad to full rank + if M >= N: + a, b = low_rank_norm_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_norm_grad + full_grad_norm = a @ b + + # Finally, update params with updated grad + params.add_(full_grad_norm, alpha=-step_size) + + return exp_avg, exp_avg2, params + + +def _tt_hybrid( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + step_size=STEP_SIZE, + fp8_fast_accum=False, + allow_tf32=False, +): + M, N = grad.shape + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + low_rank_grad = a @ b + + exp_avg, exp_avg2, norm_grad = triton_adam_launcher( + exp_avg, exp_avg2, low_rank_grad, store=store + ) + + if M >= N: + a, b = low_rank_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_grad + params = triton_mm_launcher( + a, + b, + epilogue_alpha=-step_size, + epilogue_source=params, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ) + return exp_avg, exp_avg2, params + + +def _tt_fused( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + step_size=STEP_SIZE, + fp8_fast_accum=False, + allow_tf32=False, +): + M, N = grad.shape + + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher( + a, + b, + exp_avg=exp_avg, + exp_avg2=exp_avg2, + store=store, + fp8_fast_accum=fp8_fast_accum, + allow_tf32=allow_tf32, + ) + + if M >= N: + a, b = low_rank_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_grad + params = triton_mm_launcher( + a, + b, + epilogue_alpha=-step_size, + epilogue_source=params, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ) + return exp_avg, exp_avg2, params + + # logging.basicConfig(level=logging.INFO) + + +def get_kernel(kernel): + if kernel == "ref": + op = _ref_op + elif kernel == "ref": + op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") + elif kernel == "hybrid": + op = _tt_hybrid + elif kernel == "fused": + op = _tt_fused + else: + raise ValueError(f"Unknown kernel {kernel}") + + return lambda *args, **kwargs: op(*args, **kwargs) diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py new file mode 100644 index 0000000000..b43abead45 --- /dev/null +++ b/test/kernel/test_fused_kernels.py @@ -0,0 +1,111 @@ +import itertools + +import pytest + +# Skip entire test if triton is not available, otherwise CI failure +try: + import triton +except ImportError: + pytest.skip("triton is not installed", allow_module_level=True) + +import torch +from galore_test_utils import get_kernel, make_copy, make_data + +torch.manual_seed(0) +MAX_DIFF_no_tf32 = 1e-5 +MAX_DIFF_tf32 = 1e-3 + + +def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): + # Copy to use for first run -- needed because of autotuning and inplace ops + ( + exp_avg_autotune_copy, + exp_avg2_autotune_copy, + grad_autotune_copy, + proj_matrix_autotune_copy, + params_autotune_copy, + ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) + + # Copy to use for second run to check accuracy + ( + exp_avg_test_copy, + exp_avg2_test_copy, + grad_test_copy, + proj_matrix_test_copy, + params_test_copy, + ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) + + print( + f"Running with {grad.shape[0]} x {grad.shape[1]} grad (param) shape, GaLore orthogonal matrix {list(proj_matrix.shape)}, dtype {grad.dtype} and allow_tf32 {allow_tf32}\n" + f"Kernel: {kernel}", + flush=True, + ) + + ref_op = get_kernel("ref") + test_op = get_kernel(kernel) + + # Reference run + ref_out = ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ) + + # Autotune + _ = test_op( + grad_autotune_copy, + proj_matrix_autotune_copy, + exp_avg_autotune_copy, + exp_avg2_autotune_copy, + params_autotune_copy, + store=False, + allow_tf32=allow_tf32, + ) + + # Accuracy run + test_out = test_op( + grad_test_copy, + proj_matrix_test_copy, + exp_avg_test_copy, + exp_avg2_test_copy, + params_test_copy, + store=True, + allow_tf32=allow_tf32, + ) + print("Accuracy:") + + output_names = [ + "adam state - running grad mean", + "adam state - running grad var", + "params (after update)", + ] + MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 + for name, ref, tt in zip(output_names, ref_out, test_out): + max_diff = (ref - tt).abs().max() + print(f"-> {name}:\n Max err: {max_diff:.6f}") + assert max_diff < MAX_DIFF + + +KERNELS = ["hybrid"] # "fused"] +DTYPES = [torch.float32] # torch.float16 +ROW_DIMS = [4096] +COL_DIMS = [4096] # , 11008] +RANKS = [128] +ALLOW_TF32 = [False] # , True] + +TEST_CONFIGS = list( + itertools.product(KERNELS, DTYPES, ROW_DIMS, COL_DIMS, RANKS, ALLOW_TF32) +) + +# TEST_CONFIGS = TEST_CONFIGS[0:1] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) +def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + + exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) + run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32) diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py new file mode 100644 index 0000000000..cc06b6812e --- /dev/null +++ b/test/kernel/test_galore_downproj.py @@ -0,0 +1,49 @@ +import pytest + +# Skip entire test if triton is not available, otherwise CI failure +try: + import triton +except ImportError: + pytest.skip("triton is not installed", allow_module_level=True) + +import torch +from galore_test_utils import make_data + +from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk +from torchao.prototype.galore.kernels.matmul import triton_mm_launcher +from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector + +torch.manual_seed(0) + +matmul_tuner_topk(10) +MAX_DIFF_no_tf32 = 1e-4 +MAX_DIFF_tf32 = 1e-2 + + +TEST_CONFIGS = [ + # (4096, 4096, 128, True, False, torch.float32), + (4096, 4096, 128, False, False, torch.float32), + # (4096, 11008, 128, True, False, torch.float32), + (4096, 11008, 128, False, False, torch.float32), +] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) +def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 + exp_avg, exp_avg2, grad, galore_proj, params = make_data(M, N, rank, dtype) + + if M >= N: + a, b = grad, galore_proj.t() + else: + a, b = galore_proj.t(), grad + low_rank_ref = lambda: a @ b + low_rank_tt = lambda: triton_mm_launcher( + a, b, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum + ) + diff = torch.max(torch.abs(low_rank_ref() - low_rank_tt())) + if not diff < MAX_DIFF: + print("diff: ", torch.max(torch.abs(low_rank_ref() - low_rank_tt()))) + assert diff < MAX_DIFF diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py new file mode 100644 index 0000000000..1eabf479ce --- /dev/null +++ b/test/quantization/test_galore_quant.py @@ -0,0 +1,91 @@ +import itertools + +import pytest + +# Skip entire test if triton is not available, otherwise CI failure +try: + import triton +except ImportError: + pytest.skip("triton is not installed", allow_module_level=True) + +import bitsandbytes.functional as F +import torch + +from torchao.prototype.galore.kernels import ( + triton_dequant_blockwise, + triton_quantize_blockwise, +) + +SEED = 0 +torch.manual_seed(SEED) + +DIM1 = [64, 1024, 4096] +DIM2 = [1024, 2048, 4096] +SIGNS = [True, False] +DTYPES = [torch.float32] # , torch.float16] +BLOCKSIZE = [2048] + +TEST_CONFIGS = list(itertools.product(DIM1, DIM2, DTYPES, SIGNS, BLOCKSIZE)) + + +@pytest.mark.skip("skipping for now, see comments below") +@pytest.mark.parametrize( + "dim1,dim2,dtype,signed,blocksize", + TEST_CONFIGS, +) +def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): + g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 + + qmap = F.create_dynamic_map(signed).to(g.device) + + ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) + bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape) + + tt_q, tt_norm, tt_absmax = triton_quantize_blockwise( + g, qmap, group_size=blocksize, return_normalized=True + ) + tt_check = torch.allclose(ref_bnb, tt_q) + + # see notes.md under `prototype.galore.kernels` for an explanation of the following conditions + if not tt_check: + print( + f"Failed quantization check for {dim1} x {dim2}, {dtype}, signed {signed}" + ) + print(f"Absmax: {(qstate.absmax - tt_absmax).abs().max()}") + print(f"Norm diff: {(bnb_norm - tt_norm).abs().max()}") + + idx_diff = (ref_bnb != tt_q).to("cuda") + print(f"Num code idx diffs: {idx_diff.sum()}") + max_idx_diff = (ref_bnb - tt_q).abs().max() + print(f"Max code idx diff: {max_idx_diff}") + + # This below checks that the value being quantized falls half-way between two code buckets + # where bitsandbytes assigns to one and the triton implementation assigns to the other + # Since either bucket is technically valid, we only check that the distance between the value and the + # adjacent buckets are the same. I.e., we don't require that the triton implementation exactly matches + # bitsandbytes. + + bnb_code = qmap[ref_bnb[idx_diff].tolist()] + tt_code = qmap[tt_q[idx_diff].tolist()] + bnb_dist = torch.abs(bnb_code - bnb_norm[idx_diff]) + torch_dist = torch.abs(tt_code - bnb_norm[idx_diff]) + + dist_sum = torch.sum(bnb_dist - torch_dist) + print(f"Distance sum: {torch.sum(bnb_dist - torch_dist)}") + assert tt_check or (not tt_check and dist_sum < 1e-4) + + +@pytest.mark.parametrize( + "dim1,dim2,dtype,signed,blocksize", + TEST_CONFIGS, +) +def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): + g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 + + qmap = F.create_dynamic_map(signed).to(g.device) + + q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) + + dq_ref = F.dequantize_blockwise(q, qstate) + dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize) + assert torch.allclose(dq, dq_ref) diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md new file mode 100644 index 0000000000..02ee2dd3be --- /dev/null +++ b/torchao/prototype/README.md @@ -0,0 +1,19 @@ +# Prototype + +### Experimental kernels and utilities for efficient inference and training + +> The goal isn't to reproduce all emerging methodologies but to extract common components across prevalent, proven paradigms that can be modularized and composed with the `torch` stack as well as other OSS ML frameworks. + +#### Code structure + +- `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) + - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm + - `galore/docs` - implementation notes and discussion of issues faced in kernel design. + +#### Roadmap + +- `hqq`, `awq`, `marlin`,`QuaRot`, and other well-researched methodologies for quantized fine-tuning and inference. + - ideally, techniques that are both **theoretically sound** and have **practical hardware-aware implementations** + - AWQ and GPTQ are good examples. +- `cutlass` / `triton` utilities for common quantization ops (numeric conversion, quant / dequant, mixed type gemm, etc.) + - goal is to create a set of kernels and components that can expedite the implementation & optimization across the spectrum of quantization, fine-tuning, and inference patterns. diff --git a/torchao/prototype/__init__.py b/torchao/prototype/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/galore/README.md b/torchao/prototype/galore/README.md new file mode 100644 index 0000000000..2a7ae1f7d9 --- /dev/null +++ b/torchao/prototype/galore/README.md @@ -0,0 +1,11 @@ +## Fused GaLore + +### Experimental kernels for fusing various parts of the GaLore algorithm + +#### AdamW + +See `docs/galore_adam.md` for implementation notes. + +#### AdamW8bit + +See `docs/galore_adam8bit.md` for implementation notes. diff --git a/torchao/prototype/galore/__init__.py b/torchao/prototype/galore/__init__.py new file mode 100644 index 0000000000..4e1edc4039 --- /dev/null +++ b/torchao/prototype/galore/__init__.py @@ -0,0 +1 @@ +from .kernels import * diff --git a/torchao/prototype/galore/docs/README.md b/torchao/prototype/galore/docs/README.md new file mode 100644 index 0000000000..74b077c4a9 --- /dev/null +++ b/torchao/prototype/galore/docs/README.md @@ -0,0 +1,198 @@ +## Fused GaLore Adam (WIP) + +### Various fused implementations of `Adam` update step per [Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) + +This is an initial attempt at optimizing the update step of the `GaLore Adam` optimizer. + +#### Overview + +The `GaLore` `Adam` optimizer introduces additional ops to the traditional `adam` update step. + +Specifically: + +1. `grad` is projected to low rank --> additional matmul +2. `adam` states are updated with `grad` elementwise (same as `Adam` except in low-rank) +3. normalized `grad` is projected to full rank --> additional matmul +4. `params` are updated with the normalized full rank grad + +#### Implementation + +Various fusions were attempted across 2 kernel implementations: + +- `Fused` + - Steps 1 & 2 are fused: the `adam` state updates are loaded and updated (inplace) during the first `matmul` + - Steps 3 & 4 are fused: the param update is folded as an epilogue into the second `matmul` +- `Hybrid` + - Step 1 is performed using standard `torch matmul` (i.e., `cuBlas`) + - Step 2 is fused as an elementwise kernel + - Steps 3 & 4 per `Fused` + +#### Performance + +Below are benchmarks for various kernels: + +- `torch` - reference `torch` implementation where each of the steps are implemented verbatim per above +- `hybrid` - see above +- `fused` - see above +- `compiled` - `torch` reference implementation compiled using `torch.compile` with `fullgraph=True` and `mode="max-autotune"`. + +Configs for each benchmark are the `grad (param)` shape, `dtype` of `grad` and `adam` states, and `allow_tf32`, whether `torch` and `triton` matmuls are allowed to use `TF32` tensor cores (see `Discussion`). + +`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `False` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 0.560128 0.347136 0.505856 0.534528 +1 64.0 0.627712 0.404480 0.600960 0.615424 +2 128.0 0.825232 0.583168 0.985072 0.833536 +3 256.0 1.378304 1.126400 1.489920 1.375232 +4 512.0 2.286080 2.101760 2.969600 2.302976 +``` + +`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `True` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 0.540672 0.321536 0.316416 0.508928 +1 64.0 0.612240 0.337728 0.345024 0.538624 +2 128.0 0.640000 0.395264 0.393216 0.693248 +3 256.0 0.777216 0.489472 0.548784 1.102848 +4 512.0 1.216512 0.864256 0.960512 1.968128 +``` + +`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `False` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 1.538672 0.915456 0.835584 1.364032 +1 64.0 1.546240 0.940032 1.022976 1.486848 +2 128.0 2.116608 1.498112 1.613312 2.098176 +3 256.0 3.423744 2.719744 2.881536 3.227136 +4 512.0 5.499904 5.036544 5.450752 5.508096 +``` + +`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `True` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 1.413120 0.871424 0.817152 1.353184 +1 64.0 1.489920 0.916480 0.854016 1.389568 +2 128.0 1.679360 0.996352 1.005568 1.563648 +3 256.0 2.152448 1.415168 1.470464 2.185216 +4 512.0 3.210240 2.460672 2.580480 3.477504 +``` + +##### Accuracy + +Comparison to reference `torch` implementation: + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32, and allow_tf32 True +Kernel: hybrid +Accuracy: +-> adam state - running grad mean: + Max err: 0.000000 Relative err: 0.000001 +-> adam state - running grad var: + Max err: 0.000002 Relative err: 0.000002 +-> params (after update): + Max err: 0.000000 Relative err: 0.000001 +``` + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False +Kernel: hybrid +Accuracy: +-> adam state - running grad mean: + Max err: 0.000000 Relative err: 0.000000 +-> adam state - running grad var: + Max err: 0.000002 Relative err: 0.000002 +-> params (after update): + Max err: 0.000000 Relative err: 0.000000 +``` + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 True +Kernel: fused +Accuracy: +-> adam state - running grad mean: + Max err: 0.000845 Relative err: 0.001152 +-> adam state - running grad var: + Max err: 0.000162 Relative err: 0.000161 +-> params (after update): + Max err: 0.000000 Relative err: 0.000001 +``` + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False +Kernel: fused +Accuracy: +-> adam state - running grad mean: +Max err: 0.000003 Relative err: 0.000004 +-> adam state - running grad var: +Max err: 0.000002 Relative err: 0.000002 +-> params (after update): +Max err: 0.000000 Relative err: 0.000000 +``` + +#### Discussion + +##### Down Projection GEMM Shape + +The motivation for the `hybrid` approach is the unconventional matrix shapes of the down projection (Step 1): + +- The projection is always done such that the larger dimension of the `grad` matrix is maintained while other is projected to low rank per the `GaLore` algorithm + - E.g., if `M >= N`, the GEMM is of shape (`M x N`) x (`N x rank`) = (`M x rank`), (`rank x M`) x (`M x N`) = (`rank x N`) otherwise +- Since `{M, N} >> rank` by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny) +- This does not fit cleanly into the `split-k / parallel reduction` `GEMM` paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension. +- Consequently, I had trouble finding an optimal kernel config using `triton` `autotuner` for the down projection step, despite tuning across many compute and io-bound configs (see `fused.triton_utils.kernels.matmul.py`). +- Benchmarking `triton`-tuned `matmul` against default `torch.matmul` for these shapes showed worse performance, for `torch.float32` + +#### Effect of `TF32` tensor cores + +`allow_tf32`: this has significant impact on relative performance of `triton` vs `torch` matmuls: + +- Quick benchmarks of the downprojection `matmul` show that: + - with `allow_tf32=True` for both, triton exhibits `~1.30x` performance improvement over `torch`. + - with `allow_tf32=False`, performance of `triton` degrades significantly to `~.50x` of `torch`. + +See this [`torch note`](https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere) for more details on this feature. + +**Note**: This might be less of a concern given this incoming triton [PR](https://github.com/openai/triton/pull/3234), which implements a fast `TF32` trick that improves both performance and accuracy. + +#### Repro + +_Accuracy_ + +- Test accuracy of `torch` vs `hybrid` for `M=4096`, `N=4096`, `rank=128`, and `tf32` switched on: + + ```python + pytest test/kernel/test_fused_kernels.py + ``` + +_Benchmark_ + +- Benchmark across all kernels without `tf32`: + + ```python + python benchmarks/bench_galore_fused_kernels.py + ``` + +For additional benchmarking options: + +```python + python benchmarks/bench_galore_fused_kernels.py --help +``` + +#### Test Env + +- GPU Device Props: + - Name: `NVIDIA RTX A6000` + - CC: `86` + - Total_memory: `48676MB` + - SM count: `84` +- Torch: `2.2.2` +- Triton: `2.2.0` diff --git a/torchao/prototype/galore/docs/galore_adam8bit.md b/torchao/prototype/galore/docs/galore_adam8bit.md new file mode 100644 index 0000000000..ddb45c29b8 --- /dev/null +++ b/torchao/prototype/galore/docs/galore_adam8bit.md @@ -0,0 +1,35 @@ +## GaLore AdamW8bit Optimizer + +### Overview + +`GaLore` AdamW8bit optimizer utilizes `bitsandbytes` `AdamW8bit` optimizer to additionally quantize the optimizer states. + +In addition to the additional ops introduced by `GaLore` to the standard `Adam` update step (see the `galore_adam.md` for details), additional dequantize / quantize steps are needed: + +- one to to dequantize the quantized states for the state update +- after the states are updated, they need to quantized along and `quant_state` updated +- For `bitsandbytes` `AdamW8bit`, the `quant_state` consists of group-wise (`blocksize`) scaling factors. + +The `bitsandbytes` 8bit optimizer is implemented in CUDA, with handcrafted logic for implementing each of these steps. + +> The motivation for re-implementing this optimizer purely in `triton` / `torch` is to enable exploration of various fusion / optimization strategies that would be difficult with the current CUDA impl. + +#### Quantization Algorithm + +1. Weights are quantized in contiguous `blocksize` segments +2. Given tensor `M x N`, reshape to `-1 x blocksize` +3. Find columnwise `absmax` and normalize tensor by dividing by `absmax` +4. Reshape normalized tensor back to original shape +5. `bitsandbytes` then uses an `8-bit` [quantization code](https://github.com/TimDettmers/bitsandbytes/blob/76885a41df9e6c94b3f80b1c37374c8441b6933e/bitsandbytes/optim/optimizer.py#L146-L151), which can either be signed or unsigned -- signed for tracking `mean`, unsigned for tracking `var`. +6. The normalized tensor is then assigned to the code it is closest to: + - E.g., given normalized value `.0412` and buckets `.0402` and `.0416`, it will be assigned the latter code. +7. **IMPORTANT**: This gives rise to a small number of edge-case errors when trying to reproduce `bitsandbytes` quantization + - Specifically, if a normalized value falls directly between two codes there is a degree of indeterminism. + - E.g., in the previous example, if the normalized value is `.0409`, it would be equidistant to the codes `.0402` and `.0416`. + - See the assertions in the `test_galore_quant.py` unittest that checks that these are the only discrepancies arising from the triton implementation (run with `pytest -sv -k` flags to see the output from this test). + +### bitsandbytes CUDA Source + +- Adam[W]8bit [update step](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L1770) +- Adam blockwise [quantization](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L413) after update +- [Blockwise](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L726) [Quantization](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L339) kernel diff --git a/torchao/prototype/galore/kernels/__init__.py b/torchao/prototype/galore/kernels/__init__.py new file mode 100644 index 0000000000..d8a134717b --- /dev/null +++ b/torchao/prototype/galore/kernels/__init__.py @@ -0,0 +1,4 @@ +from .adam_downproj_fused import fused_adam_mm_launcher +from .adam_step import triton_adam_launcher +from .matmul import triton_mm_launcher +from .quant import triton_dequant_blockwise, triton_quantize_blockwise diff --git a/torchao/prototype/galore/kernels/adam_downproj_fused.py b/torchao/prototype/galore/kernels/adam_downproj_fused.py new file mode 100644 index 0000000000..9049baa782 --- /dev/null +++ b/torchao/prototype/galore/kernels/adam_downproj_fused.py @@ -0,0 +1,353 @@ +import logging + +import torch +import triton +import triton.language as tl +from triton.ops.matmul import get_higher_dtype, init_to_zero +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +from .adam_step import BETA1, BETA2, EPS +from .custom_autotune import Config, autotune +from .matmul import TRITON_ACC_TYPES +from .matmul import get_autotuner as default_mm_autotuner +from .matmul import get_mm_heuristics, to_tl_type + +logger = logging.getLogger(__name__) + +AUTOTUNER_TOP_K = 50 + + +def set_tuner_top_k(k): + global AUTOTUNER_TOP_K + AUTOTUNER_TOP_K = k + + +@triton.jit +def _fused_adam_mm_kernel( + # matmul args + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + # adam epilogue, + exp_avg_ptr, # these will be updated inplace + exp_avg2_ptr, + store, + # grad_ptr, # low rank grad output -- not needed, C is the output + # meta params + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_M: tl.constexpr, + # Adam-specific params + BETA1: tl.constexpr = BETA1, + BETA2: tl.constexpr = BETA2, + EPS: tl.constexpr = EPS, + # matmul kernel settings + acc_dtype: tl.constexpr = tl.float32, # + allow_tf32: tl.constexpr = False, # higher precision for this phase + fp8_fast_accum: tl.constexpr = False, # + AB_DTYPE: tl.constexpr = None, # +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + # acc = acc.to(C.dtype.element_ty) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + epilogue_offsets = rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + + # Load adam state + exp_avg = tl.load(exp_avg_ptr + epilogue_offsets, mask=mask) + exp_avg2 = tl.load(exp_avg2_ptr + epilogue_offsets, mask=mask) + + # Perform update + exp_avg = BETA1 * exp_avg.to(acc.dtype) + (1.0 - BETA1) * acc + exp_avg2 = BETA2 * exp_avg2.to(acc.dtype) + (1.0 - BETA2) * (acc * acc) + denom = tl.sqrt(exp_avg2) + EPS + norm_grad = exp_avg / denom + # Convert to output type + norm_grad = norm_grad.to(C.dtype.element_ty) + + # acc = acc.to(C.dtype.element_ty) + C = C + epilogue_offsets + + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, norm_grad, mask=mask) + else: + tl.atomic_add(C, norm_grad, mask=mask) + + if store: + tl.store( + exp_avg_ptr + epilogue_offsets, + exp_avg, + mask=mask, + ) + tl.store( + exp_avg2_ptr + epilogue_offsets, + exp_avg2, + mask=mask, + ) + + +def _get_configs_splitk_all(): + """ + Configs specific to split-k matmuls + Not used currently + """ + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128]: + for block_k in [16, 32, 64, 128, 256]: + for block_n in [16, 32, 64, 128]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def _get_configs_splitk_small(): + """Configs for split-k, smaller version than above + Not used currently + """ + configs = [] + for num_stages in [2, 3, 4]: + for block_m in [64, 128]: + for block_k in [16, 32, 64]: + for block_n in [64, 128]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def _splitk_autotuner( + configs=_get_configs_splitk_small(), + key=["M", "N", "K"], + early_config_prune=early_config_prune, + perf_model=estimate_matmul_time, + top_k=AUTOTUNER_TOP_K, +): + """Autotuner for splitk matmuls + Not used currently + """ + autotuner = autotune( + configs=configs, + key=key, + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": perf_model, + "top_k": top_k, + }, + ) + + return autotuner + + +def _get_kernel( + tuner_fn=default_mm_autotuner, heuristics_fn=get_mm_heuristics, topk=AUTOTUNER_TOP_K +): + tuner = tuner_fn() + tuner.topk = topk + heuristics = heuristics_fn() + return tuner(heuristics(_fused_adam_mm_kernel)) + + +DEFAULT_KERNEL = _get_kernel() + + +def fused_adam_mm_launcher( + a, + b, + *, + exp_avg, + exp_avg2, + store=True, + BETA1=BETA1, + BETA2=BETA2, + EPS=EPS, + allow_tf32=False, + fp8_fast_accum=False, + acc_dtype=None, + output_dtype=None, + kernel=None, +): + + device = a.device + # handle non-contiguous inputs if necessary + # a = grad + # b = galore_proj.ortho_matrix.t() + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + if acc_dtype is None: + acc_dtype = [ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + if kernel is None: + kernel = DEFAULT_KERNEL + kernel[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + exp_avg, + exp_avg2, + store=store, + BETA1=BETA1, # , # + BETA2=BETA2, # , # + EPS=EPS, # + acc_dtype=acc_dtype, # + allow_tf32=allow_tf32, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, + AB_DTYPE=ab_dtype, + ) + return exp_avg, exp_avg2, c # c -> normalized low rank grad diff --git a/torchao/prototype/galore/kernels/adam_step.py b/torchao/prototype/galore/kernels/adam_step.py new file mode 100644 index 0000000000..75b2c870d2 --- /dev/null +++ b/torchao/prototype/galore/kernels/adam_step.py @@ -0,0 +1,178 @@ +import torch +import triton +import triton.language as tl +from triton.language.math import sqrt +from triton.runtime.autotuner import heuristics + +from .custom_autotune import Config, autotune + +BETA1, BETA2 = 0.9, 0.999 +EPS = 1e-8 + +AUTOTUNER_TOP_K = 10 + + +def get_configs_for_adam(num_warps=[2, 4, 8], block_sizes=[512, 1024, 2048]): + configs = [] + for w in num_warps: + for bs in block_sizes: + configs.append(Config({"BLOCK_SIZE": bs}, num_warps=w)) + return configs + + +def early_adam_prune(configs, named_args): + numels = named_args["numels"] + pruned_configs = [cfg for cfg in configs if numels % cfg.kwargs["BLOCK_SIZE"] == 0] + # print("Pruned configs:\n") + for cfg in pruned_configs: + print(f"{cfg}\n") + return pruned_configs + + +def get_adam_tuner( + configs=get_configs_for_adam(), + early_config_prune=None, # early_adam_prune, + top_k=AUTOTUNER_TOP_K, +): + return autotune( + configs=configs, + prune_configs_by={ + "early_config_prune": early_config_prune, + "top_k": top_k, + }, + key=["numels"], + ) + + +def get_adam_heuristics(): + return { + "USE_MASK": lambda args: args["numels"] % args["BLOCK_SIZE"] != 0, + } + + +@autotune(configs=get_configs_for_adam(), key=["numels"]) +@heuristics(get_adam_heuristics()) +@triton.jit +def _adam_update( + avg_ptr, + avg2_ptr, + grad_ptr, + # avg_out_ptr, + # avg2_out_ptr, + # grad_out_ptr, + numels, + store, + BLOCK_SIZE: tl.constexpr, + USE_MASK: tl.constexpr, + BETA1: tl.constexpr = BETA1, + BETA2: tl.constexpr = BETA2, + EPS: tl.constexpr = EPS, +): + pid_m = tl.program_id(0) + offset = pid_m * BLOCK_SIZE + offset = offset + tl.arange(0, BLOCK_SIZE) + # load_idx = offset + tl.arange(0, BLOCK_SIZE) + load_idx = tl.max_contiguous(tl.multiple_of(offset, BLOCK_SIZE), BLOCK_SIZE) + + mask = None + if USE_MASK: + mask = load_idx < numels + avg = tl.load(avg_ptr + load_idx, mask=mask) + avg2 = tl.load(avg2_ptr + load_idx, mask=mask) + grad = tl.load(grad_ptr + load_idx, mask=mask) + + avg = BETA1 * avg + (1.0 - BETA1) * grad + avg2 = BETA2 * avg2 + (1.0 - BETA2) * (grad * grad) + + denom = sqrt(avg2) + EPS + # denom = tl.sqrt(avg2) + EPS + + norm_grad = avg / denom + + if store: + tl.store(avg_ptr + load_idx, avg, mask=mask) + tl.store(avg2_ptr + load_idx, avg2, mask=mask) + tl.store(grad_ptr + load_idx, norm_grad, mask=mask) + # tl.store(avg_out_ptr + load_idx, avg, mask=mask) + # tl.store(avg2_out_ptr + load_idx, avg2, mask=mask) + # tl.store(grad_out_ptr + load_idx, norm_grad, mask=mask) + + +adam_update = _adam_update + + +def triton_adam_launcher( + avg, + avg2, + grad, + store=True, + beta1=BETA1, + beta2=BETA2, + eps=EPS, +): + M, N = avg.shape + # avg_out = torch.empty_like(avg) + # avg2_out = torch.empty_like(avg2) + # grad_out = torch.empty_like(grad) + + grid = lambda META: (triton.cdiv(M * N, META["BLOCK_SIZE"]),) + adam_update[grid]( + avg, + avg2, + grad, + # avg_out, + # avg2_out, + # grad_out, + avg.numel(), + store=store, + BETA1=beta1, + BETA2=beta2, + EPS=eps, + # BLOCK_SIZE=1024, + # USE_MASK=USE_MASK, + ) + return avg, avg2, grad + + +def ref_adam_step(exp_avg, exp_avg2, grad, beta1=BETA1, beta2=BETA2, eps=EPS): + exp_avg = beta1 * exp_avg + (1 - beta1) * grad + exp_avg2 = beta2 * exp_avg2 + (1 - beta2) * torch.square(grad) + denom = exp_avg2.sqrt() + eps + norm_grad = exp_avg / denom + return exp_avg, exp_avg2, norm_grad + + +def make_data(M, N, rank, dtype): + # full_grad = torch.randn(M, N, device="cuda", dtype=dtype) + params = torch.randn(M, N, device="cuda", dtype=dtype) + + if M >= N: + exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) + else: + exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) + exp_avg2 = exp_avg**2 + down_grad = torch.randn_like(exp_avg) + + return exp_avg, exp_avg2, down_grad, params + + +if __name__ == "__main__": + from triton.testing import do_bench + + M = N = 4096 + rank = 128 + dtype = torch.float32 + exp_avg, exp_avg2, grad, params = make_data(M, N, rank, dtype=dtype) + exp_avg_copy, exp_avg2_copy, grad_copy = ( + exp_avg.clone(), + exp_avg2.clone(), + grad.clone(), + ) + ref_out = ref_adam_step(exp_avg, exp_avg2, grad) + + # Autotune run -- changes exp_avg, exp_avg2, grad in-place + _ = triton_adam_launcher(exp_avg, exp_avg2, grad) + triton_out = triton_adam_launcher(exp_avg_copy, exp_avg2_copy, grad_copy) + + for ref, tt in zip(ref_out, triton_out): + print(torch.max(torch.abs(ref - tt))) diff --git a/torchao/prototype/galore/kernels/custom_autotune.py b/torchao/prototype/galore/kernels/custom_autotune.py new file mode 100644 index 0000000000..3b76605955 --- /dev/null +++ b/torchao/prototype/galore/kernels/custom_autotune.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +import builtins +import logging +import os +import time +from typing import Dict + +import numpy as np +from triton.runtime.cache import default_cache_dir +from triton.runtime.errors import OutOfResources +from triton.runtime.jit import KernelInterface +from triton.testing import do_bench + +logger = logging.getLogger(__file__) + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args: 0 + if len(self.reset_idx) > 0 or len(self.restore_idx) > 0: + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + if len(self.restore_idx) > 0: + + def _post_hook(args): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get( + "early_config_prune", self.early_config_prune + ) + + self.fn = fn + self.num_warmups = warmup + self.num_reps = rep + # self.autotune_log_path = os.path.join(default_cache_dir(), autotune_log_file) + self.kernel_name = self._find_kernel_name() + + def _find_kernel_name(self): + try: + kernel_name = self.fn.__name__ + except AttributeError: + try: # in case JITfn is wrapped in both autotune and heuristic + kernel_name = self.fn.fn.__name__ + except: # noqa + kernel_name = self.fn.__name__ + return kernel_name + + def _get_key_combination(self, args, as_str=True, sep=" "): + key_vals = [f"{self.arg_names[i]}={args[i]}" for i in self.key_idx] + return f"{sep}".join(key_vals) if as_str else key_vals + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **current, + ) + self.post_hook(args) + + try: + return do_bench( + kernel_call, + warmup=self.num_warmups, + rep=self.num_reps, + quantiles=(0.5, 0.2, 0.8), + ) + except OutOfResources: + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + logger.debug("Cache miss!\n") + logger.info( + f"\n==== Autotune ====\nRunning autotune for {self.kernel_name} for {len(self.configs)} total configs" + f" for key combination {self._get_key_combination(args)}..." + ) + # prune configs + pruned_configs = self.prune_configs(kwargs) + logger.info(f"\nNum configs after pruning {len(pruned_configs)}") + bench_start = time.time() + timings = {} + for config in pruned_configs: + timings[config] = self._bench(*args, config=config, **kwargs) + # timings = { + # config: self._bench(*args, config=config, **kwargs) + # for config in pruned_configs + # } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + + sorted_timings = dict( + sorted(timings.items(), key=lambda x: np.mean(x[1])) + ) + _key_suffix = self._get_key_combination(args, sep="-") + autotune_file = f"autotune_{self.kernel_name}_{_key_suffix}.log" + autotune_log_path = os.path.join(default_cache_dir(), autotune_file) + + logger.info(f"\nFinished autotune, writing log to {autotune_log_path}") + + with open(f"{autotune_log_path}", "w") as f: + f.write( + f" ==== Autotune Results ====\nKernel name: {self.kernel_name}\nArgs: {self.arg_names}\nKeys: {self._get_key_combination(args)}\n" + ) + f.write(f"\nPruned configs:\n") + for cfg in pruned_configs: + f.write(f"{cfg}\n") + f.write(f"Timings:\n") + for cfg, timing in sorted_timings.items(): + f.write(f"{cfg} {timing} \n") + f.write(f"Best config: {self.cache[key]}\n") + config = self.cache[key] + logger.debug(f"\nAutotune: Cache hit! Running best config...") + else: + config = self.configs[0] + self.best_config = config + logger.info(f"\nAutotune Best Config: {config}\n") + + full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} + if config.pre_hook is not None: + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **kwargs, + **config.kwargs, + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append( + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + ) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + return ", ".join(res) + + +def autotune( + configs, + key, + prune_configs_by=None, + reset_to_zero=None, + restore_value=None, + warmup=25, + rep=100, +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by, + warmup, + rep, + ) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/torchao/prototype/galore/kernels/matmul.py b/torchao/prototype/galore/kernels/matmul.py new file mode 100644 index 0000000000..b183f7ed66 --- /dev/null +++ b/torchao/prototype/galore/kernels/matmul.py @@ -0,0 +1,383 @@ +import torch +import triton +import triton.language as tl +from triton.ops.matmul import get_higher_dtype, init_to_zero +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +from .custom_autotune import Config, autotune, heuristics + +# Allowed types for acc_type given the types of a and b. +TRITON_ACC_TYPES = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), +} + +AUTOTUNER_TOP_K = 50 + + +def set_tuner_top_k(k): + global AUTOTUNER_TOP_K + AUTOTUNER_TOP_K = k + + +def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def get_configs_compute_bound(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + +def get_autotuner( + configs=get_configs_compute_bound() + get_configs_io_bound(), + key=["M", "N", "K"], + early_config_prune=early_config_prune, + perf_model=estimate_matmul_time, + top_k=AUTOTUNER_TOP_K, +): + autotuner = autotune( + configs=configs, + key=key, + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": perf_model, + "top_k": top_k, + }, + ) + + return autotuner + + +def get_mm_heuristics(): + return heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } + ) + + +@triton.jit +def _matmul_kernel( + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + # meta params + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_M: tl.constexpr, + # epilogue + epilogue_alpha=None, + epilogue_beta=None, + epilogue_source=None, # Corresponds to C in GEMM convention of D = AB + C + # matmul kernel settings + acc_dtype: tl.constexpr = tl.float32, # + allow_tf32: tl.constexpr = True, # + fp8_fast_accum: tl.constexpr = True, # + AB_DTYPE: tl.constexpr = None, # + EPILOGUE: tl.constexpr = False, +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + # acc = acc.to(C.dtype.element_ty) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + if EPILOGUE: + if epilogue_alpha is not None: + acc = epilogue_alpha.to(acc_dtype) * acc + if epilogue_source is not None: + epilogue_src = tl.load( + epilogue_source + rm[:, None] * stride_cm + rn[None, :] * stride_cn + ) + if epilogue_beta is not None: + epilogue_src = epilogue_src.to(acc_dtype) * epilogue_beta.to(acc_dtype) + acc = acc + epilogue_src + + acc = acc.to(C.dtype.element_ty) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +_autotuner = get_autotuner() +_heuristics = get_mm_heuristics() +matmul = _autotuner(_heuristics(_matmul_kernel)) + + +def triton_mm_launcher( + a, + b, + epilogue_alpha=None, + epilogue_beta=None, + epilogue_source=None, + allow_tf32=True, + fp8_fast_accum=True, + acc_dtype=None, + output_dtype=None, + kernel=matmul, +): + + device = a.device + # handle non-contiguous inputs if necessary + # a = grad + # b = galore_proj.ortho_matrix.t() + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + if acc_dtype is None: + acc_dtype = [ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + # launch kernel + # print( + # f"{__file__} triton matmul args: (AB dtype {ab_dtype}) (C dtype {c.dtype}) (allow_tf32 {allow_tf32}) (fp8_fast_accum {fp8_fast_accum})" + # ) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + matmul[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + epilogue_alpha=epilogue_alpha, # + epilogue_beta=epilogue_beta, # + epilogue_source=epilogue_source, # + acc_dtype=acc_dtype, # + allow_tf32=allow_tf32, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, + AB_DTYPE=ab_dtype, + EPILOGUE=any([epilogue_alpha, epilogue_beta, epilogue_source]), + ) + return c diff --git a/torchao/prototype/galore/kernels/quant.py b/torchao/prototype/galore/kernels/quant.py new file mode 100644 index 0000000000..516b741eab --- /dev/null +++ b/torchao/prototype/galore/kernels/quant.py @@ -0,0 +1,177 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _dequant_kernel( + q_idx_ptr, + absmax_ptr, + qmap_ptr, + dq_ptr, + stride_qm, + stride_qn, + GROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # rm = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + # rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + offsets = rm[:, None] * stride_qm + rn[None, :] * stride_qn + tl.static_print(offsets) + group_offsets = offsets // GROUP_SIZE + tl.static_print("group_offsets", group_offsets) + q_idx = tl.load(q_idx_ptr + offsets) + tl.static_print(q_idx) + # NOTE: Must upcast q_idx to int32 (q_idx is tl.uint8, which does not work for pointer indexing) + q_vals = tl.load(qmap_ptr + q_idx.to(tl.int32)) + absmax = tl.load(absmax_ptr + group_offsets) + + dq = q_vals * absmax + tl.store(dq_ptr + offsets, dq) + + +def triton_dequant_blockwise( + q: torch.Tensor, qmap: torch.Tensor, absmax: torch.Tensor, group_size: int +): + M, N = q.shape + dq = torch.empty_like(q).to(absmax.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + _dequant_kernel[grid]( + q, + absmax, + qmap, + dq, + q.stride(0), + q.stride(1), + BLOCK_M=1, + BLOCK_N=group_size, + GROUP_SIZE=group_size, + ) + return dq + + +@triton.heuristics( + values={ + "USE_MASK": lambda args: args["numels"] % args["BLOCK_SIZE"] != 0, + "NUM_GROUPS": lambda args: triton.cdiv(args["numels"], args["BLOCK_SIZE"]), + } +) +@triton.jit +def _quantize_blockwise_kernel( + t_ptr, + cutoffs_ptr, + q_ptr, + absmax_ptr, + norm_ptr, + numels, + BLOCK_SIZE: tl.constexpr, + NUM_BUCKETS: tl.constexpr, + USE_MASK: tl.constexpr, + NUM_GROUPS: tl.constexpr, + RETURN_NORM: tl.constexpr = False, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = None + absmax_mask = None + if USE_MASK: + mask = offsets < numels + absmax_mask = pid < NUM_GROUPS + t = tl.load(t_ptr + offsets, mask=mask) + + absmax = tl.max(tl.abs(t), axis=0) + normalized = t / absmax + + # Load code buckets + cutoffs = tl.load(cutoffs_ptr + tl.arange(0, NUM_BUCKETS)) + q = tl.reshape(normalized, (BLOCK_SIZE, 1)) > cutoffs + + # NOTE: explicit cast is needed, addition on tl.int1 (bool) does not work as per torch / numpy + q = q.to(tl.uint8) + q = tl.sum(q, axis=1) + + tl.store(q_ptr + offsets, q, mask=mask) + # Each block processes one group_size number of elements, hence 1 absmax + tl.store(absmax_ptr + pid, absmax, mask=absmax_mask) + + if RETURN_NORM: + tl.store(norm_ptr + offsets, normalized, mask=mask) + + +# NOTE: Each block processes one group_size number of elements, hence BLOCK_SIZE = group_size +# where group_size corresponds to the groupwise quantization blocksize +def triton_quantize_blockwise( + t: torch.Tensor, code, group_size=2048, return_normalized=False +): + """ + Params: + t: torch.Tensor, tensor to quantize + code: torch.Tensor, quantization codebook for bitsandbytes, output of `bitsandbytes.functional.create_dynamic_map` + # absmax: torch.Tensor, absolute max values for each block, if None, will be calculated from the input tensor + group_size: int, groupwise quantization blocksize, default 2048, the hardcoded blocksize for bitsandbytes 8-bit optimizers + return_normalized: bool, if True, will return the normalized tensor, primarily for debugging + """ + numel = t.numel() + q = torch.empty(numel, dtype=torch.uint8, device=t.device) + normalized = torch.empty_like(t) if return_normalized else None + num_groups = numel // group_size + abs_max = torch.empty(num_groups, dtype=t.dtype, device="cuda") + # Cutoffs for quantization + # code corresponds to actual (normalized) quant codes + # Cutoffs are used to calculate which of these codes a value belongs to + # E.g., for consecutive codes C1 and C2, the corresponding cutoff is C1 + C2 / 2 + # Hence, if a value is greater is assigned C1 if it is less than all cutoffs up to this cutoff + cutoffs = (code[:-1] + code[1:]) / 2 + + # Need to make cutoffs multiple of 2 for triton reduce + MAX_CUTOFF = torch.tensor( + torch.finfo(cutoffs.dtype).max, dtype=cutoffs.dtype, device=cutoffs.device + ).reshape( + 1, + ) + cutoffs = torch.cat([cutoffs, MAX_CUTOFF], dim=-1) + assert cutoffs.numel() % 2 == 0 + + grid = lambda META: (triton.cdiv(t.numel(), META["BLOCK_SIZE"]),) + # assert t.numel() % group_size == 0 + _quantize_blockwise_kernel[grid]( + t.view(-1), + cutoffs, + q, + abs_max, + normalized.view(-1) if return_normalized else None, + numel, + NUM_BUCKETS=len(cutoffs), + BLOCK_SIZE=group_size, + RETURN_NORM=return_normalized, + ) + return ( + q.reshape(t.shape), + normalized.reshape(t.shape) if return_normalized else None, + abs_max, + ) + + +# Reference implementation +def _torch_quantize_blockwise(tensor: torch.Tensor, code, absmax=None, blocksize=2048): + # Flatten values first + + # If not absmax, need to first normalize -> reshape to (-1, blocksize) -> max over the last dim + + # Quantize by flattening A to [numels, 1] > code[:, None], sum, then reshape back to original shape + if absmax is None: + absmax = tensor.reshape(-1, blocksize).abs().max(dim=-1).values + + normalized = tensor.reshape(-1, blocksize) / absmax[:, None] + buckets = (code[:-1] + code[1:]) / 2 + q = normalized.reshape(normalized.numel(), 1) > buckets + q = q.sum(dim=1).reshape(tensor.shape) + return q.to(torch.uint8), normalized.reshape(tensor.shape), absmax diff --git a/torchao/prototype/galore/optim/__init__.py b/torchao/prototype/galore/optim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/prototype/galore/optim/galore_torch.py b/torchao/prototype/galore/optim/galore_torch.py new file mode 100644 index 0000000000..a44a7f449f --- /dev/null +++ b/torchao/prototype/galore/optim/galore_torch.py @@ -0,0 +1,401 @@ +"""Reference implementations +See https://github.com/jiaweizzhao/GaLore/tree/master/galore_torch +""" + +# copy dependencies from transformers/optimization.py +import math +import warnings +from typing import Callable, Iterable, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer + +from bitsandbytes.optim.optimizer import Optimizer2State + + +class GaLoreProjector: + def __init__( + self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std" + ): + self.rank = rank + self.verbose = verbose + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.proj_type = proj_type + + def project(self, full_rank_grad, iter): + + if self.proj_type == "std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "reverse_std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "right": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "left": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "full": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix( + full_rank_grad, self.rank, type="full" + ) + low_rank_grad = ( + torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) + @ self.ortho_matrix[1].t() + ) + + return low_rank_grad + + def project_back(self, low_rank_grad): + + if self.proj_type == "std": + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == "reverse_std": + if ( + low_rank_grad.shape[0] <= low_rank_grad.shape[1] + ): # note this is different from std + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + else: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == "right": + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == "left": + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == "full": + full_rank_grad = ( + torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] + ) + + return full_rank_grad * self.scale + + # svd decomposition + def get_orthogonal_matrix(self, weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + + # make the smaller matrix always to be orthogonal matrix + if type == "right": + # A = U[:, :rank] @ torch.diag(s[:rank]) + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type == "left": + A = U[:, :rank] + # B = torch.diag(s[:rank]) @ Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type == "full": + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError("type should be left, right or full") + + +class AdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + no_deprecation_warning: bool = False, + ): + if not no_deprecation_warning: + warnings.warn( + "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" + " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" + " warning", + FutureWarning, + ) + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + + grad = state["projector"].project(grad, state["step"]) + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + # compute norm gradient + norm_grad = exp_avg / denom + + # GaLore Projection Back + if "rank" in group: + norm_grad = state["projector"].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + return loss + + +class AdamW8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + overflows = [] + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + # if self.is_paged: self.page_mng.prefetch_all() + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + + if "weight_decay" in group and group["weight_decay"] > 0: + # ensure that the weight decay is not applied to the norm grad + group["weight_decay_saved"] = group["weight_decay"] + group["weight_decay"] = 0 + + grad = state["projector"].project(p.grad, state["step"]) + + # suboptimal implementation + p.saved_data = p.data.clone() + p.data = grad.clone().to(p.data.dtype).to(p.data.device) + p.data.zero_() + p.grad = grad + + if "state1" not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # GaLore Projection Back + if "rank" in group: + p.data = p.saved_data.add_(state["projector"].project_back(p.data)) + + # apply weight decay + if "weight_decay_saved" in group: + p.data.add_( + p.data, alpha=-group["lr"] * group["weight_decay_saved"] + ) + group["weight_decay"] = group["weight_decay_saved"] + del group["weight_decay_saved"] + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + return loss diff --git a/torchao/prototype/galore/utils.py b/torchao/prototype/galore/utils.py new file mode 100644 index 0000000000..41242cbd83 --- /dev/null +++ b/torchao/prototype/galore/utils.py @@ -0,0 +1,111 @@ +import torch + + +def get_orthogonal_matrix(weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + + # make the smaller matrix always to be orthogonal matrix + if type == "right": + # A = U[:, :rank] @ torch.diag(s[:rank]) + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type == "left": + A = U[:, :rank] + # B = torch.diag(s[:rank]) @ Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type == "full": + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError("type should be left, right or full") + + +class TestGaLoreProjector: + def __init__( + self, + rank=128, + scale=1.0, + proj_type="std", + ): + self.rank = rank + self.scale = scale + + if proj_type != "std": + raise ("Only std projection is supported") + + self.proj_type = proj_type + + self.ortho_matrix = None + + def update_orthogonal_matrix(self, full_rank_grad): + + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + self.ortho_matrix = get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + else: + self.ortho_matrix = get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + + def project(self, full_rank_grad): + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + + return low_rank_grad + + def project_back(self, low_rank_grad): + + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + + return full_rank_grad * self.scale + + +def make_copy(*args): + return [t.detach().clone() for t in args] + + +# def adam_step( +# exp_avg, +# exp_avg2, +# grad, +# galore_proj, +# params, +# step_size=1e-4, +# beta1=BETA1, +# beta2=BETA2, +# eps=EPS, +# ): +# grad = galore_proj.project(grad) +# exp_avg = beta1 * exp_avg + (1 - beta1) * grad +# exp_avg2 = beta2 * exp_avg2 + (1 - beta2) * torch.square(grad) +# denom = exp_avg2.sqrt() + eps +# norm_grad = exp_avg / denom +# norm_grad = galore_proj.project_back(norm_grad) +# # params = params - step_size * norm_grad +# return exp_avg, exp_avg2, denom, norm_grad