Skip to content

Fix for OOB smem access for the dot operand B#2198

Closed
alexander-zinoviev wants to merge 2 commits intotriton-lang:mainfrom
alexander-zinoviev:shmem-oob-access-fix
Closed

Fix for OOB smem access for the dot operand B#2198
alexander-zinoviev wants to merge 2 commits intotriton-lang:mainfrom
alexander-zinoviev:shmem-oob-access-fix

Conversation

@alexander-zinoviev
Copy link
Copy Markdown
Contributor

This is to fix an OOB access when ldmatrix from shmem for operand B.

A scenario where I faced the problem:

  1. There is a tl.dot(a, b) operation where the b tensor is fp16 and the shape is 16x16
  2. The thread block has 4 warps
  3. The tensor b is placed at the very end of the shared memory, so the end of the tensor touches the border of computed amount of required shared memory, Triton reports in the shared memory property
  4. We instruct CUDA to use the reported amount of shared memory and no more
  5. Then on sm75 the program crashes with illegal memory access. On sm80 it does not crash for unknown reason but the pointer analysis with prints confirms the ldmatrix loads the shared memory that goes beyond the tensor b region

The debugging suggested the offsets that are computed by computeLdmatrixMatOffs() function are not correct for warps #1 and #3. In fact both warps #1 and #3 are considered logically equivalent to each other, and the warpId value that is passed to the function is 1. The problem happens for the last 8 lanes of the warp. F.e for the very last lane (31) the math looks like:

rowInMat = 7
matIndex = 3 (bottom right 8x8 matrix of the whole 16x16)
s0 = s1 = 1
kOrder = 0
kMatArr = 1
nkMatArr = 1

warpId (logical) = 1
warpMatOffset = 1
inWarpMatOffset = 1

continuousMatIndex = 1 * 1 + 1 * 1 = 2 (points to a non-existing column of the matrix b)
contiguousTileNumMats = 16/8 = 2

then in the loop
contiguousIndex = 2 (for i = 0) and 3 (for i = 1)

then the guard against OOB does not help because
if (warpsPerCTA[order[0]] > contiguousTileNumMats || 2 > 2 - false
contiguousTileNumMats % warpsPerCTA[order[0]] != 0 2 % 2 != 0 - false
// wrap around

I propose to make the wrap around unconditional. After that all pointers that are used for ldmatrix are within the range and sm75 does not crash anymore.

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Aug 29, 2023

@zahimoud

@zahimoud
Copy link
Copy Markdown
Contributor

This is to fix an OOB access when ldmatrix from shmem for operand B.

A scenario where I faced the problem:

  1. There is a tl.dot(a, b) operation where the b tensor is fp16 and the shape is 16x16
  2. The thread block has 4 warps
  3. The tensor b is placed at the very end of the shared memory, so the end of the tensor touches the border of computed amount of required shared memory, Triton reports in the shared memory property
  4. We instruct CUDA to use the reported amount of shared memory and no more
  5. Then on sm75 the program crashes with illegal memory access. On sm80 it does not crash for unknown reason but the pointer analysis with prints confirms the ldmatrix loads the shared memory that goes beyond the tensor b region

The debugging suggested the offsets that are computed by computeLdmatrixMatOffs() function are not correct for warps #1 and #3. In fact both warps #1 and #3 are considered logically equivalent to each other, and the warpId value that is passed to the function is 1. The problem happens for the last 8 lanes of the warp. F.e for the very last lane (31) the math looks like:

rowInMat = 7 matIndex = 3 (bottom right 8x8 matrix of the whole 16x16) s0 = s1 = 1 kOrder = 0 kMatArr = 1 nkMatArr = 1

warpId (logical) = 1 warpMatOffset = 1 inWarpMatOffset = 1

continuousMatIndex = 1 * 1 + 1 * 1 = 2 (points to a non-existing column of the matrix b) contiguousTileNumMats = 16/8 = 2

then in the loop contiguousIndex = 2 (for i = 0) and 3 (for i = 1)

then the guard against OOB does not help because if (warpsPerCTA[order[0]] > contiguousTileNumMats || 2 > 2 - false contiguousTileNumMats % warpsPerCTA[order[0]] != 0 2 % 2 != 0 - false // wrap around

I propose to make the wrap around unconditional. After that all pointers that are used for ldmatrix are within the range and sm75 does not crash anymore.

Wondering why inWarpMatOffset is 1, and whether the issue is somewhere else. Can you share the CTALayout of the shared encoding of the tensor ?

@alexander-zinoviev
Copy link
Copy Markdown
Contributor Author

The warpsPerCTA is [2,2] for the tensor. Is it what you're looking for?

@alexander-zinoviev
Copy link
Copy Markdown
Contributor Author

alexander-zinoviev commented Aug 29, 2023

The simplest reproducer is (just a block inside my kernel)

(a_ptr and b_ptr are pointers to fp16 tensors)

a_offsets = tl.arange(0, 32)[:, None] * 16 + tl.arange(0, 16)[None, :]
a = tl.load(a_ptr + a_offsets)
b_offsets = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
b = tl.load(b_ptr + b_offsets)
ab = tl.dot(a, b)

and ldmatrix for b goes OOB to fetch the data from shmem

@zahimoud
Copy link
Copy Markdown
Contributor

The simplest reproducer is (just a block inside my kernel)

(a_ptr and b_ptr are pointers to fp16 tensors)

a_offsets = tl.arange(0, 32)[:, None] * 16 + tl.arange(0, 16)[None, :] a = tl.load(a_ptr + a_offsets) b_offsets = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] b = tl.load(b_ptr + b_offsets) ab = tl.dot(a, b)

and ldmatrix for b goes OOB to fetch the data from shmem

@alexander-zinoviev thanks. I will take a look after I finish the current task.

@zahimoud
Copy link
Copy Markdown
Contributor

@alexander-zinoviev can you paste the full ttgir ?

@alexander-zinoviev
Copy link
Copy Markdown
Contributor Author

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 1, warpsPerCTA = [2, 2]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @multi_head_attention_kernel_0d1d2d3d4d5d6d7d8d9d10d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<32x16xf32, #mma>
%cst_0 = arith.constant dense<16> : tensor<32x1xi32, #blocked>
%cst_1 = arith.constant dense<16> : tensor<16x1xi32, #blocked1>
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked>
%2 = arith.muli %1, %cst_0 : tensor<32x1xi32, #blocked>
%3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%5 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x16xi32, #blocked>
%6 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1>
%7 = tt.broadcast %2 : (tensor<32x1xi32, #blocked>) -> tensor<32x16xi32, #blocked>
%8 = tt.broadcast %5 : (tensor<1x16xi32, #blocked>) -> tensor<32x16xi32, #blocked>
%9 = arith.addi %7, %8 : tensor<32x16xi32, #blocked>
%10 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x16x!tt.ptr, #blocked>
%11 = tt.addptr %10, %9 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked>
%12 = tt.load %11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x16xf16, #blocked>
%13 = triton_gpu.convert_layout %12 : (tensor<32x16xf16, #blocked>) -> tensor<32x16xf16, #shared>
%14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%15 = tt.expand_dims %14 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xi32, #blocked1>
%16 = arith.muli %15, %cst_1 : tensor<16x1xi32, #blocked1>
%17 = tt.broadcast %16 : (tensor<16x1xi32, #blocked1>) -> tensor<16x16xi32, #blocked1>
%18 = tt.broadcast %6 : (tensor<1x16xi32, #blocked1>) -> tensor<16x16xi32, #blocked1>
%19 = arith.addi %17, %18 : tensor<16x16xi32, #blocked1>
%20 = tt.splat %arg4 : (!tt.ptr) -> tensor<16x16x!tt.ptr, #blocked1>
%21 = tt.addptr %20, %19 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1>
%22 = tt.load %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #blocked1>
%23 = triton_gpu.convert_layout %22 : (tensor<16x16xf16, #blocked1>) -> tensor<16x16xf16, #shared>
%24 = triton_gpu.convert_layout %13 : (tensor<32x16xf16, #shared>) -> tensor<32x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%25 = triton_gpu.convert_layout %23 : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%26 = tt.dot %24, %25, %cst {allowTF32 = true} : tensor<32x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x16xf32, #mma>
%27 = triton_gpu.convert_layout %26 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked>
%28 = arith.truncf %27 : tensor<32x16xf32, #blocked> to tensor<32x16xf16, #blocked>
%29 = tt.splat %arg10 : (!tt.ptr) -> tensor<32x16x!tt.ptr, #blocked>
%30 = tt.addptr %29, %9 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked>
tt.store %30, %28 {cache = 1 : i32, evict = 1 : i32} : tensor<32x16xf16, #blocked>
tt.return
}
}

@zahimoud
Copy link
Copy Markdown
Contributor

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 1, warpsPerCTA = [2, 2]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @multi_head_attention_kernel_0d1d2d3d4d5d6d7d8d9d10d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf32, #mma> %cst_0 = arith.constant dense<16> : tensor<32x1xi32, #blocked> %cst_1 = arith.constant dense<16> : tensor<16x1xi32, #blocked1> %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> %2 = arith.muli %1, %cst_0 : tensor<32x1xi32, #blocked> %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> %5 = tt.expand_dims %3 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x16xi32, #blocked> %6 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x16xi32, #blocked1> %7 = tt.broadcast %2 : (tensor<32x1xi32, #blocked>) -> tensor<32x16xi32, #blocked> %8 = tt.broadcast %5 : (tensor<1x16xi32, #blocked>) -> tensor<32x16xi32, #blocked> %9 = arith.addi %7, %8 : tensor<32x16xi32, #blocked> %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<32x16x!tt.ptr, #blocked> %11 = tt.addptr %10, %9 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> %12 = tt.load %11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x16xf16, #blocked> %13 = triton_gpu.convert_layout %12 : (tensor<32x16xf16, #blocked>) -> tensor<32x16xf16, #shared> %14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %15 = tt.expand_dims %14 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<16x1xi32, #blocked1> %16 = arith.muli %15, %cst_1 : tensor<16x1xi32, #blocked1> %17 = tt.broadcast %16 : (tensor<16x1xi32, #blocked1>) -> tensor<16x16xi32, #blocked1> %18 = tt.broadcast %6 : (tensor<1x16xi32, #blocked1>) -> tensor<16x16xi32, #blocked1> %19 = arith.addi %17, %18 : tensor<16x16xi32, #blocked1> %20 = tt.splat %arg4 : (!tt.ptr) -> tensor<16x16x!tt.ptr, #blocked1> %21 = tt.addptr %20, %19 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> %22 = tt.load %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #blocked1> %23 = triton_gpu.convert_layout %22 : (tensor<16x16xf16, #blocked1>) -> tensor<16x16xf16, #shared> %24 = triton_gpu.convert_layout %13 : (tensor<32x16xf16, #shared>) -> tensor<32x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %25 = triton_gpu.convert_layout %23 : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %26 = tt.dot %24, %25, %cst {allowTF32 = true} : tensor<32x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x16xf32, #mma> %27 = triton_gpu.convert_layout %26 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked> %28 = arith.truncf %27 : tensor<32x16xf32, #blocked> to tensor<32x16xf16, #blocked> %29 = tt.splat %arg10 : (!tt.ptr) -> tensor<32x16x!tt.ptr, #blocked> %30 = tt.addptr %29, %9 : tensor<32x16x!tt.ptr, #blocked>, tensor<32x16xi32, #blocked> tt.store %30, %28 {cache = 1 : i32, evict = 1 : i32} : tensor<32x16xf16, #blocked> tt.return } }

Is this generated from top of main branch ?

@alexander-zinoviev
Copy link
Copy Markdown
Contributor Author

I think we are currently at fb265d3

@alexander-zinoviev
Copy link
Copy Markdown
Contributor Author

The bug does not reproduce at the OSS head.
It does reproduce at the OSS 5df9042 where we integrated our current version from.

I am trying to bisect to find when it was fixed.

@alexander-zinoviev
Copy link
Copy Markdown
Contributor Author

Confirmed the bug was fixed by @qliu93 on 8/15 with #2109, and no change is required this time.

I close the PR.

@alexander-zinoviev alexander-zinoviev deleted the shmem-oob-access-fix branch September 21, 2023 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants