Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[not for land yet] hack max and abs out of ops eligible for AC #580

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Sep 17, 2024

Summary:

For now, this is not for land and just saving work and starting a discussion.

We need to calculate max(abs(tensor)) for each float8 gemm input when using per-tensor scaling. I realized that today this does not work efficiently with AC, because max(abs(tensor)) is usually recomputed. Since the output size is 1, it's more efficient to save it and never recompute.

For now, just hack these ops into the do-not-recompute list to get a perf measurement. Seems to save ~1% on LLaMa 3B on 8 H100 GPUs if using "op" selective AC mode, which is the easiest mode to hack this onto. I'd expect the benefit to be less if float8 all-gather is turned on. I verified in the pre-post traces that the redundant triton kernels to calculate max(abs(activation)) and max(abs(weight)) are gone with this hack.

Heading to PTC but we should get a measurement on a larger model, and figure out a better way to land this.

Test Plan:

https://gist.github.com/vkuzo/375230e30e1cb599ad31a87e0be25d75

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

For now, this is not for land and just saving work and starting a
discussion.

We need to calculate max(abs(tensor)) for each float8 gemm input
when using per-tensor scaling.  I realized that today this does
not work efficiently with AC, because max(abs(tensor)) is usually
recomputed. Since the output size is 1, it's more efficient to
save it and never recompute.

For now, just hack these ops into the do-not-recompute list
to get a perf measurement. Seems to save ~1% on LLaMa 3B on 8 H100 GPUs.
I verified in the pre-post traces that the redundant triton kernels
to calculate max(abs(activation)) and max(abs(weight)) are gone with
this hack.

Heading to PTC but we should get a measurement on a larger model, and
figure out a better way to land this.

Test Plan:

https://gist.github.com/vkuzo/375230e30e1cb599ad31a87e0be25d75

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants