diff --git a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir index b21079b6923e..eda79c211117 100644 --- a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir @@ -195,11 +195,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tdm_2d_with_padding tt.func public @tdm_2d_with_padding( - %tensorDesc: !tt.tensordesc<128x64xf16>, + %tensorDesc: !tt.tensordesc<128x64xf16, #shared>, %memDesc: !ttg.memdesc<128x64xf16, #shared, #smem, mutable> ) { %c0_i32 = arith.constant 0 : i32 - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<128x64xf16> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<128x64xf16, #shared> // CHECK: "llvm.amdgcn.tensor.store.from.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () tt.return } @@ -212,11 +212,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tdm_5d_with_padding tt.func public @tdm_5d_with_padding( - %tensorDesc: !tt.tensordesc<8x8x8x16x16xf16>, + %tensorDesc: !tt.tensordesc<8x8x8x16x16xf16, #shared_5d>, %memDesc: !ttg.memdesc<8x8x8x16x16xf16, #shared_5d, #smem_5d, mutable> ) { %c0_i32 = arith.constant 0 : i32 - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<8x8x8x16x16xf16, #shared_5d, #smem_5d, mutable> -> !tt.tensordesc<8x8x8x16x16xf16> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<8x8x8x16x16xf16, #shared_5d, #smem_5d, mutable> -> !tt.tensordesc<8x8x8x16x16xf16, #shared_5d> // CHECK: "llvm.amdgcn.tensor.store.from.lds"({{.+}}) : (vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>, vector<8xi32>, i32) -> () tt.return } diff --git a/test/TritonGPU/amd/amd-consan.mlir b/test/TritonGPU/amd/amd-consan.mlir index a85b1709a149..45da871eaa71 100644 --- a/test/TritonGPU/amd/amd-consan.mlir +++ b/test/TritonGPU/amd/amd-consan.mlir @@ -464,7 +464,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local - tt.func public @async_tdm_copy_global_to_local(%desc: !tt.tensordesc<32x32xf32>) { + tt.func public @async_tdm_copy_global_to_local(%desc: !tt.tensordesc<32x32xf32, #shared>) { // CHECK-DAG: %[[BUFFERS:.*]] = tti.experimental_buffer_descriptors [0], [{{.*}}], shared_mem : tensor<1xi64 // CHECK-DAG: %[[WRITE_VISIBILITY_GLOB:.*]] = ttg.global_scratch_alloc {alignment = 16 : i32, nbytes = 8 : i32, shared_cluster_state, third_party_allocation, tt.divisibility = 16 : i64} : !tt.ptr @@ -496,7 +496,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_visible_writes // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -510,8 +510,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local_two_bufs_one_barrier tt.func public @async_tdm_copy_global_to_local_two_bufs_one_barrier( - %a: !tt.tensordesc<32x32xf32>, - %b: !tt.tensordesc<32x32xf32>) { + %a: !tt.tensordesc<32x32xf32, #shared>, + %b: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 @@ -533,7 +533,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_visible_writes // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - %0 = amdg.async_tdm_copy_global_to_local %a[%c0_i32, %c0_i32] into %a_smem, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %0 = amdg.async_tdm_copy_global_to_local %a[%c0_i32, %c0_i32] into %a_smem, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // Second TDM copy: same full instrumentation // CHECK: tt.call @__triton_consan_verify_write_visibility @@ -546,7 +546,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_track_visible_writes // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state - %1 = amdg.async_tdm_copy_global_to_local %b[%c0_i32, %c0_i32] into %b_smem, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %b[%c0_i32, %c0_i32] into %b_smem, pred = %pred, barrier = %bar : !tt.tensordesc<32x32xf32, #shared>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %c0_phase = arith.constant 0 : i32 amdg.wait_barrier %bar, %c0_phase : !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -564,7 +564,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_global_to_local_no_barrier - tt.func public @async_tdm_copy_global_to_local_no_barrier(%desc: !tt.tensordesc<32x32xf32>) { + tt.func public @async_tdm_copy_global_to_local_no_barrier(%desc: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -575,7 +575,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses // CHECK-NOT: tt.call @__triton_consan_verify_barrier_arrive - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32, #shared> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> tt.return } } @@ -587,7 +587,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_local_to_global - tt.func public @async_tdm_copy_local_to_global(%desc: !tt.tensordesc<32x32xf32>, %ptr: tensor<128x128x!tt.ptr, #blocked>) { + tt.func public @async_tdm_copy_local_to_global(%desc: !tt.tensordesc<32x32xf32, #shared>, %ptr: tensor<128x128x!tt.ptr, #blocked>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %shmem = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -598,7 +598,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_check_outstanding_commits_excl_self_noalias // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32, #shared> tt.return } } @@ -609,7 +609,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_load_store_no_barrier - tt.func public @async_tdm_load_store_no_barrier(%in_desc: !tt.tensordesc<32x32xf32>, %out_desc: !tt.tensordesc<32x32xf32>) { + tt.func public @async_tdm_load_store_no_barrier(%in_desc: !tt.tensordesc<32x32xf32, #shared>, %out_desc: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> @@ -617,11 +617,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_check_outstanding_commits_excl_self_noalias // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - %1 = amdg.async_tdm_copy_global_to_local %in_desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %in_desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32, #shared> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // CHECK: tt.call @__triton_consan_check_outstanding_commits_excl_self_noalias // CHECK: tt.call @__triton_consan_stage_access_for_commit // CHECK: tt.call @__triton_consan_commit_accesses - amdg.async_tdm_copy_local_to_global %out_desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> + amdg.async_tdm_copy_local_to_global %out_desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32, #shared> tt.return } } @@ -633,7 +633,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @async_tdm_copy_local_to_global_with_barrier - tt.func public @async_tdm_copy_local_to_global_with_barrier(%desc: !tt.tensordesc<32x32xf32>) { + tt.func public @async_tdm_copy_local_to_global_with_barrier(%desc: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %bar = ttg.local_alloc {allocation.offset = 65536 : i32} : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> @@ -648,7 +648,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar // CHECK: tt.call @__triton_consan_verify_barrier_arrive // CHECK: tt.call @__triton_consan_update_barrier_state // CHECK-NOT: tt.call @__triton_consan_stage_access_for_commit - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0, barrier = %bar : !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !tt.tensordesc<32x32xf32> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0, barrier = %bar : !ttg.memdesc<32x32xf32, #shared, #smem, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !tt.tensordesc<32x32xf32, #shared> tt.return } } @@ -716,11 +716,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @tdm_load_no_barrier_wait - tt.func public @tdm_load_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32>) { + tt.func public @tdm_load_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32, #shared> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_both amdg.async_tdm_wait {num = 0 : i32} ttg.local_load %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> @@ -735,10 +735,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @tdm_store_no_barrier_wait - tt.func public @tdm_store_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32>) { + tt.func public @tdm_store_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32, #shared> // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_both amdg.async_tdm_wait {num = 0 : i32} ttg.local_load %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> @@ -753,12 +753,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shar #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 8 : i32} { // CHECK-LABEL: @tdm_load_store_no_barrier_wait - tt.func public @tdm_load_store_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32>) { + tt.func public @tdm_load_store_no_barrier_wait(%desc: !tt.tensordesc<32x32xf32, #shared>) { %c0_i32 = arith.constant 0 : i32 %pred = arith.constant 1 : i32 %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32> + %1 = amdg.async_tdm_copy_global_to_local %desc[%c0_i32, %c0_i32] into %0, pred = %pred : !tt.tensordesc<32x32xf32, #shared> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + amdg.async_tdm_copy_local_to_global %desc[%c0_i32, %c0_i32] from %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> !tt.tensordesc<32x32xf32, #shared> // CHECK: tt.call @__triton_consan_clear_outstanding_commits_transfer_both amdg.async_tdm_wait {num = 0 : i32} ttg.local_load %0 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #blocked> diff --git a/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir b/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir index 601064d5c664..3f44773195c5 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count-without-token.mlir @@ -575,7 +575,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: tdm_gather_scatter_multiple_instructions tt.func public @tdm_gather_scatter_multiple_instructions( %memDesc: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, - %tensorDesc: !tt.tensordesc<64x128xf16>, + %tensorDesc: !tt.tensordesc<64x128xf16, #shared>, %row_indices_i32: tensor<64xi32, #ttg.slice<{dim = 0, parent = #idx_i32_parent}>>, %row_indices_i16: tensor<256xi16, #ttg.slice<{dim = 0, parent = #idx_i16_parent}>>, %pred: i32 @@ -583,13 +583,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 // Gather with i32 indices: sizePerThread=16, 4 warps, maxPerInstr=8 => 2 instructions - amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc, pred = %pred : tensor<64xi32, #ttg.slice<{dim = 0, parent = #idx_i32_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> + amdg.async_tdm_gather %tensorDesc[%row_indices_i32, %c0_i32] to %memDesc, pred = %pred : tensor<64xi32, #ttg.slice<{dim = 0, parent = #idx_i32_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16, #shared> // Scatter with i32 indices: 2 instructions - amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc : tensor<64xi32, #ttg.slice<{dim = 0, parent = #idx_i32_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> + amdg.async_tdm_scatter %tensorDesc[%row_indices_i32, %c0_i32] from %memDesc : tensor<64xi32, #ttg.slice<{dim = 0, parent = #idx_i32_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16, #shared> // Gather with i16 indices: sizePerThread=64, 4 warps, maxPerInstr=16 => 4 instructions - amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc, pred = %pred : tensor<256xi16, #ttg.slice<{dim = 0, parent = #idx_i16_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> + amdg.async_tdm_gather %tensorDesc[%row_indices_i16, %c0_i32] to %memDesc, pred = %pred : tensor<256xi16, #ttg.slice<{dim = 0, parent = #idx_i16_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16, #shared> // Scatter with i16 indices: 4 instructions - amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc : tensor<256xi16, #ttg.slice<{dim = 0, parent = #idx_i16_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> + amdg.async_tdm_scatter %tensorDesc[%row_indices_i16, %c0_i32] from %memDesc : tensor<256xi16, #ttg.slice<{dim = 0, parent = #idx_i16_parent}>>, !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16, #shared> // i32 ops emit 2 instructions each, i16 ops emit 4 each => total 12 instructions // CHECK: amdg.async_tdm_intrinsic_wait {count = 0 @@ -616,15 +616,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: tdm_load_store_single_instruction tt.func public @tdm_load_store_single_instruction( %memDesc: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>, - %tensorDesc: !tt.tensordesc<64x128xf16>, + %tensorDesc: !tt.tensordesc<64x128xf16, #shared>, %pred: i32 ) { %c0_i32 = arith.constant 0 : i32 - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc<64x128xf16> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc<64x128xf16> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc<64x128xf16, #shared> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16, #shared> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %pred : !tt.tensordesc<64x128xf16, #shared> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !tt.tensordesc<64x128xf16, #shared> // CHECK: amdg.async_tdm_intrinsic_wait {count = 0 amdg.async_tdm_wait {num = 0 : i32} diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index 2d5e12632dbc..88b63b7fbbc7 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -406,13 +406,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: simple_tdm_waitcnt - tt.func public @simple_tdm_waitcnt(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc<128x16xf16>, %mask: i32 + tt.func public @simple_tdm_waitcnt(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc<128x16xf16, #shared>, %mask: i32 ) { %c0_i32 = arith.constant 0 : i32 // Each async_tdm_copy only emits a single instruction (-> counts 1) - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> - %2 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16, #shared> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16, #shared> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> // Do not wait on the second tdm => waitcnt 1 // CHECK: amdg.async_tdm_intrinsic_wait {{.*}} {count = 1 @@ -538,17 +538,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mix_async_copy_and_async_tdm_copy - tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x8xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc<128x8xf16>, %mask: i32, %ptr: tensor<128x8x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>} + tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x8xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc<128x8xf16, #shared>, %mask: i32, %ptr: tensor<128x8x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>} ) { %c0_i32 = arith.constant 0 : i32 // Each async_tdm_copy only emits a single instruction (-> counts 1) - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x8xf16> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x8xf16, #shared> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> %2 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x8x!tt.ptr, #blocked> -> <128x8xf16, #shared, #smem, mutable> %21 = ttg.async_commit_group tokens %2 - %3 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x8xf16> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> + %3 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x8xf16, #shared> -> !ttg.memdesc<128x8xf16, #shared, #smem, mutable> %4 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x8x!tt.ptr, #blocked> -> <128x8xf16, #shared, #smem, mutable> %5 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x8x!tt.ptr, #blocked> -> <128x8xf16, #shared, #smem, mutable> @@ -698,7 +698,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: tdm_partitioned_shared_waitcnt tt.func public @tdm_partitioned_shared_waitcnt( %memDesc: !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable>, - %tensorDesc: !tt.tensordesc<128x16xf16>, + %tensorDesc: !tt.tensordesc<128x16xf16, #partitioned>, %mask: i32 ) { %c0_i32 = arith.constant 0 : i32 @@ -706,8 +706,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // numLogicalPieces = numPartitions * numGroups = 2 * 4 = 8 // warpsAlongPartition = gcd(numWarps=4, numLogicalPieces=8) = 4 // Each async_tdm_copy emits divideCeil(8, 4) = 2 instructions - %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable> - %2 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable> + %1 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16, #partitioned> -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable> + %2 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, pred = %mask : !tt.tensordesc<128x16xf16, #partitioned> -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable> // Skip second copy (2 instructions) => count = 2 // CHECK: amdg.async_tdm_intrinsic_wait {{.*}} {count = 2 diff --git a/test/TritonGPU/amd/invalid.mlir b/test/TritonGPU/amd/invalid.mlir index f36fe402b841..a35293248576 100644 --- a/test/TritonGPU/amd/invalid.mlir +++ b/test/TritonGPU/amd/invalid.mlir @@ -206,25 +206,36 @@ module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #shared_32 = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [128, 64]}> #shared_2_intervals = #ttg.padded_shared<[64:+4, 128:+4] {order = [1, 0], shape = [128, 64]}> +#shared_partitioned = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 0, partitionLayout = #shared_32}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { tt.func public @interval_not_matching_innermost_block_dimension( - %tensorDesc: !tt.tensordesc<128x64xf16>, + %tensorDesc: !tt.tensordesc<128x64xf16, #shared_32>, %memDesc: !ttg.memdesc<128x64xf16, #shared_32, #smem, mutable> ) { %c0_i32 = arith.constant 0 : i32 // expected-error @+1 {{TDM store padding is only supported when padding interval equals the innermost block dimension}} - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_32, #smem, mutable> -> !tt.tensordesc<128x64xf16> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_32, #smem, mutable> -> !tt.tensordesc<128x64xf16, #shared_32> tt.return } tt.func public @tdm_store_two_padding_intervals( - %tensorDesc: !tt.tensordesc<128x64xf16>, + %tensorDesc: !tt.tensordesc<128x64xf16, #shared_2_intervals>, %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> ) { %c0_i32 = arith.constant 0 : i32 // expected-error @+1 {{TDM store only supports single interval paddings}} - amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> -> !tt.tensordesc<128x64xf16> + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> -> !tt.tensordesc<128x64xf16, #shared_2_intervals> + tt.return + } + + tt.func public @tdm_store_encoding_mismatch( + %tensorDesc: !tt.tensordesc<128x64xf16, #shared_partitioned>, + %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> + ) { + %c0_i32 = arith.constant 0 : i32 + // expected-error @+1 {{Mismatch between TDM descriptor and source smem encodings}} + amdg.async_tdm_copy_local_to_global %tensorDesc[%c0_i32, %c0_i32] from %memDesc: !ttg.memdesc<128x64xf16, #shared_2_intervals, #smem, mutable> -> !tt.tensordesc<128x64xf16, #shared_partitioned> tt.return } } @@ -258,6 +269,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ #shared_scatter_32 = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [8, 64]}> // Scatter with two padding intervals (only single interval is supported). #shared_scatter_2_intervals = #ttg.padded_shared<[64:+4, 128:+4] {order = [1, 0], shape = [8, 64]}> +#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [32, 32]}> +#partitioned = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 0, partitionLayout = #shared}> #smem_scatter = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { tt.func public @scatter_interval_not_matching_innermost_block_dimension( @@ -291,13 +304,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // hint == 0 has no active warps; rejected. tt.func @warp_used_hint_zero( - %tensorDesc: !tt.tensordesc<256x64xf16>, + %tensorDesc: !tt.tensordesc<256x64xf16, #shared_wb>, %memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>, %pred: i32 ) { %c0 = arith.constant 0 : i32 // expected-error @+1 {{warp_used_hint must have at least one bit set}} - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 0 : i32} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 0 : i32} : !tt.tensordesc<256x64xf16, #shared_wb> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> tt.return } @@ -305,52 +318,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // active set spans 3 warpId bit positions, not log2(K) = 2 -- a // non axis-aligned pattern is not supported. tt.func @warp_used_hint_non_axis_aligned( - %tensorDesc: !tt.tensordesc<256x64xf16>, + %tensorDesc: !tt.tensordesc<256x64xf16, #shared_wb>, %memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>, %pred: i32 ) { %c0 = arith.constant 0 : i32 // expected-error @+1 {{is not axis-aligned}} - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 105 : i32} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 105 : i32} : !tt.tensordesc<256x64xf16, #shared_wb> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> tt.return } // popcount must be a power of two. 0x07 has K=3 -- rejected even // though warps 0..2 are otherwise contiguous. tt.func @warp_used_hint_non_pow2_k( - %tensorDesc: !tt.tensordesc<256x64xf16>, + %tensorDesc: !tt.tensordesc<256x64xf16, #shared_wb>, %memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>, %pred: i32 ) { %c0 = arith.constant 0 : i32 // expected-error @+1 {{popcount(warp_used_hint) = 3 must be a power of two}} - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 7 : i32} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 7 : i32} : !tt.tensordesc<256x64xf16, #shared_wb> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> tt.return } // hint sets all 16 low bits but num_warps = 8 so bits 8..15 don't // correspond to any warp. Reported by the bits-beyond check. tt.func @warp_used_hint_exceeds_num_warps( - %tensorDesc: !tt.tensordesc<256x64xf16>, + %tensorDesc: !tt.tensordesc<256x64xf16, #shared_wb>, %memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>, %pred: i32 ) { %c0 = arith.constant 0 : i32 // expected-error @+1 {{warp_used_hint = 0xffff sets bits beyond num_warps = 8}} - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 65535 : i32} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 65535 : i32} : !tt.tensordesc<256x64xf16, #shared_wb> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> tt.return } // Bits outside [0, num_warps) must be zero. K=2 is otherwise valid, // but warp index 9 is not in [0, 8). tt.func @warp_used_hint_bits_beyond_num_warps( - %tensorDesc: !tt.tensordesc<256x64xf16>, + %tensorDesc: !tt.tensordesc<256x64xf16, #shared_wb>, %memDesc: !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable>, %pred: i32 ) { %c0 = arith.constant 0 : i32 // expected-error @+1 {{sets bits beyond num_warps = 8}} - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 513 : i32} : !tt.tensordesc<256x64xf16> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 513 : i32} : !tt.tensordesc<256x64xf16, #shared_wb> -> !ttg.memdesc<256x64xf16, #shared_wb, #smem_wb, mutable> tt.return } } @@ -378,13 +391,31 @@ module attributes {"ttg.target" = "hip:gfx950", "ttg.num-ctas" = 1 : i32, "ttg.n #smem_mi = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { tt.func @warp_used_hint_partitioned_insufficient( - %tensorDesc: !tt.tensordesc<128x16xf16>, + %tensorDesc: !tt.tensordesc<128x16xf16, #partitioned_mi>, %memDesc: !ttg.memdesc<128x16xf16, #partitioned_mi, #smem_mi, mutable>, %pred: i32 ) { %c0 = arith.constant 0 : i32 // expected-error @+1 {{warp_used_hint with a partitioned shared encoding must select K active warps}} - %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 3 : i32} : !tt.tensordesc<128x16xf16> -> !ttg.memdesc<128x16xf16, #partitioned_mi, #smem_mi, mutable> + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred {warp_used_hint = 3 : i32} : !tt.tensordesc<128x16xf16, #partitioned_mi> -> !ttg.memdesc<128x16xf16, #partitioned_mi, #smem_mi, mutable> + tt.return + } +} + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#partitioned = #ttg.partitioned_shared<{numPartitions = 2, numGroups = 1, partitionDim = 0, partitionLayout = #shared}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func @tdm_load_encoding_mismatch( + %tensorDesc: !tt.tensordesc<128x16xf16, #shared>, + %memDesc: !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable>, + %pred: i32 + ) { + %c0 = arith.constant 0 : i32 + // expected-error @+1 {{Mismatch between TDM descriptor and destination smem encodings}} + %0 = amdg.async_tdm_copy_global_to_local %tensorDesc[%c0, %c0] into %memDesc, pred = %pred : !tt.tensordesc<128x16xf16, #shared> -> !ttg.memdesc<128x16xf16, #partitioned, #smem, mutable> tt.return } } diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index fa5cadf22b73..6f9834b32683 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -649,6 +649,10 @@ LogicalResult AsyncTDMCopyGlobalToLocalOp::verify() { if (failed(verifyResult)) return verifyResult; + if (tensorDescTy.getSharedLayout() != smemTy.getEncoding()) + return emitOpError( + "Mismatch between TDM descriptor and destination smem encodings"); + auto swizzledEnc = llvm::dyn_cast(smemTy.getEncoding()); if (swizzledEnc && swizzledEnc.getMaxPhase() != 1) @@ -750,6 +754,10 @@ LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() { if (failed(verifyResult)) return verifyResult; + if (tensorDescTy.getSharedLayout() != smemTy.getEncoding()) + return emitOpError( + "Mismatch between TDM descriptor and source smem encodings"); + auto swizzledEnc = llvm::dyn_cast(smemTy.getEncoding()); if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)