Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,7 @@ struct StoreOpConversion
vec = std::min(vec, maskAlign);
}

// numElements = 1 for scalar
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));

Value mask = getMask(valueTy, rewriter, loc);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
Expand Down
40 changes: 40 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,46 @@ class ConvertTritonGPUOpToLLVMPatternBase {
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
Value getMask(Type valueTy, ConversionPatternRewriter &rewriter,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function also be used for atomicCAS and atomicRMW?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I just wanted to test that separately

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought maybe for LoadOp as well, currently we don't mask

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought maybe for LoadOp as well, currently we don't mask

Why should we apply it for load?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cache line gets evicted between two threads loading replicated data, that could be a perf hit

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then how do you broadcast the data to multiple threads? e.g., in the case of a scalar pointer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, let me think about that some more

Location loc) const {
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Value mask = int_val(1, 1);
auto tid = tid_val();
if (tensorTy) {
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTA[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
i32_val(shape[dim])));
}
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 can
// write
mask = and_(mask, icmp_slt(tid, i32_val(1)));
}
return mask;
}

// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
Expand Down
47 changes: 47 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,53 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
]


@pytest.mark.parametrize("M", [32, 64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
def test_store_op(M, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
tt.store %8, %4 : tensor<{M}x1xf32, #src>
tt.return
}}
}}
"""

import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
store_kernel = triton.compile(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, 1)).astype('float32')
y = np.zeros((M, 1), dtype='float32')
x_tri = torch.tensor(x, device=device)
y_tri = torch.tensor(y, device=device)

pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = x

np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


@triton.jit
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
Expand Down
1 change: 0 additions & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
// CHECK: llvm.inline_asm
Expand Down