Skip to content

[AMD] Move MFMA shortcut check to not compute scratch buffer shape if it is not needed#4161

Merged
ptillet merged 2 commits intotriton-lang:release/3.0.xfrom
amjames:cherry-pick-3816-3790
Jun 18, 2024
Merged

[AMD] Move MFMA shortcut check to not compute scratch buffer shape if it is not needed#4161
ptillet merged 2 commits intotriton-lang:release/3.0.xfrom
amjames:cherry-pick-3816-3790

Conversation

@amjames
Copy link
Copy Markdown
Contributor

@amjames amjames commented Jun 18, 2024

These need to be in this order to apply cleanly.

Cherry pick #3816 and #3790 for release/3.0.x.

binarman and others added 2 commits June 18, 2024 21:31
… it is not needed (triton-lang#3790)

This PR:
- moves shortcut check earlier, to not compute scratch buffer shape if
it is not needed
- raise priority of AMD specific over common conversions to eliminate
uncertainty which pattern to apply.
- add regression test for MFMA to Dot Op shortcut
This PR enables denorm flushing for `tl.math.exp2` and preserves denorms
for `tl.math.exp`, which match their behaviors on Nvidia backend.

More specifically, 
- denorm flushing for tl.math.exp2 with f32 inputs is controlled by
`__CUDA_FTZ` or `__HIP_FTZ` and the default is set to flushing denorm.
These flags can be set by developers, but are not exposed as kernel
argument.

tl.math.exp2(f32) | NV | NV | AMD | AMD
-- | -- | -- | -- | --
control flag | __CUDA_FTZ=1 (default) | __CUDA_FTZ=0 | __HIP_FTZ=1
(default) | __HIP_FTZ=0
device lib | __nv_exp2f | __nv_exp2f |  | 
llvm intrinsics | llvm.nvvm.ex2.approx.ftz.f | llvm.nvvm.ex2.approx.f |
llvm.amdgcn.exp2.f32 | llvm.exp2.f32
ptx | ex2.approx.ftz.f32 | ex2.approx.f32 |   |  
sass/amdgcn | MUFU.EX2 | MUFU.EX2<br>and instructions to<br>check and
adjust for<br>denorms | v_exp_f32 | v_exp_f32<br>and instructions<br>to
check and<br>adjust for<br>denorms
- denorms are preserved for tl.math.exp2 with f64 inputs

tl.math.exp2(f64) | NV | AMD
-- | -- | --
device lib | __nv_exp2 | __ocml_exp2_f64
- denorms are preserved for tl.math.exp with both f32 and f64 inputs.
Note that tl.math.exp(f32) on nv path is lowered with inline ptx
directly without the `.ftz` flag.

tl.math.exp(f32) | NV | AMD
-- | -- | --
llvm intrinsics |   | llvm.exp2.f32
ptx | ex2.approx.f32 |  


tl.math.exp(f64) | NV | AMD
-- | -- | --
device lib | __nv_exp | __ocml_exp_f64
@amjames amjames changed the base branch from main to release/3.0.x June 18, 2024 21:39
@ptillet ptillet merged commit 3d1da66 into triton-lang:release/3.0.x Jun 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants