Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,55 @@ struct DistributionInfo {
bool vectorizable = false;
};

// Generally parallel loops are partitionable, however if the dims for it
// are not present in a producer of the compute op within the dispatch and the
// results of that producer op are also returned from the dispatch
// then that dim is not partitioned as codegen for this is unsupported.
static SmallVector<unsigned int>
getSupportedPartitionableLoops(linalg::LinalgOp linalgOp) {
SmallVector<unsigned int> partitionableLoops;
linalgOp.getParallelDims(partitionableLoops);
SmallVector<OpOperand *> producerOperands;
for (auto operand : linalgOp.getDpsInputOperands()) {
auto producerOp = operand->get().getDefiningOp<linalg::LinalgOp>();
if (!producerOp) {
continue;
}

for (Operation *user : producerOp->getUsers()) {
if (isa<IREE::Codegen::StoreToBufferOp>(user)) {
producerOperands.push_back(operand);
break;
}
}
}
if (producerOperands.empty()) {
return partitionableLoops;
}
// If we have producer operands then we need to confirm that all of them
// also have the the partitionableLoop dims if not we skip that dim.
SmallVector<unsigned int> finalPartitionableLoops;
for (auto dim : partitionableLoops) {
bool dimFound = false;
for (auto operand : producerOperands) {
AffineMap IndexingMap = linalgOp.getMatchingIndexingMap(operand);
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
auto dimExpr = dyn_cast<AffineDimExpr>(expr);
return dimExpr && dimExpr.getPosition() == dim;
})) {
dimFound = true;
} else {
dimFound = false;
break;
}
}
if (dimFound) {
finalPartitionableLoops.push_back(dim);
}
}
return finalPartitionableLoops;
}

static FailureOr<DistributionInfo> collectOpDistributionInfo(Operation *op) {
DistributionInfo distInfo;
// MapScatterOp doesn't fit the LinalgOp interface, so use special case logic
Expand Down Expand Up @@ -1054,19 +1103,17 @@ static FailureOr<DistributionInfo> collectOpDistributionInfo(Operation *op) {
}

auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
// Bail out on multi result cases as consumer fusion currently does not
// support multi result ops.
if (!linalgOp || linalgOp.getNumDpsInits() != 1) {
if (!linalgOp) {
return failure();
}

// This pipeline requires tensor semantics. Also fail for gather semantics
// for now to simplify tile + fuse.
if (!linalgOp.hasPureTensorSemantics() || linalgOp.hasIndexSemantics()) {
if (!linalgOp.hasPureTensorSemantics() ||
LinalgExt::isGatherlikeOp(linalgOp)) {
return failure();
}

linalgOp.getParallelDims(distInfo.partitionableLoops);
distInfo.partitionableLoops = getSupportedPartitionableLoops(linalgOp);

// Bail out if op is not tilable.
if (distInfo.partitionableLoops.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,120 @@ func.func @dyn_parallel_reduction(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
// CHECK-SAME: reduction = [0, 4]
// CHECK-SAME: thread = [1, 0]
// CHECK-SAME: workgroup = [64, 0]

// -----
func.func @multi_result_index_generic_with_scatterfusion(%arg0: tensor<4x?x32x8xf16>, %arg1: tensor<4x?xi64>) -> (tensor<?x8x32xf8E4M3FNUZ>, tensor<4x?x32x8xf8E4M3FNUZ>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c1 : tensor<4x?x32x8xf16>
%0 = tensor.empty(%dim) : tensor<4x?x32x8xf8E4M3FNUZ>
%1 = tensor.empty(%dim) : tensor<4x?x8x32xf8E4M3FNUZ>
%2 = tensor.empty(%dim) : tensor<?x8x32xf8E4M3FNUZ>
%3:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<4x?x32x8xf16>)
outs(%0, %1 : tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>) {
^bb0(%in: f16, %out: f8E4M3FNUZ, %out_0: f8E4M3FNUZ):
%5 = linalg.index 1 : index
%6 = arith.mulf %in, %in : f16
%7 = arith.cmpi eq, %5, %c0 : index
%8 = arith.select %7, %6, %in : f16
%9 = arith.truncf %8 : f16 to f8E4M3FNUZ
linalg.yield %9, %9 : f8E4M3FNUZ, f8E4M3FNUZ
} -> (tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>)
%4 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%3#1, %arg1 : tensor<4x?x8x32xf8E4M3FNUZ>, tensor<4x?xi64>) outs(%2 : tensor<?x8x32xf8E4M3FNUZ>) {
^bb0(%arg2: f8E4M3FNUZ, %arg3: f8E4M3FNUZ):
iree_linalg_ext.yield %arg2 : f8E4M3FNUZ
} -> tensor<?x8x32xf8E4M3FNUZ>
return %4, %3#0 : tensor<?x8x32xf8E4M3FNUZ>, tensor<4x?x32x8xf8E4M3FNUZ>
}

// CHECK-LABEL: func.func @multi_result_index_generic_with_scatterfusion
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1, 1, 1, 4]
// CHECK-SAME: workgroup = [1, 1, 32, 8]

// -----
func.func @producer_broadcasted(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>) -> tensor<4x8xi64> {
%c59_i64 = arith.constant 59 : i64
%c2_i64 = arith.constant 2 : i64
%c8_i64 = arith.constant 8 : i64
%c32_i64 = arith.constant 32 : i64
%0 = tensor.empty() : tensor<4x8xi64>
%1 = tensor.empty() : tensor<4xi64>
%2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : tensor<4xi64>) outs(%1 : tensor<4xi64>) {
^bb0(%in: i64, %out: i64):
%4 = arith.addi %in, %c59_i64 : i64
%5 = arith.muli %4, %c2_i64 : i64
linalg.yield %5 : i64
} -> tensor<4xi64>
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%2, %arg1 : tensor<4xi64>, tensor<4xi64>) outs(%0 : tensor<4x8xi64>) {
^bb0(%in: i64, %in_0: i64, %out: i64):
%4 = linalg.index 1 : index
%5 = arith.index_cast %4 : index to i64
%6 = arith.muli %in, %c8_i64 : i64
%7 = arith.addi %6, %5 : i64
%8 = arith.muli %7, %c32_i64 : i64
%9 = arith.addi %8, %in_0 : i64
linalg.yield %9 : i64
} -> tensor<4x8xi64>
return %3 : tensor<4x8xi64>
}

// CHECK-LABEL: func.func @producer_broadcasted
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK: thread = [1, 1]
// CHECK-SAME: workgroup = [4, 16]

// -----
func.func @producer_broadcasted_and_stored_to_buffer(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>, %arg2 : memref<4xi64>) -> tensor<4x8xi64> {
%c59_i64 = arith.constant 59 : i64
%c2_i64 = arith.constant 2 : i64
%c8_i64 = arith.constant 8 : i64
%c32_i64 = arith.constant 32 : i64
%0 = tensor.empty() : tensor<4x8xi64>
%1 = tensor.empty() : tensor<4xi64>
%2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : tensor<4xi64>) outs(%1 : tensor<4xi64>) {
^bb0(%in: i64, %out: i64):
%4 = arith.addi %in, %c59_i64 : i64
%5 = arith.muli %4, %c2_i64 : i64
linalg.yield %5 : i64
} -> tensor<4xi64>
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%2, %arg1 : tensor<4xi64>, tensor<4xi64>) outs(%0 : tensor<4x8xi64>) {
^bb0(%in: i64, %in_0: i64, %out: i64):
%4 = linalg.index 1 : index
%5 = arith.index_cast %4 : index to i64
%6 = arith.muli %in, %c8_i64 : i64
%7 = arith.addi %6, %5 : i64
%8 = arith.muli %7, %c32_i64 : i64
%9 = arith.addi %8, %in_0 : i64
linalg.yield %9 : i64
} -> tensor<4x8xi64>
iree_codegen.store_to_buffer %2, %arg2 : tensor<4xi64> into memref<4xi64>
return %3 : tensor<4x8xi64>
}

// CHECK-LABEL: func.func @producer_broadcasted
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK: reduction = [0, 4]
// CHECK-SAME: thread = [1, 0]
// CHECK-SAME: workgroup = [64, 0]
Original file line number Diff line number Diff line change
Expand Up @@ -1390,3 +1390,64 @@ hal.executable public @main {
// CHECK: scf.forall {{.*}} in (2048, 8) {
// CHECK: %[[READ:.+]] = vector.transfer_read{{.*}}: memref<2048x2048xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
// CHECK: vector.scatter {{.*}} %[[READ]] : memref<4194304xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xindex>, vector<4xi1>, vector<4xf32>


// -----
#translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
#lowering_config = #iree_gpu.lowering_config<{thread = [1, 1, 1, 4], workgroup = [1, 1, 32, 8]}>

#pipeline_layout = #hal.pipeline.layout<bindings =
[#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
#hal.pipeline.binding<storage_buffer, ReadOnly>,
#hal.pipeline.binding<storage_buffer, Indirect>,
#hal.pipeline.binding<storage_buffer, Indirect>
]>

hal.executable public @multi_result_index_generic_with_scatter_fusion {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @multi_result_index_generic_with_scatter_fusion ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%arg1, %arg2, %arg3)
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @multi_result_index_generic_with_scatter_fusion() attributes {translation_info = #translation_info} {
%c32_i64 = arith.constant 32 : i64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%24 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c1) flags("ReadOnly|Indirect") : memref<4x?x32x8xf16, #hal.descriptor_type<storage_buffer>>{%c1}
%25 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c1) flags("ReadOnly|Indirect") : memref<4x?xi64, #hal.descriptor_type<storage_buffer>>{%c1}
%26 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<?x8x32xf8E4M3FNUZ, #hal.descriptor_type<storage_buffer>>{%c1}
%27 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c1) flags(Indirect) : memref<4x?x32x8xf8E4M3FNUZ, strided<[?, 256, 8, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%c1}
%28 = iree_codegen.load_from_buffer %24 : memref<4x?x32x8xf16, #hal.descriptor_type<storage_buffer>> -> tensor<4x?x32x8xf16>
%29 = iree_codegen.load_from_buffer %25 : memref<4x?xi64, #hal.descriptor_type<storage_buffer>> -> tensor<4x?xi64>
%30 = tensor.empty(%c1) : tensor<?x8x32xf8E4M3FNUZ>
%31 = tensor.empty(%c1) : tensor<4x?x8x32xf8E4M3FNUZ>
%32 = tensor.empty(%c1) : tensor<4x?x32x8xf8E4M3FNUZ>
%33:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%28 : tensor<4x?x32x8xf16>) outs(%32, %31 : tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>) attrs = {lowering_config = #lowering_config} {
^bb0(%in: f16, %out: f8E4M3FNUZ, %out_0: f8E4M3FNUZ):
%35 = linalg.index 1 : index
%36 = arith.mulf %in, %in : f16
%37 = arith.cmpi eq, %35, %c0 : index
%38 = arith.select %37, %36, %in : f16
%39 = arith.truncf %38 : f16 to f8E4M3FNUZ
linalg.yield %39, %39 : f8E4M3FNUZ, f8E4M3FNUZ
} -> (tensor<4x?x32x8xf8E4M3FNUZ>, tensor<4x?x8x32xf8E4M3FNUZ>)
%34 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%33#1, %29 : tensor<4x?x8x32xf8E4M3FNUZ>, tensor<4x?xi64>) outs(%30 : tensor<?x8x32xf8E4M3FNUZ>) {
^bb0(%arg0: f8E4M3FNUZ, %arg1: f8E4M3FNUZ):
iree_linalg_ext.yield %arg0 : f8E4M3FNUZ
} -> tensor<?x8x32xf8E4M3FNUZ>
iree_codegen.store_to_buffer %34, %26 : tensor<?x8x32xf8E4M3FNUZ> into memref<?x8x32xf8E4M3FNUZ, #hal.descriptor_type<storage_buffer>>
iree_codegen.store_to_buffer %33#0, %27 : tensor<4x?x32x8xf8E4M3FNUZ> into memref<4x?x32x8xf8E4M3FNUZ, strided<[?, 256, 8, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}
}
}
}
// CHECK-LABEL: func @multi_result_index_generic_with_scatter_fusion
// CHECK: scf.forall (%arg0, %arg1) in (4, 1) {
// CHECK: vector.transfer_read
// CHECK: arith.mulf {{.*}} vector<4xf16>
// CHECK: arith.truncf {{.*}} vector<4xf16> to vector<4xf8E4M3FNUZ>
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: iree_linalg_ext.scatter
Loading