Skip to content

[NVPTX] performance regression caused by register splitting/merging combined with wgmma instructions. #151580

@Artem-B

Description

@Artem-B

Some users reported substantial performance regression (apparently in nvptx) attributed to register splitting/merging after recent LLVM changes.

[kernel] keeps the WGMMA accumulator registers as <2 x float> and unpacks them right before passing them into WGMMA:

...
  %571 = extractelement <2 x float> %458, i64 0, !dbg !10
  %572 = extractelement <2 x float> %458, i64 1, !dbg !10
  %573 = extractelement <2 x float> %457, i64 0, !dbg !10
  %574 = extractelement <2 x float> %457, i64 1, !dbg !10
  %575 = call { float, float, ... } asm sideeffect "{ .reg .pred p; setp.ne.b32 p, $130, 0; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {$0,$1,$2,$3,$4,...}, $128, $129, p, $131, $132, $133, $134; }\0A", "=f,=f,=f,=f,=f,...,0,1,2,3,4,...,l,l,n,n,n,n,n"(float %511, float %512, float %513, float %514, ..., float %568, float %569, float %570, float %571, float %572, float %573, float %574, i64 %506, i64 %510, i32 1, i32 1, i32 1, i32 0, i32 1) #5, !dbg !10
  ...

Before those changes, NVPTX would always represent the <2 x float> as 2 32-bit registers and the extractelement ops would not be visible in the generated ptx. However, if I look at what happens now, I see this:

        ...
	mov.b64 	{%r1533, %r1534}, %rd281;
	mov.b64 	{%r1535, %r1536}, %rd280;
	mov.b64 	{%r1537, %r1538}, %rd279;
	mov.b64 	{%r1539, %r1540}, %rd278;
	mov.b64 	{%r1541, %r1542}, %rd277;
	// begin inline asm
	{ .reg .pred p; setp.ne.b32 p, 1, 0; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%r1479,%r1480,%r1481,%r1482,%r1483,%r1484,%r1485,%r1486,%r1487,%r1488,%r1489,%r1490,%r1491,%r1492,%r1493,%r1494,%r1495,%r1496,%r1497,%r1498,%r1499,%r1500,%r1501,%r1502,%r1503,%r1504,%r1505,%r1506,%r1507,%r1508,%r1509,%r1510,%r1511,%r1512,%r1513,%r1514,%r1515,%r1516,%r1517,%r1518,%r1519,%r1520,%r1521,%r1522,%r1523,%r1524,%r1525,%r1526,%r1527,%r1528,%r1529,%r1530,%r1531,%r1532,%r1533,%r1534,%r1535,%r1536,%r1537,%r1538,%r1539,%r1540,%r1541,%r1542}, %rd235, %rd220, p, 1, 1, 0, 1; }
	...

and this is a problem, because ptxas really doesn't like it when you touch registers that end up as WGMMA accumulators between a wgmma.fence and the wgmma instruction, even if it's for something as trivial as this. It emits a warning saying that it will insert additional waits to make sure this is safe, and this slows down the whole kernel.

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions