Fix for OOB smem access for the dot operand B#2198
Fix for OOB smem access for the dot operand B#2198alexander-zinoviev wants to merge 2 commits intotriton-lang:mainfrom
Conversation
Wondering why |
|
The warpsPerCTA is [2,2] for the tensor. Is it what you're looking for? |
|
The simplest reproducer is (just a block inside my kernel) (a_ptr and b_ptr are pointers to fp16 tensors) a_offsets = tl.arange(0, 32)[:, None] * 16 + tl.arange(0, 16)[None, :] and ldmatrix for b goes OOB to fetch the data from shmem |
@alexander-zinoviev thanks. I will take a look after I finish the current task. |
|
@alexander-zinoviev can you paste the full ttgir ? |
|
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> |
Is this generated from top of |
|
I think we are currently at fb265d3 |
|
The bug does not reproduce at the OSS head. I am trying to bisect to find when it was fixed. |
This is to fix an OOB access when ldmatrix from shmem for operand B.
A scenario where I faced the problem:
The debugging suggested the offsets that are computed by computeLdmatrixMatOffs() function are not correct for warps #1 and #3. In fact both warps #1 and #3 are considered logically equivalent to each other, and the warpId value that is passed to the function is 1. The problem happens for the last 8 lanes of the warp. F.e for the very last lane (31) the math looks like:
rowInMat = 7
matIndex = 3 (bottom right 8x8 matrix of the whole 16x16)
s0 = s1 = 1
kOrder = 0
kMatArr = 1
nkMatArr = 1
warpId (logical) = 1
warpMatOffset = 1
inWarpMatOffset = 1
continuousMatIndex = 1 * 1 + 1 * 1 = 2 (points to a non-existing column of the matrix b)
contiguousTileNumMats = 16/8 = 2
then in the loop
contiguousIndex = 2 (for i = 0) and 3 (for i = 1)
then the guard against OOB does not help because
if (warpsPerCTA[order[0]] > contiguousTileNumMats || 2 > 2 - false
contiguousTileNumMats % warpsPerCTA[order[0]] != 0 2 % 2 != 0 - false
// wrap around
I propose to make the wrap around unconditional. After that all pointers that are used for ldmatrix are within the range and sm75 does not crash anymore.