Skip to content

Commit

Permalink
[MLIR][XeGPU] Add sg_map for scatter verification (#124300)
Browse files Browse the repository at this point in the history
This PR adds sg_map to the verification of scatter ops in XeGPU.

The documentation says `chunk_size: indicates the number of continuous
elements accessed for each offset`, it also mentions the fact that
scatter ops are SG-level.
Hence, if an operation is distributed to work-items, a 1-d load means a
work-item reads one element, a 2-d load means a work-item loads
`chunk-size` or second dimension of tdesc elements. The changes in this
PR reflect the documentation with the presence of sg_map attribute
(i.e., distributed case).
  • Loading branch information
akroviakov authored Jan 30, 2025
1 parent 4985804 commit e436bf6
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 10 deletions.
7 changes: 2 additions & 5 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
let hasVerifier = 1;
}

def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
AllElementTypesMatch<["value", "TensorDesc"]>,
AllElementCountsMatch<["value", "TensorDesc"]>]> {
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "load a set of scattered data points from memory.";

let description = [{ It (aka. load) load data per each work-item. The output
Expand Down Expand Up @@ -620,8 +618,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
let hasVerifier = 1;
}

def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
AllElementTypesMatch<["value", "TensorDesc"]>]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
Expand Down
47 changes: 44 additions & 3 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,22 @@ LogicalResult CreateDescOp::verify() {
if (shape != tdescShape)
return emitOpError("Incorrect TensorDesc shape. ")
<< "Expected is " << makeString(shape) << "\n";

if (auto sgMap = tdescTy.getSGMapAttr()) {
// A work-item's slice of the TensorDesc with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
// the mapping should reflect that.
if (sgMap.getWiData()[0] > 1)
return emitOpError("TensorDesc's SG map only supports multiple elements "
"contiguous along rows.");
if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
return emitOpError(
"TensorDesc's chunkSize must match WI's data mapping.");
if (int rank = tdescTy.getRank();
(sgMap.getWiLayout()[2 - rank] != tdescShape[0]))
return emitOpError("Detected a conflict between SG map's work-item "
"layout and TensorDesc shape. Check the index of "
"`subgroup_size` in WI layout map.");
}
return success();
}

Expand Down Expand Up @@ -513,10 +528,23 @@ LogicalResult LoadGatherOp::verify() {

if (tdescTy.getRank() == 2) {
if (!getTransposeAttr())
return emitOpError("load_gather has to be transposed.");
return emitOpError("load of rank-2 tensor has to be transposed.");
transpose({1, 0}, tdescShape);
}

if (auto sgMap = tdescTy.getSGMapAttr()) {
auto valueVecTy = cast<VectorType>(valueTy);
const int32_t wiData =
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
// All represent the same concept: a number of row elements to store.
if (valueVecTy.getNumElements() != wiData ||
valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
return emitOpError("Chunk size, vector size and wi_data must match.");
}
// Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size].
tdescShape[tdescTy.getRank() - 1] = 1;
}

if (valueShape != tdescShape)
return emitOpError("Unexpected result shape")
<< "(Expected shape: " << makeString(tdescShape)
Expand Down Expand Up @@ -552,10 +580,23 @@ LogicalResult StoreScatterOp::verify() {

if (tdescTy.getRank() == 2) {
if (!getTransposeAttr())
return emitOpError("load_gather has to be transposed.");
return emitOpError("Store of a rank-2 tensor has to be transposed.");
transpose({1, 0}, tdescShape);
}

if (auto sgMap = tdescTy.getSGMapAttr()) {
auto valueVecTy = cast<VectorType>(valueTy);
const int32_t wiData =
sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
// All represent the same concept: a number of row elements to store.
if (valueVecTy.getNumElements() != wiData ||
valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
return emitOpError("Chunk size, vector size and wi_data must match.");
}
// Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size].
tdescShape[tdescTy.getRank() - 1] = 1;
}

if (valueShape != tdescShape)
return emitOpError("Unexpected value shape")
<< "(Expected shape: " << makeString(tdescShape)
Expand Down
62 changes: 60 additions & 2 deletions mlir/test/Dialect/XeGPU/XeGPUOps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,69 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
gpu.return
}

// CHECK: gpu.func @test_load_with_sg_map(%[[arg0:.*]]: ui64) {
gpu.func @test_load_with_sg_map(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2x1xf32>
gpu.return
}

// CHECK: gpu.func @test_load_with_sg_map_2(%[[arg0:.*]]: ui64) {
gpu.func @test_load_with_sg_map_2(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
%2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
//CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
%3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1> -> vector<1xf32>
gpu.return
}

// CHECK: gpu.func @test_store_with_sg_map(%[[arg0:.*]]: ui64) {
gpu.func @test_store_with_sg_map(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x1xf32>
%2 = arith.constant dense<2.9>: vector<2x1xf32>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x1xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
gpu.return
}

// CHECK: gpu.func @test_store_with_sg_map_2(%[[arg0:.*]]: ui64) {
gpu.func @test_store_with_sg_map_2(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
//CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
%1 = arith.constant dense<1>: vector<4xi1>
//CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
%2 = arith.constant dense<2.9>: vector<1xf32>
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
%3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>
//CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf32>, !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [1, 1]>>, vector<4xi1>
gpu.return
}



// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
gpu.func @test_prefetch_vc(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
Expand Down
77 changes: 77 additions & 0 deletions mlir/test/Dialect/XeGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,83 @@ func.func @test_prefetch_vc_2(%src: ui64) {
return
}

// -----
func.func @test_create_tdesc_sg_map_1(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
// expected-error@+1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}}
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
return
}

// -----
func.func @test_create_tdesc_sg_map_2(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
// expected-error@+1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}}
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
return
}

// -----
func.func @test_create_tdesc_sg_map_3(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
// expected-error@+1 {{TensorDesc's chunkSize must match WI's data mapping}}
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
return
}

// -----
func.func @test_load_gather_sg_map_1(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [1, 2])}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1x2xf32>
return
}

// -----
func.func @test_load_gather_sg_map_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
// expected-error@+1 {{Unexpected result shape(Expected shape: [2, 1], Given shape: [2])}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<2xf32>
return
}

// -----
func.func @test_load_gather_sg_map_3(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
// expected-error@+1 {{Chunk size, vector size and wi_data must match}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1> -> vector<1xf32>
return
}


// -----
func.func @test_store_scatter_sg_map_1(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%val = arith.constant dense<2.9>: vector<1x2xf32>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [1, 2])}}
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<1x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
return
}

// -----
func.func @test_store_scatter_sg_map_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%val = arith.constant dense<2.9>: vector<2xf32>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
// expected-error@+1 {{Unexpected value shape(Expected shape: [2, 1], Given shape: [2])}}
xegpu.store %val, %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, transpose}> : vector<2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>, vector<4xi1>
return
}

// -----
func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) {
%0 = arith.constant dense<1>: vector<4xi1>
Expand Down

0 comments on commit e436bf6

Please sign in to comment.