From e436bf64b50ac7f18218b8168da13176457fa3aa Mon Sep 17 00:00:00 2001 From: Artem Kroviakov <71938912+akroviakov@users.noreply.github.com> Date: Thu, 30 Jan 2025 18:22:25 +0100 Subject: [PATCH] [MLIR][XeGPU] Add sg_map for scatter verification (#124300) This PR adds sg_map to the verification of scatter ops in XeGPU. The documentation says `chunk_size: indicates the number of continuous elements accessed for each offset`, it also mentions the fact that scatter ops are SG-level. Hence, if an operation is distributed to work-items, a 1-d load means a work-item reads one element, a 2-d load means a work-item loads `chunk-size` or second dimension of tdesc elements. The changes in this PR reflect the documentation with the presence of sg_map attribute (i.e., distributed case). --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 7 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 47 ++++++++++- mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 62 ++++++++++++++- mlir/test/Dialect/XeGPU/invalid.mlir | 77 +++++++++++++++++++ 4 files changed, 183 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index a2bfa721f251..c2335eecc378 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>, - AllElementTypesMatch<["value", "TensorDesc"]>, - AllElementCountsMatch<["value", "TensorDesc"]>]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"] let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>, - AllElementTypesMatch<["value", "TensorDesc"]>]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81f46f941785..cd883baa986b 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -454,7 +454,22 @@ LogicalResult CreateDescOp::verify() { if (shape != tdescShape) return emitOpError("Incorrect TensorDesc shape. ") << "Expected is " << makeString(shape) << "\n"; - + if (auto sgMap = tdescTy.getSGMapAttr()) { + // A work-item's slice of the TensorDesc with shape [sg_size] or + // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively, + // the mapping should reflect that. + if (sgMap.getWiData()[0] > 1) + return emitOpError("TensorDesc's SG map only supports multiple elements " + "contiguous along rows."); + if (chunkSize != static_cast(sgMap.getWiData()[1])) + return emitOpError( + "TensorDesc's chunkSize must match WI's data mapping."); + if (int rank = tdescTy.getRank(); + (sgMap.getWiLayout()[2 - rank] != tdescShape[0])) + return emitOpError("Detected a conflict between SG map's work-item " + "layout and TensorDesc shape. Check the index of " + "`subgroup_size` in WI layout map."); + } return success(); } @@ -513,10 +528,23 @@ LogicalResult LoadGatherOp::verify() { if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) - return emitOpError("load_gather has to be transposed."); + return emitOpError("load of rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); } + if (auto sgMap = tdescTy.getSGMapAttr()) { + auto valueVecTy = cast(valueTy); + const int32_t wiData = + sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1]; + // All represent the same concept: a number of row elements to store. + if (valueVecTy.getNumElements() != wiData || + valueVecTy.getNumElements() != tdescTy.getChunkSize()) { + return emitOpError("Chunk size, vector size and wi_data must match."); + } + // Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size]. + tdescShape[tdescTy.getRank() - 1] = 1; + } + if (valueShape != tdescShape) return emitOpError("Unexpected result shape") << "(Expected shape: " << makeString(tdescShape) @@ -552,10 +580,23 @@ LogicalResult StoreScatterOp::verify() { if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) - return emitOpError("load_gather has to be transposed."); + return emitOpError("Store of a rank-2 tensor has to be transposed."); transpose({1, 0}, tdescShape); } + if (auto sgMap = tdescTy.getSGMapAttr()) { + auto valueVecTy = cast(valueTy); + const int32_t wiData = + sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1]; + // All represent the same concept: a number of row elements to store. + if (valueVecTy.getNumElements() != wiData || + valueVecTy.getNumElements() != tdescTy.getChunkSize()) { + return emitOpError("Chunk size, vector size and wi_data must match."); + } + // Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size]. + tdescShape[tdescTy.getRank() - 1] = 1; + } + if (valueShape != tdescShape) return emitOpError("Unexpected value shape") << "(Expected shape: " << makeString(tdescShape) diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index d7174a489888..dcd6b01974cf 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -163,11 +163,69 @@ gpu.func @test_create_tdesc_vc_1(%src: memref) { gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> - //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> gpu.return } +// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) { +gpu.func @test_load_with_sg_map(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<2x1xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<2x1xf32> + gpu.return +} + +// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) { +gpu.func @test_load_with_sg_map_2(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> -> vector<1xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> -> vector<1xf32> + gpu.return +} + +// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) { +gpu.func @test_store_with_sg_map(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32> + %2 = arith.constant dense<2.9>: vector<2x1xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + gpu.return +} + +// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) { +gpu.func @test_store_with_sg_map_2(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[cst1:.*]] = arith.constant dense : vector<4xi1> + %1 = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> + %2 = arith.constant dense<2.9>: vector<1xf32> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map>, vector<4xi1> + gpu.return +} + + + // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) { gpu.func @test_prefetch_vc(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 7816bff0582f..201f72120cf2 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -170,6 +170,83 @@ func.func @test_prefetch_vc_2(%src: ui64) { return } +// ----- +func.func @test_create_tdesc_sg_map_1(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map> + return +} + +// ----- +func.func @test_create_tdesc_sg_map_2(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + return +} + +// ----- +func.func @test_create_tdesc_sg_map_3(%src: ui64) { + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}} + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + return +} + +// ----- +func.func @test_load_gather_sg_map_1(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<1x2xf32> + return +} + +// ----- +func.func @test_load_gather_sg_map_2(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<2xf32> + return +} + +// ----- +func.func @test_load_gather_sg_map_3(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Chunk size, vector size and wi_data must match}} + %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> -> vector<1xf32> + return +} + + +// ----- +func.func @test_store_scatter_sg_map_1(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %val = arith.constant dense<2.9>: vector<1x2xf32> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}} + xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + return +} + +// ----- +func.func @test_store_scatter_sg_map_2(%src: ui64) { + %0 = arith.constant dense<1>: vector<4xi1> + %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %val = arith.constant dense<2.9>: vector<2xf32> + %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map> + // expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}} + xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr, #xegpu.sg_map>, vector<4xi1> + return +} + // ----- func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) { %0 = arith.constant dense<1>: vector<4xi1>