From 2e0a91783075bd443283b60c6e039870476bc714 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 29 Nov 2024 00:08:15 -0800 Subject: [PATCH] Add tests for 3D local_load local_alloc and relax asserts Also switch 3D dot_operand cases to use linear layout path, This may be suboptimal in some cases but that solves the functionality problems which is more important. There is ongoing work from Mario that should get the code quality to be good again soon. --- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 10 +-- python/test/unit/language/test_core.py | 88 +++++++++++++++++++ 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 3488a686134e..1e6e1c1fd717 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -23,9 +23,6 @@ void lowerDistributedToShared( auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); - assert(srcTy.getShape().size() <= 2 || - (srcTy.getShape().size() == 3 && outOrd[2] == 0) && - "Unexpected rank of ConvertLayout(blocked->shared)"); auto elemTy = typeConverter->convertType(srcTy.getElementType()); auto smemBase = smemObj.getBase(); @@ -163,7 +160,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; // To be removed in https://github.com/triton-lang/triton/pull/5154 bool legacyLoweringIsBuggy = - (kWidth >= 8 || (kWidth == 4 && bitwidth == 32)) && mma.isAmpere(); + (kWidth >= 8 || (kWidth == 4 && bitwidth == 32) || + dstTy.getRank() == 3) && + mma.isAmpere(); return (mma.isHopper() && !canUseLdmatrix) || (mma.isAmpere() && legacyLoweringIsBuggy); } @@ -220,7 +219,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); - assert((dstShape.size() <= 2 || isSupportedDotOpLayout(srcTy, dstTy)) && + assert((!isa(dstTy.getEncoding()) || + isSupportedDotOpLayout(srcTy, dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 21b43ac02afb..91be7d217d84 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5383,6 +5383,94 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t assert torch.equal(z, x) +layouts_3d = [ + BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]), op_idx=0, + k_width=1), +] + +shared_layout_3d = [ + SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), +] + + +@pytest.mark.parametrize("M, N, K", [[8, 16, 32]]) +@pytest.mark.parametrize("shared_layout", shared_layout_3d) +@pytest.mark.parametrize("dist_layout", layouts_3d) +def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %cst_0 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_1 = arith.constant dense<{K*N}> : tensor<{M}x1x1xi32, #dist> + %cst_2 = arith.constant dense<{K}> : tensor<1x{N}x1xi32, #dist> + %0 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %1 = tt.expand_dims %0 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %4 = tt.addptr %3, %2 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %7 = tt.expand_dims %6 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %8 = arith.muli %7, %cst_2 : tensor<1x{N}x1xi32, #dist> + %9 = tt.broadcast %4 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %10 = tt.broadcast %8 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %11 = tt.addptr %9, %10 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %12 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %13 = tt.expand_dims %12 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %14 = tt.expand_dims %13 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %15 = arith.muli %14, %cst_1 : tensor<{M}x1x1xi32, #dist> + %16 = tt.broadcast %11 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + %19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> + %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist> + %22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> + %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> + %25 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x{K}x!tt.ptr, #dist> + %26 = tt.addptr %25, %24 : tensor<1x1x{K}x!tt.ptr, #dist>, tensor<1x1x{K}xi32, #dist> + %27 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %28 = tt.expand_dims %27 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %29 = tt.expand_dims %28 {{axis = 2 : i32}} : tensor<1x{N}xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<1x{N}x1xi32, #dist> + %30 = arith.muli %29, %cst : tensor<1x{N}x1xi32, #dist> + %31 = tt.broadcast %26 : tensor<1x1x{K}x!tt.ptr, #dist> -> tensor<1x{N}x{K}x!tt.ptr, #dist> + %32 = tt.broadcast %30 : tensor<1x{N}x1xi32, #dist> -> tensor<1x{N}x{K}xi32, #dist> + %33 = tt.addptr %31, %32 : tensor<1x{N}x{K}x!tt.ptr, #dist>, tensor<1x{N}x{K}xi32, #dist> + %34 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> + %35 = tt.expand_dims %34 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #ttg.slice<{{dim = 2, parent = #dist}}>}}>> -> tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> + %36 = tt.expand_dims %35 {{axis = 2 : i32}} : tensor<{M}x1xi32, #ttg.slice<{{dim = 2, parent = #dist}}>> -> tensor<{M}x1x1xi32, #dist> + %37 = arith.muli %36, %cst_0 : tensor<{M}x1x1xi32, #dist> + %38 = tt.broadcast %33 : tensor<1x{N}x{K}x!tt.ptr, #dist> -> tensor<{M}x{N}x{K}x!tt.ptr, #dist> + %39 = tt.broadcast %37 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> + %40 = tt.addptr %38, %39 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> + tt.store %40, %21 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> + tt.return + }} +}} +""" + + x = torch.arange(0, M * N * K, device=device, dtype=torch.int32).reshape(M, N, K) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + mma_pairs = [ [ MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),