Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 20x slowdown of FP6 kernel due to device properties query #1092

Merged
merged 1 commit into from
Oct 16, 2024

Conversation

tobiasvanderwerff
Copy link
Contributor

While working on the FP6 kernel, I noticed a large slowdown of FP6 compared to FP16 (e.g. 1125 ms FP6 vs. 90 ms FP16 for m=1, k=8192, n=8192, as shown in the table below). This is not supposed to happen, since FP6 is expected to be faster than FP16 for at least most of the smaller matrix multiply shapes.

The slowdown occurs ever since the FP6 for SM75 PR (#942). This change was tested for SM75 (T4 GPU), but not for an A100 (SM80). Now that I've tried it on an A100, I noticed the slowdown. After digging for a bit, the problem seems to be the code that checks for the GPU compute capability to determine the launch parameters:

inline bool isSM75GPU() {
int device;
cudaError_t err = cudaGetDevice(&device);
if (err != cudaSuccess) {
return false;
}
cudaDeviceProp props;
err = cudaGetDeviceProperties(&props, device);
if (err != cudaSuccess) {
return false;
}
return (props.major == 7) && (props.minor == 5);
}

When I initially wrote this, I did a quick web search to check whether this code would incur any significant overhead from querying the device properties. The answer came back negative, but it turns out this was wrong: calling cudaGetDeviceProperties() can incur significant runtime overhead, as discussed in this NVIDIA forum post. The problem is that this function call returns all device properties, even those that require expensive PCIe reads. Instead, it is better to directly query a specific device property, using cudaDeviceGetAttribute(), which is what the current PR does.

Benchmark results

Results of benchmarks/benchmark_fp6.py

Tested on:

  • A100 80GB
  • Torch version: 2.4.1

Before:

m k n fp6_latency (ms) fp16_latency (ms) speedup (d/s) correct
1 8192 8192 1125.46 90.4927 0.0804052 1
1 8192 10240 1000.41 107.788 0.107743 1
1 8192 57344 1068.93 555.302 0.519496 1
1 28672 8192 1025.68 281.582 0.274532 1
2 8192 8192 1046.39 90.9669 0.0869344 1
2 8192 10240 1144.23 110.408 0.0964911 1
2 8192 57344 1069.79 566.677 0.529709 1
2 28672 8192 1007.03 277.57 0.275633 1
4 8192 8192 1000.59 90.9182 0.0908643 1
4 8192 10240 1009.25 111.002 0.109985 1
4 8192 57344 1073.72 570.032 0.530892 1
4 28672 8192 1039.15 278.076 0.267599 1
8 8192 8192 1032.79 91.5243 0.0886185 1
8 8192 10240 1000.24 112.728 0.1127 1
8 8192 57344 1077.23 576.736 0.53539 1
8 28672 8192 1012.57 278.99 0.275527 1
16 8192 8192 1001.91 92.3847 0.0922085 1
16 8192 10240 1000.84 115.095 0.114998 1
16 8192 57344 1083.48 587.268 0.542023 1
16 28672 8192 1012.82 281.066 0.277508 1
32 8192 8192 1092.25 92.1624 0.0843783 1
32 8192 10240 1022.55 111.275 0.108821 1
32 8192 57344 1166.29 590.005 0.505883 1
32 28672 8192 1023.64 289.989 0.283292 1
64 8192 8192 1020.78 93.5926 0.0916872 1
64 8192 10240 1021.62 110.602 0.108261 1
64 8192 57344 1145.79 584.46 0.510093 1
64 28672 8192 1042.67 296.394 0.284264 1
128 8192 8192 1036.82 114.424 0.11036 1
128 8192 10240 1047.11 130.439 0.12457 1
128 8192 57344 1271.59 690.167 0.542758 1
128 28672 8192 1073.64 316.786 0.295059 1
256 8192 8192 1078.36 170.566 0.158172 1
256 8192 10240 1067.59 215.302 0.20167 1
256 8192 57344 1492.74 1075.21 0.720291 1
256 28672 8192 1173.1 523.931 0.446622 1
512 8192 8192 1145.68 331.623 0.289456 1
512 8192 10240 1129.63 362.692 0.321071 1
512 8192 57344 2548.14 2005.06 0.786871 1
512 28672 8192 1435.92 1009.66 0.70315 1

After (this PR):

m k n fp6_latency (ms) fp16_latency (ms) speedup (d/s) correct
1 8192 8192 43.7752 89.2689 2.03926 1
1 8192 10240 48.1101 107.697 2.23855 1
1 8192 57344 217.711 553.976 2.54455 1
1 28672 8192 118.12 281.185 2.3805 1
2 8192 8192 45.3615 90.577 1.99678 1
2 8192 10240 48.4927 112.666 2.32337 1
2 8192 57344 220.861 567.734 2.57055 1
2 28672 8192 119.004 277.219 2.32949 1
4 8192 8192 44.1525 90.395 2.04734 1
4 8192 10240 50.0794 113.646 2.26932 1
4 8192 57344 225.299 570.694 2.53306 1
4 28672 8192 119.585 277.762 2.32272 1
8 8192 8192 44.8787 91.2936 2.03423 1
8 8192 10240 50.5082 114.593 2.26879 1
8 8192 57344 234.011 576.682 2.46434 1
8 28672 8192 121.05 279.179 2.30631 1
16 8192 8192 47.889 92.6836 1.93538 1
16 8192 10240 54.9626 117.292 2.13403 1
16 8192 57344 261.265 590.183 2.25895 1
16 28672 8192 131.584 280.924 2.13494 1
32 8192 8192 51.9708 92.1966 1.77401 1
32 8192 10240 59.8898 111.895 1.86836 1
32 8192 57344 307.97 585.129 1.89995 1
32 28672 8192 139.128 289.797 2.08294 1
64 8192 8192 69.7171 93.2411 1.33742 1
64 8192 10240 81.1042 110.693 1.36482 1
64 8192 57344 447.616 585.569 1.30819 1
64 28672 8192 195.54 296.719 1.51743 1
128 8192 8192 117.214 112.366 0.958636 1
128 8192 10240 150.636 136.439 0.905753 1
128 8192 57344 800.991 687.209 0.857948 1
128 28672 8192 364.892 315.869 0.865652 1
256 8192 8192 227.643 170.793 0.750265 1
256 8192 10240 272.679 217.37 0.797165 1
256 8192 57344 1485.88 1083.08 0.728916 1
256 28672 8192 673.345 519.406 0.771381 1
512 8192 8192 445.791 332.247 0.745299 1
512 8192 10240 496.402 368.659 0.742663 1
512 8192 57344 2564.5 2027.05 0.790427 1
512 28672 8192 1327.19 1013.2 0.763421 1

Copy link

pytorch-bot bot commented Oct 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1092

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c1732e6 with merge base 0b71b8d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 16, 2024
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@supriyar
Copy link
Contributor

awesome! possible to cherry-pick this into the next release? cc @drisspg

@msaroufim msaroufim merged commit c87cc9b into pytorch:main Oct 16, 2024
17 checks passed
@gau-nernst
Copy link
Collaborator

The numbers look good already. But I wonder if we can cache this result? (e.g. use static variable or global variable - I know, frowned upon by C++ devs 🤣). Will using PyTorch API be better? (i.e. torch.cuda.get_device_capability() - and does this function cache the result?)

(Again, the current state looks good already. Just something I wonder)

@msaroufim
Copy link
Member

@HDCharles do you have a good sense of the state of our benchmarking story? Right now perf regressions are only ever caught manually for llama in your spot checks. Are we in a place we can start failing if kernels like the fp6 one become 20x slower?

@tobiasvanderwerff
Copy link
Contributor Author

@gau-nernst just FYI, FlashAttention uses at::cuda::getCurrentDeviceProperties(); (source), which I also considered. However, I couldn't find documentation on whether it has the same problem (querying all device properties at once, which is slow), or that it handles this in some more optimized way.

@tobiasvanderwerff tobiasvanderwerff deleted the fp6-slowdown-fix branch October 17, 2024 06:12
drisspg pushed a commit that referenced this pull request Oct 17, 2024
Replace `cudaGetDeviceProperties` with `cudaDeviceGetAttribute`
atalman pushed a commit to atalman/ao that referenced this pull request Oct 21, 2024
…h#1092)

Replace `cudaGetDeviceProperties` with `cudaDeviceGetAttribute`
atalman added a commit that referenced this pull request Oct 21, 2024
Fix 20x slowdown of FP6 kernel due to device properties query (#1092)

Replace `cudaGetDeviceProperties` with `cudaDeviceGetAttribute`

Co-authored-by: Tobias van der Werff <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
…at/ folder (pytorch#1076)

* [Hackability Refactor] Move known_model_params under torchchat (pytorch#1073)

* [Hackability Refactor] Migrate CLI call sites to explicitly go through torchchat.py (pytorch#1075)

* [Hackability Refactor] Move model.py underneath torchchat/ (pytorch#1077)

* Move model.py

* Clear out init to avoid package circular import

* [Hackability Refactor] Move select top level docs into folders within torchchat (pytorch#1080)

* [Hackability Refactor] Move the top level util folder into torchchat/utils (pytorch#1079)

* [Hackability Refactor] Move the top level util file into torchchat/utils/

* Cleared out init to avoid packing

* [Hackability Refactor] Collapse gguf_util into gguf_loader (pytorch#1078)

* [Hackability Refactor] Collapse gguf_util into gguf_loader

* Update bad import

* [Hackability Refactor] Move model_config into torchchat/model_config (pytorch#1082)

* [Hackability Refactor] Move cli related files under torchchat/cli (pytorch#1083)

* [Hackability Refactor] Move build/util into torchchat/utils (pytorch#1084)

* [Hackability Refactor] Easy Moves: eval, gguf_loader, quantize, model_dist (pytorch#1085)

* [Hackability Refactor] Easy Cheap Moves: eval, gguf_loader, quantize, model_dist

* Update eval.py call sites that slipped through the initial pass

* [Hackability Refactor] Update missed direct file calls to use torchchat.py (pytorch#1088)

* [Hackability Refactor] Move export and generate under torchchat/ (pytorch#1089)

* [Hackability Refactor] Move scripts under torchchat/utils (pytorch#1090)

* [Hackability Refactor] Move scripts under torchchat/utils

* Fix install script for AOTI

* Update referenced path in build_android

* Adding missing utils path

* Add another layer for torchchat

* Move the source command depending on if TC root is defined

* [Hackability Refactor] Move installation related files into install/ (pytorch#1081)

* [Hackability Refactor] Move installation related files into install/

* Fix install req path

* Test fix with install path for bash

* Debug messages

* Remove changes to install in et_python_libs

* Remove debug echo

* Fix pin path for et

* [Hackability Refactor] Restricted Lint (pytorch#1091)

* [Hackability Refactor] Removing __main__ from export/generate/eval (pytorch#1092)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants