Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autotuner for int mm Triton kernels #41

Merged
merged 38 commits into from
Mar 21, 2024
Merged

Autotuner for int mm Triton kernels #41

merged 38 commits into from
Mar 21, 2024

Conversation

cpuhrsch
Copy link
Contributor

@cpuhrsch cpuhrsch commented Mar 3, 2024

[ 8:58PM (nightly20240311py310) /scratch/cpuhrsch/dev/ao/benchmarks - git:intmmbenchmarks1]$ TORCHAO_AUTOTUNER_ENABLE=0 python intmm.py sam_shapes.csv
fn,m,k,n,fp_time,int_mm_time,ratio
<function run_int_mm_benchmark at 0x7f42252f29e0>,32768,3072,768,0.7926271820068359,3.6111566162109376,0.219493992159917
<function run_int_mm_benchmark at 0x7f42252f29e0>,32768,768,2304,0.4677939224243164,2.7787774658203124,0.16834522669710103
<function run_int_mm_benchmark at 0x7f42252f29e0>,32768,768,3072,0.6118912124633789,3.285841979980469,0.1862205231387956
<function run_int_mm_benchmark at 0x7f42252f29e0>,32768,768,768,0.17190912246704101,1.7368576049804687,0.09897709632274326
<function run_int_mm_benchmark at 0x7f42252f29e0>,39200,768,2304,0.579788818359375,3.1330712890625,0.1850544609002062
<function run_int_mm_benchmark at 0x7f42252f29e0>,39200,768,768,0.20057088851928712,1.8391961669921875,0.10905355944020902
<function run_int_scaled_mm_benchmark at 0x7f430a506b00>,32768,3072,768,0.7390310668945312,3.5631512451171874,0.20740940141322164
<function run_int_scaled_mm_benchmark at 0x7f430a506b00>,32768,768,2304,0.8094003295898438,3.315210266113281,0.24414750939426128
<function run_int_scaled_mm_benchmark at 0x7f430a506b00>,32768,768,3072,1.0603622436523437,4.030166931152344,0.2631062836271039
<function run_int_scaled_mm_benchmark at 0x7f430a506b00>,32768,768,768,0.28380159378051756,1.9123712158203126,0.14840298339189378
<function run_int_scaled_mm_benchmark at 0x7f430a506b00>,39200,768,2304,0.9600614166259765,3.763394470214844,0.2551051781109645
<function run_int_scaled_mm_benchmark at 0x7f430a506b00>,39200,768,768,0.3342233657836914,2.057533416748047,0.16243885181312626
[ 8:58PM (nightly20240311py310) /scratch/cpuhrsch/dev/ao/benchmarks - git:intmmbenchmarks1]$ TORCHAO_AUTOTUNER_ENABLE=1 python intmm.py sam_shapes.csv
fn,m,k,n,fp_time,int_mm_time,ratio
INFO:root:Trying to load configs for NVIDIA A100-SXM4-40GB from /scratch/cpuhrsch/dev/ao/torchao/kernel/configs/data_a100.pkl
INFO:root:Loading best configs from file /scratch/cpuhrsch/dev/ao/torchao/kernel/configs/data_a100.pkl
<function run_int_mm_benchmark at 0x7fbfe839a9e0>,32768,3072,768,0.7970508575439453,0.39620609283447267,2.0117077247394524
<function run_int_mm_benchmark at 0x7fbfe839a9e0>,32768,768,2304,0.5419622421264648,0.41020416259765624,1.32120122500571
<function run_int_mm_benchmark at 0x7fbfe839a9e0>,32768,768,3072,0.7071027374267578,0.5336064147949219,1.3251391246833395
<function run_int_mm_benchmark at 0x7fbfe839a9e0>,32768,768,768,0.18099199295043944,0.13851648330688476,1.3066458852369918
<function run_int_mm_benchmark at 0x7fbfe839a9e0>,39200,768,2304,0.5925785446166992,0.45351936340332033,1.306622368160523
<function run_int_mm_benchmark at 0x7fbfe839a9e0>,39200,768,768,0.2048409652709961,0.1618739128112793,1.26543531143162
<function run_int_scaled_mm_benchmark at 0x7fc0cd6bab00>,32768,3072,768,0.7551487731933594,0.36495361328125,2.069163712078135
<function run_int_scaled_mm_benchmark at 0x7fc0cd6bab00>,32768,768,2304,0.8089702606201172,0.4003123092651367,2.0208478278001594
<function run_int_scaled_mm_benchmark at 0x7fc0cd6bab00>,32768,768,3072,1.0633113861083985,0.530063362121582,2.006008077699405
<function run_int_scaled_mm_benchmark at 0x7fc0cd6bab00>,32768,768,768,0.28652544021606446,0.17480703353881835,1.6390956039674311
<function run_int_scaled_mm_benchmark at 0x7fc0cd6bab00>,39200,768,2304,0.9518489837646484,0.4832358551025391,1.9697399804132356
<function run_int_scaled_mm_benchmark at 0x7fc0cd6bab00>,39200,768,768,0.33665023803710936,0.17441791534423828,1.9301356593597783

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 3, 2024
@cpuhrsch cpuhrsch marked this pull request as ready for review March 14, 2024 20:47
@cpuhrsch cpuhrsch requested a review from HDCharles March 14, 2024 20:53
@facebook-github-bot
Copy link

@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cpuhrsch cpuhrsch requested a review from msaroufim March 14, 2024 20:55
@msaroufim
Copy link
Member

About to head into a meeting but will give this a proper read, any chance we could add a test?

@facebook-github-bot
Copy link

@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link

@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some nits


Set this to a nonzero value to enable the kernels generated by the autotuner. This is turned off by default, because it is still an experimental feature and also can take a long time to run.

Searching a new config can take a long time and we'll save the updated data in `data.pkl`. If you'd like to contributed updated configs for your hardware or shapes, please open a pull request.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presumably people won't contribute the pickle file since that's not human readable? Also kind of a security issues for us to host pickle files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added https://github.com/pytorch-labs/ao/pull/41/files#diff-4986e1d3257adc0a73b17fd6f21ef9d3b2c0eaec9027381e6f2de89e5be0e6b5 to make it easier to inspect. It stores the triton Configs, so it's a bit more difficult to make them human readable by default.




def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this in benchmark util instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once I add the next benchmark for weight only

benchmarks/intmm_shapes.csv Show resolved Hide resolved
@@ -0,0 +1,7 @@
m,k,n
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presumably you mean shapes of matmuls in sam?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, SAM vit_b batch size 16 to be precise

[
("cuda", torch.bfloat16),
("cuda", torch.bfloat16),
# ("cpu", torch.bfloat16),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll turn those into TODOs. It should also work on CPU.

@@ -0,0 +1,19 @@
## Autotuner and custom Triton kernels

### Use case
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is intent to fill this out later? Might be better to open an issue if you'd like to do this later


:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warmup is not a time it seems to be an int so not in ms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't milliseconds be given in int?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just saying "run warmup until at least 25ms were spent"


BEST_CONFIGS = None

AUTOTUNER_DATA_PATH = os.getenv('TORCHAO_AUTOTUNER_DATA_PATH', None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put all global variables at the top of the file so they're easier to find

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

int8_powers_of_two,
int8_powers_of_two)], [])

# int8_mm_kernel_configs = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to leave these as reference from core. I'll add a comment.

import triton.language as tl
import itertools
import os
int8_powers_of_two = [32, 64, 128, 256]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you envision people wanting to add more options here and for int8 kernel configs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, yes. Follow up here includes making it more extensible for other kernels. Adding support for mixed precision should help that.

@msaroufim msaroufim self-requested a review March 19, 2024 20:00
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unblocking for now

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unblocking for now

@facebook-github-bot
Copy link

@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link

@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cpuhrsch cpuhrsch merged commit efb6514 into main Mar 21, 2024
2 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants