diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp index 6a0310c84ede..fcb4ee128e87 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -788,7 +788,9 @@ void fillTDMDescriptor(RewriterBase &rewriter, Location loc, // Update tensor shapes based on offset for (size_t i = 0; i < numDims; ++i) { - tensorShape[i] = b.smax(b.i32_val(0), b.sub(tensorShape[i], offset[i])); + auto diff = b.sub(tensorShape[i], offset[i]); + Value inBounds = b.icmp_ule(diff, tensorShape[i]); + tensorShape[i] = b.select(inBounds, diff, b.i32_val(0)); } // TDM store does not support padding in general. However, if the padding diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index 81d6b8d089af..54a799ac6ac8 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -2944,6 +2944,35 @@ def kernel(a_ptr, b_ptr): assert torch.equal(a[:, 32:], b[:, 32:]) and not torch.equal(a[:, :32], b[:, :32]) +# Check that negative TDM offsets are treated as unsigned so they will mask (zero-fill) out the whole tile. +@pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250") +def test_tdm_load_negative_offset(): + + @gluon.jit + def tdm_load_negative_offset_kernel(a_ptr, b_ptr): + shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) + reg_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0]) + + desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=a_ptr, shape=(16, 64), strides=(64, 1), + block_shape=(16, 64), layout=shared_layout) + smem = ttgl.allocate_shared_memory(desc.dtype, shape=desc.block_shape, layout=desc.layout) + + ttgl.amd.gfx1250.tdm.async_load(desc, [-1, -1], smem) + ttgl.amd.gfx1250.tdm.async_wait(0) + + b_offs_m = ttgl.arange(0, 16, layout=ttgl.SliceLayout(1, reg_layout)) + b_offs_n = ttgl.arange(0, 64, layout=ttgl.SliceLayout(0, reg_layout)) + b_ptrs = b_ptr + b_offs_m[:, None] * 64 + b_offs_n[None, :] + tile = smem.load(reg_layout) + ttgl.store(b_ptrs, tile) + + a = torch.randint(0x0, 0xFFFF, (16, 64), dtype=torch.uint16) + b_device = torch.randint(0x0, 0xFFFF, (16, 64), dtype=torch.uint16).cuda() + tdm_load_negative_offset_kernel[(1, )](a.cuda(), b_device) + + assert torch.all(b_device.cpu() == 0) + + @pytest.mark.skipif(not is_hip_gfx1250(), reason="Requires GFX1250") @pytest.mark.parametrize("XBLOCK", [128]) def test_ws_store_wait_load(XBLOCK):