[not for land yet] hack max and abs out of ops eligible for AC #580
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
For now, this is not for land and just saving work and starting a discussion.
We need to calculate max(abs(tensor)) for each float8 gemm input when using per-tensor scaling. I realized that today this does not work efficiently with AC, because max(abs(tensor)) is usually recomputed. Since the output size is 1, it's more efficient to save it and never recompute.
For now, just hack these ops into the do-not-recompute list to get a perf measurement. Seems to save ~1% on LLaMa 3B on 8 H100 GPUs if using "op" selective AC mode, which is the easiest mode to hack this onto. I'd expect the benefit to be less if float8 all-gather is turned on. I verified in the pre-post traces that the redundant triton kernels to calculate max(abs(activation)) and max(abs(weight)) are gone with this hack.
Heading to PTC but we should get a measurement on a larger model, and figure out a better way to land this.
Test Plan:
https://gist.github.com/vkuzo/375230e30e1cb599ad31a87e0be25d75
Reviewers:
Subscribers:
Tasks:
Tags: