Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GaLore and fused kernel prototypes #95

Merged
merged 28 commits into from
Apr 16, 2024
Merged

Conversation

jeromeku
Copy link
Collaborator

Prototype Kernels and Utils

Currently:

  • GaLore
    • Initial implementation of fused kernels for GaLore memory efficient training.

TODO:

  • triton
    • Composable triton kernels for quantized training and inference
  • cutlass
    • Pythonic utils for defining custom cutlass kernels and other quant ops

@msaroufim

@facebook-github-bot
Copy link

Hi @jeromeku!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits on packaging and minor questions on the kernels, will do another pass to review the kernels properly - let's ensure the tests run in CI and if a T4 machine is not enough then we need to get a beefier GPU asap

#### TODO

- Common quant / dequant kernels for popular quantization frameworks
- [ ] GPTQ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have code for GPTQ https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/GPTQ.py
Bits and bytes is also interesting granted we also have kernels for QLoRA that are codegened here as well
HQQ would be nice to add

Copy link
Collaborator Author

@jeromeku jeromeku Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok -- will work on kernels for bitsandbytes AdamW8bit then onto HQQ.

HQQ's 4 and 8-bit should be easy to adapt; 1, 2, and 3-bit might require some additional preprocessing to optimize.

Their current CUDA dequant implementations can definitely be optimized. Will work on re-implementing in CUDA and triton.

Also am looking into how to decomposing Marlin kernel design into reusable building blocks for optimized quant inference.

@@ -0,0 +1,20 @@
from setuptools import setup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rework the exisitng setup.py to package your kernels into the core package - Happy to credit you in the files directly and/or README

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np -- whatever is easiest

@@ -0,0 +1,3 @@
# Cutlass Quant

### Pythonic tools for defining `cutlass` kernels and quantization ops
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @andrewor14 who has also been thinking about CUTLASS in the context of #86

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cutlass is very neat.

  • Cutlass 3.x and the CuTe framework that it introduces has many useful primitives and patterns for defining bespoke kernels of relevance (mixed type GEMM, MoE, etc.), though it is targeted primarily at sm90+ architectures.
  • The 2.x api has limited support for sub-byte mixed type quant kernels (without preprocessing weights to custom format -- I believe pytorch already has this integrated under torch.quantization._quantized_conversions).

Currently working on using Cutlass 3.x / CuTe to adapt / improve pre-Hopper kernels useful for quant ops. Would love to also test on Hopper but unfortunately don't have access to H100.


- [ ] Implement `FusedGaLoreOptimizer`
- [ ] `Cutlass` - given fixed GEMM shape, experiment with `Cutlass` GEMMs (`split-k`, `stream-k`, fast `tensorops`). Interestingly, profiling `torch.matmul` for down projection shows that `cuBlas` dispatches to a `Cutlass` kernel of shape `128x128x16`.
- [ ] Repeat with `AdamW8bit` - pure `triton` implementation of `bitsandbytes` `AdamW8bit`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this would be very helpful

#### Installation

```
pip install --editable .
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mentioned this already but we can package prototype under its own namespace in ao as opposed to its own package

3. normalized `grad` is projected to full rank --> additional matmul
4. `params` are updated with the normalized full rank grad

#### Implementation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appreciated this ncie simple explanation

print(f"Finished benchmark, results saved to {save_path}")


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the unittest format simiilarly to other tests in the repo, lmk if you need help here

logger = logging.getLogger(__file__)


class Autotuner(KernelInterface):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch we have a generic kernel auto tuner now right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tweaked the Autotuner to print additional info such as pruned configs, best config, cache hit, etc.


#### Next Steps

- [ ] Implement `FusedGaLoreOptimizer`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Generally for next steps I'd rather they get mentioned in a github issue vs docs

a = a.to(AB_DTYPE)
b = b.to(AB_DTYPE)
if fp8_fast_accum:
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I'm missing something dumb but how this an fp8 accum?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is directly from the triton matmul implementation.

In the launcher, the constexpr AB_DTYPE gets set to None if a and b are fp8 type. I'm guessing triton's underlying tl.dot implementation is overloaded to handle this case, which is probably why the signature differs slightly from the non-fp8 case: tl.dot(a, b, acc, ...) vs. tl.dot(a, b, ...) where acc is passed as an additional arg in the former. Need to dig a bit further to confirm.


# make the smaller matrix always to be orthogonal matrix
if type == "right":
A = U[:, :rank] @ torch.diag(s[:rank])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused; there's a PR upstream: jiaweizzhao/GaLore#18
Same comment on line 28 as well.

@@ -0,0 +1,65 @@
import logging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: boths tests pass but we need to decide whether we want benchmarks as tests or just accuracy checks

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 29, 2024
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@msaroufim
Copy link
Member

msaroufim commented Mar 31, 2024

Hi @jeromeku sorry for the delay, it took me a while to figure out the right way to review this but here's what I'm thinking we need to do to merge this. And feel free to reach out to me directly on Discord if you want to pair program any of this

  1. Delete the cutlass and triton folders, start a new github issue with your future plans
  2. Get rid of your setup.py make sure that all your kernels are accessible via torchao.prototype.galore
  3. For now keep your existing autotuner, I'll work with @cpuhrsch in a separate PR on making this more generic
  4. Right now you have both tests and benchmarks underneath your tests folder, move them to https://github.com/pytorch-labs/ao/tree/main/test and https://github.com/pytorch-labs/ao/tree/main/benchmarks respectively - these will act like good unit tests to make sure the kernels are accurate and fast
  5. We're missing a memory reduction test, the reason why people use GaLore is to save memory so some test that validates that there's a significant memory reduction would make this easier to merge
  6. We're missing a regression test, as is some tutorial for how people are expected to actually run the GaLore algorithm by using your kernels. These can go in a README or be a standalone test up to you - this will be helpful as we work with other fine-tuning libraries like axolotl or torchtune to get them to integrate your code

@janeyx99
Copy link
Contributor

janeyx99 commented Apr 1, 2024

Hi @jeromeku, glad to see a GaLore folder popping up in torchao :D I'm brainstorming ideas about supporting GaLore better from the core PyTorch side (e.g., making it work with distributed + checkpointing, using torch.compile instead, seeing if there's a way to generalize the GaLore technique across all optimizers without needing to create a separate optimizer for each). I see your PR introduces performant kernels, which is awesome, and I'd like to know if you want to collaborate on implementing something together. I notice your next steps are about adding more optimizers and doing analysis with torch.compile so there may be some opportunity to jam together there--I'm happy to discuss more on Slack or discord if this is something you're interested in.

BTW, review-wise, I would push for @msaroufim points 6 + 5 especially!

@jeromeku
Copy link
Collaborator Author

jeromeku commented Apr 1, 2024

Good to meet you @janeyx99! Great CUDA Mode lecture on optimizers btw :)

I'm brainstorming ideas about supporting GaLore better from the core PyTorch side (e.g., making it work with distributed + checkpointing, using torch.compile instead, seeing if there's a way to generalize the GaLore technique across all optimizers without needing to create a separate optimizer for each).

Yes - agree that it makes sense to generalize GaLore across all optimizers with minimal duplication. Part of the reason I focused first on creating fused kernels that implement various parts of the GaLore optimizer step is to enable these modular pieces to be plugged into any optimizer type and to be able to be composed with torch.compile.

I have most of the pieces for a pure triton / torch-native implementation of AdamW8bit implemented (no dependency on bitsandbytes) -- I think we should be able to use lessons learned from the initial Adam (this PR) impl and AdamW8bit as a reasonable starting point for generalizing across different parts of the torch stack.

Working on points 5 & 6 per review though need to take care of some other stuff first.

cc @msaroufim

@msaroufim msaroufim requested a review from lessw2020 April 2, 2024 23:27
@jeromeku
Copy link
Collaborator Author

jeromeku commented Apr 4, 2024

@msaroufim
cc @janeyx99

Updates:

  • Refactored prototype to now fall under the torchao namespace
  • Added README under torchao/prototype to outline the general motivation for things under the prototype umbrella
  • Removed unnecessary folders (e.g., cutlass)
  • Updated docs -- see torchao/prototype/galore/docs -- which contains implementation notes for galore_adam and galore_adam8bit (ongoing)
  • Added triton implementation of bitsandbytes 8-bit quant / dequant -- these are necessary components for implementing GaLore AdamW8bit optimizers; also added tests (see ao/test/quantization)
  • Updated torchao dev-requirements.txt and setup.py to include additional libs necessary for testing triton quant against bitsandbytes and triton benchmarking requirements.
  • Moved fused kernel tests under ao/test/kernels and separate benchmark under ao/benchmarks
  • Checked that all components / tests after the refactor.

Next steps:

  • Create a tutorial showing how to implement GaLore AdamW8bit using the kernels implemented under prototype/galore
  • Microbenchmark triton quant / dequant against bitsandbytes CUDA impl
  • E2E benchmark of full AdamW8bit update step using kernels from prototype/galore
  • Memory profiling to validate GaLore claims of memory efficiency
  • nsys analysis of fused kernel optimizers + ncu deep dive on kernel performance
  • Hqq, quarot, and other promising (and proven) techniques for efficient training / inference; see prototype/README.md for longer-term aims.

@zou3519 zou3519 self-requested a review April 5, 2024 17:22
@msaroufim
Copy link
Member

Hey @jeromeku I do wanna make sure we merge something of yours, the roadmap is ambitious and the right one but I'd suggest breaking it apart this way

For this PR

  • Memory profiling to validate GaLore claims of memory efficiency
  • Create a test showing how to implement GaLore AdamW8bit using the kernels implemented under prototype/galore

Next PR

  • Create a tutorial showing how to implement GaLore AdamW8bit using the kernels implemented under prototype/galore
  • Microbenchmark triton quant / dequant against bitsandbytes CUDA impl
  • E2E benchmark of full AdamW8bit update step using kernels from prototype/galore
  • nsys analysis of fused kernel optimizers + ncu deep dive on kernel performance
  • Hqq, quarot, and other promising (and proven) techniques for efficient training / inference; see prototype/README.md for longer-term aims.

@jeromeku
Copy link
Collaborator Author

@msaroufim

  • Added a GaLore memory testing script test/galore/test_memory_usage.py which profiles memory usage (using torch.profiler) across optimizers for various Llama model sizes (doesn't require downloading weights but configs tuned for various model sizes).
  • Added jupyter notebook that analyzes the output of the torch.profiler to provide summary stats on memory usage.
  • Note that this only includes reference (torch-only) implementations of the GaLore optimizers per original repo which are admittedly not highly optimized.
  • Will implement and test fused implementations AdamW and AdamW8bit next.
  • Included instructions and some preliminary analysis -- see test/galore/README.md -- take a look and let me know what you think.

| median | 516.3 | 403.6 | 0.0 | 0.0 | 75.7 | 272.8 | 0.0 | 18.1 |
| max | 595.0 | 403.6 | 0.3 | 6.6 | 1,336.0 | 395.3 | 312.9 | 173.6 |

- The `optimizer state` is indeed smaller for the `GaLoreAdamW` optimizer. Interestingly, the `Parameter` sizes balloons in the `GaLore` optimizer, likely due to extra data copies. Admittedly, the implementation is only a reference (per original repo) and leaves much room for optimization.
Copy link
Member

@msaroufim msaroufim Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for noting this, how are you thinking about next steps here? One hope I had for merging this is for other repos to take a dependency like https://github.com/pytorch/torchtune

Also maybe @janeyx99 has some idea of what might be going wrong

"transformers_version": "4.28.1",
"use_cache": true,
"vocab_size": 32000
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a heads up we can't merge binary files in the git repo so this includes html, pt files and most of the json files that were merged in

@@ -0,0 +1,70 @@
import pandas as pd
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a nice utility can make it standalone outside of galore context



@contextmanager
def nsys_profiler():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this would use the nvidia nsys profiler

@@ -0,0 +1,72 @@
from functools import partial
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if possible I'd rather we not have notebooks merged in, just turn the code into a python file and let's check that in

@msaroufim
Copy link
Member

msaroufim commented Apr 11, 2024

PR is looking good, I think we're close to merging this - main things that need to be fixed

  • CI failure, you might need some skip tests to ensure these tests only run on CPU
  • Get rid of all the binary files you checked in
  • The kernel looks good, if the numerics and memory checks are running in CI this becomes quite safe to land

@jeromeku do you mind if I make changes to your PR directly? It might help make the review process smoother

@jeromeku
Copy link
Collaborator Author

jeromeku commented Apr 11, 2024

PR is looking good, I think we're close to merging this - main things that need to be fixed

  • CI failure, you might need some skip tests to ensure these tests only run on CPU
  • Get rid of all the binary files you checked in
  • The kernel looks good, if the numerics and memory checks are running in CI this becomes quite safe to land

@jeromeku do you mind if I make changes to your PR directly? It might help make the review process smoother

@msaroufim No worries - feel free to make whatever changes necessary.

@jeromeku
Copy link
Collaborator Author

jeromeku commented Apr 11, 2024

@msaroufim

Made the following changes:

  • removed all binary files (model configs, sample data, profiler outputs, etc.)
  • reworked the profiler to work without these binaries
  • added instructions to README for how to use memory_analysis_utils.py to generate the analyses in lieu of the jupyter notebook
  • added preliminary memory stats for torch.optim.AdamW, bitsandbytes AdamW8bit, GaLoreAdamW, and GaLoreAdamW8bit(see updated README in test/galore). Note that the GaLore optimizers are from the reference (pure-torch) implementations in the original repo. Still to add the fused versions.
  • mark tests to skip requiring GPU as such

Additional conditions for skipping tests to avoid CI failure.
Rename files as they are not actual tests but profiling tools to avoid
triggering CI runs.
@msaroufim
Copy link
Member

msaroufim commented Apr 16, 2024

@jeromeku

TL;DR

  • Let's expose a Galore optimizer like torchao.prototype.GaLoreOptimizer and make it clear that's the main entry point for this work. GaloreProjector is also available for users that want more control
  • Move remaining benchmarks under test folder to benchmark folder
  • It'll be challenging for us to accept bitsandbytes as an additional dependency so we'll need to make this as an optional dependency - as in the code needs to work in ci with or without having bits and bytes installed

Benchmark notes

Confirming the benchmark script works, on an H100 on my end I get

On benchmark script I get

Adam Kernel Comparison Grad shape: 4096x4096, dtype: torch.float32, allow_tf32: False
Median times (ms):
    rank     torch    hybrid     fused  compiled
0   32.0  0.222944  0.224960  0.225728  0.196896
1   64.0  0.255264  0.386480  0.352864  0.238752
2  128.0  0.344352  0.363776  0.464224  0.330336
3  256.0  0.510464  0.416960  0.856896  0.515584
4  512.0  0.856352  0.786656  1.169344  0.899936

So indeed things are fastest for the hybrid approach, fused seems slower than eager and compile is fast but not fastest probably because I didn't enable tensor cores

I made a minor change to the way the flag is set there which is recommended over using

- torch.backends.cuda.matmul.allow_tf32 = allow_tf32
+ if allow_tf32:
 +   torch.set_float32_matmul_precision('high')

On nightlies this gives some cuda graph errors we can fix at a later time - not urgent. But it does highlight the importance of running these benchmark scripts in CI regularly. I'll make a PR myself to run everything in benchmarks/ nightly

SingleProcess AUTOTUNE benchmarking takes 3.9803 seconds and 0.0000 seconds precompiling
skipping cudagraphs due to mutation on input. Found from : 
   File "/home/marksaroufim/ao/benchmarks/fused_benchmark_utils.py", line 60, in _ref_op
    exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1))

If we allow tf32 I instead get torch.compile being universally faster, how should I read this? That it's best to express Galore in python code and run torch.compile? I'm fine if the answer ends up we need the fused kernels so we can also be faster in eager

Adam Kernel Comparison Grad shape: 4096x4096, dtype: torch.float32, allow_tf32: True
Median times (ms):
    rank     torch    hybrid     fused  compiled
0   32.0  0.190016  0.240576  0.216256  0.176992
1   64.0  0.196032  0.235648  0.208256  0.180480
2  128.0  0.217760  0.238624  0.213696  0.200032
3  256.0  0.220176  0.250784  0.240864  0.200224
4  512.0  0.260224  0.250896  0.249728  0.236128

API notes

I did confirm first that import torchao.prototype.galore works and that this is the full public API

dir(torchao.prototype.galore)
['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'adam_downproj_fused', 'adam_step', 'custom_autotune', 'fused_adam_mm_launcher', 'kernels', 'matmul', 'quant', 'triton_adam_launcher', 'triton_dequant_blockwise', 'triton_mm_launcher', 'triton_quantize_blockwise']

Which was making me wonder how would someone use these kernels exactly, the answer was in fused_benchmark_utils.py

Specifically here, the function is named make_data but it's closer to make_model, we could figure out how to make this work on a real model but for now that's fine

def make_data(M, N, rank, dtype):
    grad = torch.randn(M, N, device="cuda", dtype=dtype)
    params = torch.randn(M, N, device="cuda", dtype=dtype)

    galore_proj = GaLoreProjector(rank=rank)
    galore_proj.update_orthogonal_matrix(grad)

    if M >= N:
        exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype)
    else:
        exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype)
    exp_avg2 = exp_avg**2

    return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params

Regardless as I kept going I noticed

Tests

test/galore mostly has benchmarks so let's move it there, by test I mean some unit tests that could help us verify correctness of the code for example a single run from the benchmark from the benchmark script to make sure the code is not broken will go a long way

I was however able to confirm that the memory reductions are there

Galore Adam W
Max Memory Allocated: 2,332.1 MB
Max Memory Reserved: 2,654.0 MB

AdamW 
Max Memory Allocated: 2,564.4 MB
Max Memory Reserved: 2,814.0 MB

I also feel like it'll be challenging for us to accept bitsandbytes as a core dependency so we can relegate a lot of the 8 bit work to a tutorial for now as opposed to core functionality in the library

I also see some code duplication for example make_data() exists both in tests and benchmarks

test_fused_kernels.py and test_galore_downproj are very useful and easy to read thank you!

I could verify on my end that the test work and they are picked up in CI

(ao) [[email protected] ~/ao/test/kernel (galore_fused)]$ pytest .

test_autotuner.py ....                                                                                                                                     [ 57%]
test_fused_kernels.py .                                                                                                                                    [ 71%]
test_galore_downproj.py ..                                                                                                                                 [100%]

Docs

I appreciate the roadmap discusion here torchao/prototype/README.md but for now let's keep it in a github issue it's easier to discuss longer term work there

Thank you for clarifying the differences between the fused and hybrid implementations

The Galore 8 bit adam is a cherry on top but as previously mentioned I don't think we're ready to take on a new depdency like bits and bytes so you can keep this as a tutorial but would suggest removing bits and bytes specific code from the PR

I will say the most important function of the docs will be to communicate how you want people to use this work, so far it seems like the make_data() but then when I read your optimizer code it seems like you want people to inherit your custom Galore optimizer? But then I also read test/galore/profile_memory_usage.py where you took a real world model and made that work. So just be explicit in docs around how you expect people to consume your work

Kernels

You're much better than me at Triton lol so will let you decide how you want to surface things or change things here for better perf. As long as the kernels are correct I think it's fine to merge the code as is because we can always go through rounds of profiling and improvements.

There's some minor nits I have here mostly around removing commented code

I also think there's some dead code like class TestGaLoreProjector in utils.py which is maybe intended as a unit test or was maybe duplicate code for the optimizer?

Optimizer

This code was a joy to read, felt like a tutorial. Makes me wonder why you don't expose a public for a Galore optimizer as the main way for people to consume your work. EDIT: You do in the memory profile scripts

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work, there's a bunch of cleanup I'll follow up on later but nothing that should block landing this

@msaroufim msaroufim merged commit b0a649e into pytorch:main Apr 16, 2024
7 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* initial commit

* add placeholders for cutlass and triton

* update readme

* fix versions

* minor text edits

* clean up

* add triton bnb quant kernel and test

* add notes on triton quant kernel

* refactor code structure

* add galore downproj test

* refactor test utils

* add fused kernel tests

* add fused benchmark

* add dequant kernel

* update docs

* add galore memory test

* add adamw8bit

* fix README

* clean up binaries

* remove notebook, add instructions to README

* remove sample data

* Update galore tests

Skip tests if no GPU

* rename galore docs

* More test edits

Additional conditions for skipping tests to avoid CI failure.
Rename files as they are not actual tests but profiling tools to avoid
triggering CI runs.

* decrease fused matmul parametrizations

* remove long-running tests

* remove tf32 test for now

---------

Co-authored-by: Mark Saroufim <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants