Skip to content

BMM into BMM triggers internal assert #5211

@falkaer

Description

@falkaer

Describe the bug

Performing two back-to-back bmm calls on NVIDIA GPU triggers an internal assert. On triton 3.1.0:

python3: /project/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp:84: mlir::LogicalResult {anonymous}::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, const mlir::LLVMTypeConverter*, mlir::ConversionPatternRewriter&) const: Assertion `dstShape.size() <= 2 && "Unexpected rank of ConvertLayout(shared->blocked)"' failed.

and on main (d5ba6ac):

python3: /project/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:210: llvm::LogicalResult {anonymous}::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, const mlir::LLVMTypeConverter*, mlir::ConversionPatternRewriter&) const: Assertion `(dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"' failed.

Below is my WIP kernel that triggers the issue. The kernel is supposed to perform two batched matrix multiply and some reshapes to compute the forward pass of a block tensor train as in the reference einsum. The code gives the correct result when run using the interpreter. If this is a known issue with a workaround I would appreciate some help :)

import os

import torch
import triton
import triton.language as tl
from einops import einsum, rearrange


def btt_fwd_ref(x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor) -> torch.Tensor:
    g, _, _, _, j = W1.shape
    y = rearrange(x, "... (g i j) -> ... g i j", g=g, j=j)
    out = einsum(y, W1, W2, "... g i j, g k r i j, g k o r i -> ... g k o")
    return rearrange(out, "... g k o -> ... (g k o)")


@triton.jit
def _btt_fwd_kernel(
    # Pointers to matrices
    x_ptr,
    w1_ptr,
    w2_ptr,
    out_ptr,
    # Matrix dimensions
    B,
    G,
    K,
    I,
    J,
    O,
    # Strides
    # x
    stride_xb,
    stride_xg,
    stride_xi,
    stride_xj,
    # W1
    stride_w1g,
    stride_w1k,
    stride_w1r,
    stride_w1i,
    stride_w1j,
    # W2
    stride_w2g,
    stride_w2k,
    stride_w2o,
    stride_w2r,
    stride_w2i,
    # out
    stride_ob,
    stride_og,
    stride_ok,
    stride_oo,
    # Meta-parameters
    RANK: tl.constexpr,
    BLOCK_B: tl.constexpr,
    BLOCK_K: tl.constexpr,
    BLOCK_J: tl.constexpr,
):
    # Program ID
    pid1 = tl.program_id(0)
    pid2 = tl.program_id(1)

    # Compute batch and group indices
    bb = pid1 // G * BLOCK_B
    g = pid1 % G

    o_blocks = tl.cdiv(O, BLOCK_K)
    bk = pid2 // o_blocks * BLOCK_K
    bo = pid2 % o_blocks * BLOCK_K
    acc_out = tl.zeros((BLOCK_K, BLOCK_B, BLOCK_K), dtype=tl.float32)  # g: k b o

    for r in tl.range(0, RANK):
        indsb = tl.arange(0, BLOCK_B)
        indsj = tl.arange(0, BLOCK_J)
        indsk = tl.arange(0, BLOCK_K)
        for bi in tl.range(0, I, BLOCK_J):
            # First matmul: x × W1
            acc_inner = tl.zeros(
                (BLOCK_J, BLOCK_B, BLOCK_K), dtype=tl.float32
            )  # g r: i b k
            for bj in tl.range(0, J, BLOCK_J):
                # Load block of x
                x_block = tl.load(
                    x_ptr
                    + g * stride_xg
                    + (bb + indsb[:, None, None]) * stride_xb
                    + (bi + indsj[None, :, None]) * stride_xi
                    + (bj + indsj[None, None, :]) * stride_xj,
                    mask=(
                        bb + indsb[:, None, None] < B and bi + indsj[None, :, None] < I
                    )
                    and bj + indsj[None, None, :] < J,
                    other=0,
                )  # g: i b j
                x_block = tl.trans(x_block, 1, 0, 2)  # g: b i j -> i b j

                # Load block of W1 (k j)
                w1_block = tl.load(
                    w1_ptr
                    + g * stride_w1g
                    + r * stride_w1r
                    + (bk + indsk[:, None, None]) * stride_w1k
                    + (bi + indsj[None, :, None]) * stride_w1i
                    + (bj + indsj[None, None, :]) * stride_w1j,
                    mask=(
                        bk + indsk[:, None, None] < K and bi + indsj[None, :, None] < I
                    )
                    and bj + indsj[None, None, :] < J,
                    other=0,
                )  # g r: k i j
                w1_block = tl.trans(w1_block, 1, 2, 0)  # g r: k i j -> i j k

                # Accumulate
                acc_inner += tl.dot(
                    x_block, w1_block
                )  # g r: (i b j) x (i j k) -> i b k

            acc_inner = tl.trans(acc_inner, 2, 1, 0)  # g r: i b k -> k b i

            # Second matmul: acc × W2
            w2_block = tl.load(
                w2_ptr
                + g * stride_w2g
                + r * stride_w2r
                + (bk + indsk[:, None, None]) * stride_w2k
                + (bo + indsk[None, :, None]) * stride_w2o
                + (bi + indsj[None, None, :]) * stride_w2i,
                mask=(bk + indsk[:, None, None] < K and bo + indsk[None, :, None] < O)
                and bi + indsj[None, None, :] < I,
                other=0,
            )  # g r: k o i
            w2_block = tl.trans(w2_block, 0, 2, 1)  # g r: k o i -> k i o

            # Accumulate
            acc_out += tl.dot(acc_inner, w2_block)  # g r: (k b i) x (k i o) -> k b o

    indsb = tl.arange(0, BLOCK_B)
    indsk = tl.arange(0, BLOCK_K)

    # Store result
    tl.store(
        out_ptr
        + g * stride_og
        + (bb + indsb[:, None, None]) * stride_ob
        + (bk + indsk)[None, :, None] * stride_ok
        + (bo + indsk)[None, None, :] * stride_oo,
        tl.trans(acc_out, 1, 0, 2),  # g: k b o -> b k o
        mask=(bb + indsb[:, None, None] < B and bk + indsk[None, :, None] < K)
        and (bo + indsk)[None, None, :] < O,
    )


def _btt_fwd_triton(
    x: torch.Tensor, W1: torch.Tensor, W2: torch.Tensor
) -> torch.Tensor:
    G, K, R, I, J = W1.shape
    _, _, O, _, _ = W2.shape
    out = x.new_empty(x.shape[:-1] + (G * K * O,))

    # ... (g i j) -> (...) g i j
    x_view = x.view(-1, G, I, J)
    out_view = out.view(-1, G, K, O)
    B = x.shape[0]

    BLOCK_B, BLOCK_K, BLOCK_J = 16, 16, 16
    grid = (
        triton.cdiv(B, BLOCK_B) * G,
        triton.cdiv(K, BLOCK_K) * triton.cdiv(O, BLOCK_K),
    )
    _btt_fwd_kernel[grid](
        x_view,
        W1,
        W2,
        out_view,
        B,
        G,
        K,
        I,
        J,
        O,
        *x_view.stride(),
        *W1.stride(),
        *W2.stride(),
        *out_view.stride(),
        BLOCK_B=BLOCK_B,
        BLOCK_K=BLOCK_K,
        BLOCK_J=BLOCK_J,
        RANK=R,
    )
    return out


if __name__ == "__main__":
    torch.manual_seed(0)
    triton_interpreting = os.environ.get("TRITON_INTERPRET", "0") == "1"
    print("Interpreting:", triton_interpreting, triton.__version__)
    device = "cpu" if triton_interpreting else "cuda"
    G, K, R, I, J, O = 2, 16, 1, 16, 16, 16

    bsize, in_features, out_features = 128, 256, 256
    x, W1, W2 = (
        torch.randn(bsize, G * in_features, device=device),
        torch.randn(G, K, R, I, J, device=device),
        torch.randn(G, K, O, R, I, device=device),
    )

    y = _btt_fwd_triton(x, W1, W2)
    y_ref = btt_fwd_ref(x, W1, W2)
    torch.testing.assert_close(y, y_ref)
    print("Succeeded")

Environment details

Triton: Tested on 3.1.0 and main (d5ba6ac)
GPU: A100 and 4070 Ti

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions