diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 3488a686134e..1e6e1c1fd717 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -23,9 +23,6 @@ void lowerDistributedToShared( auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); - assert(srcTy.getShape().size() <= 2 || - (srcTy.getShape().size() == 3 && outOrd[2] == 0) && - "Unexpected rank of ConvertLayout(blocked->shared)"); auto elemTy = typeConverter->convertType(srcTy.getElementType()); auto smemBase = smemObj.getBase(); @@ -163,7 +160,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; // To be removed in https://github.com/triton-lang/triton/pull/5154 bool legacyLoweringIsBuggy = - (kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere(); + (kWidth >= 8 || (kWidth == 4 && bitwidth == 32) || + dstTy.getRank() == 3) && + mma.isAmpere(); return (mma.isHopper() && !canUseLdmatrix) || (mma.isAmpere() && legacyLoweringIsBuggy); } @@ -220,7 +219,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) && + assert((!isa(dstTy.getEncoding()) || + isSupportedDotOpLayout(srcTy, dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 21b43ac02afb..91be7d217d84 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5383,6 +5383,94 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t assert torch.equal(z, x) +layouts_3d = [ + BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0, + k_width=1), +] + +shared_layout_3d = [ + SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), +] + + +@pytest.mark.parametrize("M, N, K", [[8, 16, 32]]) +@pytest.mark.parametrize("shared_layout", shared_layout_3d) +@pytest.mark.parametrize("dist_layout", layouts_3d) +def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist> + %9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist> + %16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + %19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> + %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist> + %22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %25 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist> + %31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist> + %38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + mma_pairs = [ [ MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),