-
Notifications
You must be signed in to change notification settings - Fork 15.2k
Description
Some users reported substantial performance regression (apparently in nvptx) attributed to register splitting/merging after recent LLVM changes.
[kernel] keeps the WGMMA accumulator registers as <2 x float> and unpacks them right before passing them into WGMMA:
... %571 = extractelement <2 x float> %458, i64 0, !dbg !10 %572 = extractelement <2 x float> %458, i64 1, !dbg !10 %573 = extractelement <2 x float> %457, i64 0, !dbg !10 %574 = extractelement <2 x float> %457, i64 1, !dbg !10 %575 = call { float, float, ... } asm sideeffect "{ .reg .pred p; setp.ne.b32 p, $130, 0; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {$0,$1,$2,$3,$4,...}, $128, $129, p, $131, $132, $133, $134; }\0A", "=f,=f,=f,=f,=f,...,0,1,2,3,4,...,l,l,n,n,n,n,n"(float %511, float %512, float %513, float %514, ..., float %568, float %569, float %570, float %571, float %572, float %573, float %574, i64 %506, i64 %510, i32 1, i32 1, i32 1, i32 0, i32 1) #5, !dbg !10 ...
Before those changes, NVPTX would always represent the <2 x float> as 2 32-bit registers and the extractelement ops would not be visible in the generated ptx. However, if I look at what happens now, I see this:
... mov.b64 {%r1533, %r1534}, %rd281; mov.b64 {%r1535, %r1536}, %rd280; mov.b64 {%r1537, %r1538}, %rd279; mov.b64 {%r1539, %r1540}, %rd278; mov.b64 {%r1541, %r1542}, %rd277; // begin inline asm { .reg .pred p; setp.ne.b32 p, 1, 0; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%r1479,%r1480,%r1481,%r1482,%r1483,%r1484,%r1485,%r1486,%r1487,%r1488,%r1489,%r1490,%r1491,%r1492,%r1493,%r1494,%r1495,%r1496,%r1497,%r1498,%r1499,%r1500,%r1501,%r1502,%r1503,%r1504,%r1505,%r1506,%r1507,%r1508,%r1509,%r1510,%r1511,%r1512,%r1513,%r1514,%r1515,%r1516,%r1517,%r1518,%r1519,%r1520,%r1521,%r1522,%r1523,%r1524,%r1525,%r1526,%r1527,%r1528,%r1529,%r1530,%r1531,%r1532,%r1533,%r1534,%r1535,%r1536,%r1537,%r1538,%r1539,%r1540,%r1541,%r1542}, %rd235, %rd220, p, 1, 1, 0, 1; } ...and this is a problem, because ptxas really doesn't like it when you touch registers that end up as WGMMA accumulators between a wgmma.fence and the wgmma instruction, even if it's for something as trivial as this. It emits a warning saying that it will insert additional waits to make sure this is safe, and this slows down the whole kernel.