Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transposed read/writes #176

Closed
falkaer opened this issue Aug 2, 2021 · 2 comments
Closed

Transposed read/writes #176

falkaer opened this issue Aug 2, 2021 · 2 comments

Comments

@falkaer
Copy link

falkaer commented Aug 2, 2021

Hi,

Is there a way to perform efficient transposed read/writes in Triton? As an exercise I wanted to write a transpose kernel in Triton, and taking inspiration from the matmul example in the docs I wrote the following kernel which naively transposes the input by swapping the strides,

@triton.autotune(
        configs=[
            triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
            # ...
        ],
        key=['M', 'N']
)
@triton.jit
def _transpose_triton(A, B, stride_am, stride_an, stride_bn, stride_bm, M, N, **META):
    pid = tl.program_id(0)
    BLOCK_M = META['BLOCK_M']
    BLOCK_N = META['BLOCK_N']
    GROUP_M = META['GROUP_M']
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // group_size
    
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    a = tl.load(A, mask=mask)
    
    # rematerialize to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    tl.store(B, a, mask=mask)

def transpose_triton(input, out=None):
    M, N = input.shape
    if out is None:
        out = input.new_zeros(N, M)
    
    assert out.size(0) == N and out.size(1) == M
    assert input.stride(0) == 1 or input.stride(1) == 1
    assert out.stride(0) == 1 or out.stride(1) == 1
    
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
    _transpose_triton[grid](input, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
    return out

However this performs suboptimally (likely because the store instruction becomes uncoalesced), unless the destination is already lazily transposed, as the write then becomes coalesced and approaches the speed of a copy.

...
>>> X = torch.rand(20000, 30000, device='cuda')
>>> out = X.new_empty(30000, 20000)
>>> %timeit transpose_triton(X, out=out); torch.cuda.synchronize()
51.6 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit X.transpose(0, 1).contiguous(); torch.cuda.synchronize()
56.2 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> out = out = X.new_empty(20000, 30000).T
>>> %timeit transpose_triton(X, out=out); torch.cuda.synchronize()
16 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> X.clone(); torch.cuda.synchronize()
13.6 ms ± 16.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

It would be very useful to have a generic way to express such transposed operations in a way that avoids uncoalesced reads/writes to dram (similar to how in CUDA one can use shared memory to efficiently transpose elements). If there is already a way to achieve this in Triton I would appreciate some hints, if not perhaps a tl.transpose function could be introduced?

Best regards,
Kenny

@ptillet
Copy link
Collaborator

ptillet commented Aug 2, 2021

This seems like a bug. Triton is supposed to order threads so that it stays coalesced for both reads and writes, and automatically figure out that it should use shared memory. It used to work properly. Thanks for reporting, I'll look into it.

@ptillet
Copy link
Collaborator

ptillet commented Aug 30, 2021

Hey! Sorry for the delay. I've just merged a bunch of fixes into master and this takes care of the issue :) I've also added a unit test that explicitly checks that transpositions use coalesce memory accesses for both reads and writes https://github.com/openai/triton/blob/master/python/test/test_language.py#L403-L430. Please feel free to re-open the issue if you encounter more issues with transpositions in Triton!

@ptillet ptillet closed this as completed Aug 30, 2021
B1tway pushed a commit to B1tway/triton that referenced this issue Apr 3, 2023
…l_benchmark_rows_softmax

Print all softmax tutorial benchmark rows.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants