-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Closed
Labels
Description
Describe the bug
Performing two back-to-back bmm calls on NVIDIA GPU triggers an internal assert. On triton 3.1.0:
python3: /project/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp:84: mlir::LogicalResult {anonymous}::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, const mlir::LLVMTypeConverter*, mlir::ConversionPatternRewriter&) const: Assertion `dstShape.size() <= 2 && "Unexpected rank of ConvertLayout(shared->blocked)"' failed.
and on main (d5ba6ac):
python3: /project/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:210: llvm::LogicalResult {anonymous}::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, const mlir::LLVMTypeConverter*, mlir::ConversionPatternRewriter&) const: Assertion `(dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"' failed.
Below is my WIP kernel that triggers the issue. The kernel is supposed to perform two batched matrix multiply and some reshapes to compute the forward pass of a block tensor train as in the reference einsum. The code gives the correct result when run using the interpreter. If this is a known issue with a workaround I would appreciate some help :)
import os
import torch
import triton
import triton.language as tl
from einops import einsum, rearrange
def btt_fwd_ref(x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
g, _, _, _, j = W1.shape
y = rearrange(x, "... (g i j) -> ... g i j", g=g, j=j)
out = einsum(y, W1, W2, "... g i j, g k r i j, g k o r i -> ... g k o")
return rearrange(out, "... g k o -> ... (g k o)")
@triton.jit
def _btt_fwd_kernel(
# Pointers to matrices
x_ptr,
w1_ptr,
w2_ptr,
out_ptr,
# Matrix dimensions
B,
G,
K,
I,
J,
O,
# Strides
# x
stride_xb,
stride_xg,
stride_xi,
stride_xj,
# W1
stride_w1g,
stride_w1k,
stride_w1r,
stride_w1i,
stride_w1j,
# W2
stride_w2g,
stride_w2k,
stride_w2o,
stride_w2r,
stride_w2i,
# out
stride_ob,
stride_og,
stride_ok,
stride_oo,
# Meta-parameters
RANK: tl.constexpr,
BLOCK_B: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_J: tl.constexpr,
):
# Program ID
pid1 = tl.program_id(0)
pid2 = tl.program_id(1)
# Compute batch and group indices
bb = pid1 // G * BLOCK_B
g = pid1 % G
o_blocks = tl.cdiv(O, BLOCK_K)
bk = pid2 // o_blocks * BLOCK_K
bo = pid2 % o_blocks * BLOCK_K
acc_out = tl.zeros((BLOCK_K, BLOCK_B, BLOCK_K), dtype=tl.float32) # g: k b o
for r in tl.range(0, RANK):
indsb = tl.arange(0, BLOCK_B)
indsj = tl.arange(0, BLOCK_J)
indsk = tl.arange(0, BLOCK_K)
for bi in tl.range(0, I, BLOCK_J):
# First matmul: x × W1
acc_inner = tl.zeros(
(BLOCK_J, BLOCK_B, BLOCK_K), dtype=tl.float32
) # g r: i b k
for bj in tl.range(0, J, BLOCK_J):
# Load block of x
x_block = tl.load(
x_ptr
+ g * stride_xg
+ (bb + indsb[:, None, None]) * stride_xb
+ (bi + indsj[None, :, None]) * stride_xi
+ (bj + indsj[None, None, :]) * stride_xj,
mask=(
bb + indsb[:, None, None] < B and bi + indsj[None, :, None] < I
)
and bj + indsj[None, None, :] < J,
other=0,
) # g: i b j
x_block = tl.trans(x_block, 1, 0, 2) # g: b i j -> i b j
# Load block of W1 (k j)
w1_block = tl.load(
w1_ptr
+ g * stride_w1g
+ r * stride_w1r
+ (bk + indsk[:, None, None]) * stride_w1k
+ (bi + indsj[None, :, None]) * stride_w1i
+ (bj + indsj[None, None, :]) * stride_w1j,
mask=(
bk + indsk[:, None, None] < K and bi + indsj[None, :, None] < I
)
and bj + indsj[None, None, :] < J,
other=0,
) # g r: k i j
w1_block = tl.trans(w1_block, 1, 2, 0) # g r: k i j -> i j k
# Accumulate
acc_inner += tl.dot(
x_block, w1_block
) # g r: (i b j) x (i j k) -> i b k
acc_inner = tl.trans(acc_inner, 2, 1, 0) # g r: i b k -> k b i
# Second matmul: acc × W2
w2_block = tl.load(
w2_ptr
+ g * stride_w2g
+ r * stride_w2r
+ (bk + indsk[:, None, None]) * stride_w2k
+ (bo + indsk[None, :, None]) * stride_w2o
+ (bi + indsj[None, None, :]) * stride_w2i,
mask=(bk + indsk[:, None, None] < K and bo + indsk[None, :, None] < O)
and bi + indsj[None, None, :] < I,
other=0,
) # g r: k o i
w2_block = tl.trans(w2_block, 0, 2, 1) # g r: k o i -> k i o
# Accumulate
acc_out += tl.dot(acc_inner, w2_block) # g r: (k b i) x (k i o) -> k b o
indsb = tl.arange(0, BLOCK_B)
indsk = tl.arange(0, BLOCK_K)
# Store result
tl.store(
out_ptr
+ g * stride_og
+ (bb + indsb[:, None, None]) * stride_ob
+ (bk + indsk)[None, :, None] * stride_ok
+ (bo + indsk)[None, None, :] * stride_oo,
tl.trans(acc_out, 1, 0, 2), # g: k b o -> b k o
mask=(bb + indsb[:, None, None] < B and bk + indsk[None, :, None] < K)
and (bo + indsk)[None, None, :] < O,
)
def _btt_fwd_triton(
x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor
) -> torch.Tensor:
G, K, R, I, J = W1.shape
_, _, O, _, _ = W2.shape
out = x.new_empty(x.shape[:-1] + (G * K * O,))
# ... (g i j) -> (...) g i j
x_view = x.view(-1, G, I, J)
out_view = out.view(-1, G, K, O)
B = x.shape[0]
BLOCK_B, BLOCK_K, BLOCK_J = 16, 16, 16
grid = (
triton.cdiv(B, BLOCK_B) * G,
triton.cdiv(K, BLOCK_K) * triton.cdiv(O, BLOCK_K),
)
_btt_fwd_kernel[grid](
x_view,
W1,
W2,
out_view,
B,
G,
K,
I,
J,
O,
*x_view.stride(),
*W1.stride(),
*W2.stride(),
*out_view.stride(),
BLOCK_B=BLOCK_B,
BLOCK_K=BLOCK_K,
BLOCK_J=BLOCK_J,
RANK=R,
)
return out
if __name__ == "__main__":
torch.manual_seed(0)
triton_interpreting = os.environ.get("TRITON_INTERPRET", "0") == "1"
print("Interpreting:", triton_interpreting, triton.__version__)
device = "cpu" if triton_interpreting else "cuda"
G, K, R, I, J, O = 2, 16, 1, 16, 16, 16
bsize, in_features, out_features = 128, 256, 256
x, W1, W2 = (
torch.randn(bsize, G * in_features, device=device),
torch.randn(G, K, R, I, J, device=device),
torch.randn(G, K, O, R, I, device=device),
)
y = _btt_fwd_triton(x, W1, W2)
y_ref = btt_fwd_ref(x, W1, W2)
torch.testing.assert_close(y, y_ref)
print("Succeeded")Environment details
Triton: Tested on 3.1.0 and main (d5ba6ac)
GPU: A100 and 4070 Ti
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Done