Skip to content

Commit

Permalink
[LLVMCPU] Don't limit the distribution tile sizes (#19417)
Browse files Browse the repository at this point in the history
Removes the `kNumMaxParallelDims` constraint on the partitionable loops
while determining the distribution tile sizes. Adds `affine.linearize`
ops which needs to be resolved after reconcile translation info pass.
  • Loading branch information
pashu123 authored Dec 10, 2024
1 parent 415a0f6 commit 8cdcb7e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -866,8 +866,7 @@ getDefaultDistributedLevelTileSizes(Operation *op,
SmallVector<int64_t> adjustedMaxTileSizes(numLoops, 0);
SmallVector<int64_t> adjustedVectorSizeHints(numLoops, 1);
SmallVector<unsigned> partitionableLoops =
cast<PartitionableLoopsInterface>(op).getPartitionableLoops(
kNumMaxParallelDims);
cast<PartitionableLoopsInterface>(op).getPartitionableLoops(std::nullopt);
for (auto i : partitionableLoops) {
adjustedMinTileSizes[i] =
config.minTileSizes.empty() ? 1 : config.minTileSizes[i];
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "iree/compiler/Utils/PassUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
Expand Down Expand Up @@ -818,6 +819,7 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager,
}

variantPassManager.addPass(createReconcileTranslationInfoPass());
variantPassManager.addPass(createLowerAffinePass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());

// Run conversion to LLVM at `ModuleOp` granularity.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func.func @add4D() attributes {hal.executable.target = #executable_target_embedd
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 64, 64], [1, 1, 1, 4], [0, 0, 0, 0], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 64, 64], [1, 1, 1, 4], [0, 0, 0, 0], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
// CHECK: func.func @add4D()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -224,7 +224,7 @@ func.func @add_static() attributes {hal.executable.target = #executable_target_e
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 2, 32, 64], [1, 1, 1, 4], [0, 0, 0, 0], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 16, 32, 64], [1, 1, 1, 4], [0, 0, 0, 0], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
// CHECK: func.func @add_static()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -457,7 +457,7 @@ func.func @conv_dynamic() attributes {hal.executable.target = #executable_target
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 64, 64, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 64, 64, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUConvTileAndDecomposeExpert>
// CHECK: func.func @conv_dynamic()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -849,7 +849,7 @@ func.func @generic_unit_dims_dynamic() attributes {hal.executable.target = #exec
flow.dispatch.tensor.store %8, %5, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, %0, 1, 1, %1, %2, 1, %3], strides = [1, 1, 1, 1, 1, 1, 1, 1] : tensor<1x?x1x1x?x?x1x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x?x1x1x?x?x1x?xf32>>{%0, %1, %2, %3}
return
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 0, 0, 0, 64, 64, 0, 64], [1, 1, 1, 1, 1, 1, 1, 4], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 64, 0, 0, 64, 64, 0, 64], [1, 1, 1, 1, 1, 1, 1, 4], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
// CHECK: func.func @generic_unit_dims_dynamic()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -1632,7 +1632,7 @@ func.func @pad_only() attributes {hal.executable.target = #executable_target_emb
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 6, 57, 64], [1, 1, 1, 4], [0, 0, 0, 0], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 6, 57, 64], [1, 1, 1, 4], [0, 0, 0, 0], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPUDoubleTilingExpert>
// CHECK: func.func @pad_only()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -1660,7 +1660,7 @@ func.func @winograd_output_transform() attributes {hal.executable.target = #exec
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [2, 36, 36, 128], strides = [1, 1, 1, 1] : tensor<2x36x36x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x36x36x128xf16>>
return
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 2, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_output_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1687,7 +1687,7 @@ func.func @winograd_input_transform() attributes {hal.executable.target = #execu
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 8, 2, 6, 6, 128], strides = [1, 1, 1, 1, 1, 1] : tensor<8x8x2x6x6x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<8x8x2x6x6x128xf16>>
return
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 2, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_input_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -1766,6 +1766,78 @@ func.func @attention() attributes {hal.executable.target = #executable_target_em

// -----

#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {
cpu = "generic", cpu_features = "",
data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128",
native_vector_size = 64 : index, target_triple = "x86_64-none-elf"}>
func.func @attention_transpose_distribute_4d() attributes {hal.executable.target = #executable_target_embedded_elf_x86_64_} {
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 8.837890e-02 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
%3 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
%4 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32
%5 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32
%6 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32
%7 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32
%8 = hal.interface.constant.load layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(8) : i32
%9 = arith.extui %0 : i32 to i64
%10 = arith.extui %1 : i32 to i64
%11 = arith.shli %10, %c32_i64 : i64
%12 = arith.ori %9, %11 : i64
%13 = arith.index_castui %12 : i64 to index
%14 = arith.extui %2 : i32 to i64
%15 = arith.extui %3 : i32 to i64
%16 = arith.shli %15, %c32_i64 : i64
%17 = arith.ori %14, %16 : i64
%18 = arith.index_castui %17 : i64 to index
%19 = arith.extui %4 : i32 to i64
%20 = arith.extui %5 : i32 to i64
%21 = arith.shli %20, %c32_i64 : i64
%22 = arith.ori %19, %21 : i64
%23 = arith.index_castui %22 : i64 to index
%24 = arith.extui %6 : i32 to i64
%25 = arith.extui %7 : i32 to i64
%26 = arith.shli %25, %c32_i64 : i64
%27 = arith.ori %24, %26 : i64
%28 = arith.index_castui %27 : i64 to index
%29 = arith.index_castui %8 : i32 to index
%30 = util.assume.int %29<umin = 16, umax = 131056, udiv = 16> : index
%31 = flow.dispatch.workload.ordinal %30, 0 : index
%32 = hal.interface.binding.subspan layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%13) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x4x?x128xf16>>{%31}
%33 = hal.interface.binding.subspan layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%18) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x4x?x1x1x128xf16>>{%31}
%34 = hal.interface.binding.subspan layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%23) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x4x?x1x1x128xf16>>{%31}
%35 = hal.interface.binding.subspan layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4x4x?x?x1x1xf16>>{%31, %31}
%36 = hal.interface.binding.subspan layout(<constants = 9, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%28) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<4x?x4x128xf16>>{%31}
%37 = flow.dispatch.tensor.load %32, offsets = [0, 0, 0, 0], sizes = [4, 4, %31, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x4x?x128xf16>>{%31} -> tensor<4x4x?x128xf16>
%38 = flow.dispatch.tensor.load %33, offsets = [0, 0, 0, 0, 0, 0], sizes = [4, 4, %31, 1, 1, 128], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x4x?x1x1x128xf16>>{%31} -> tensor<4x4x?x1x1x128xf16>
%39 = flow.dispatch.tensor.load %34, offsets = [0, 0, 0, 0, 0, 0], sizes = [4, 4, %31, 1, 1, 128], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x4x?x1x1x128xf16>>{%31} -> tensor<4x4x?x1x1x128xf16>
%40 = flow.dispatch.tensor.load %35, offsets = [0, 0, 0, 0, 0, 0], sizes = [4, 4, %31, %31, 1, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x4x?x?x1x1xf16>>{%31, %31} -> tensor<4x4x?x?x1x1xf16>
%41 = tensor.empty(%31) : tensor<4x?x4x128xf16>
%42 = tensor.empty(%31) : tensor<4x4x?x128xf16>
%43 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d6, d7, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3)>]} ins(%37, %38, %39, %cst, %40 : tensor<4x4x?x128xf16>, tensor<4x4x?x1x1x128xf16>, tensor<4x4x?x1x1x128xf16>, f16, tensor<4x4x?x?x1x1xf16>) outs(%42 : tensor<4x4x?x128xf16>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<4x4x?x128xf16>
%44 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%43 : tensor<4x4x?x128xf16>) outs(%41 : tensor<4x?x4x128xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<4x?x4x128xf16>
flow.dispatch.tensor.store %44, %36, offsets = [0, 0, 0, 0], sizes = [4, %31, 4, 128], strides = [1, 1, 1, 1] : tensor<4x?x4x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<4x?x4x128xf16>>{%31}
return
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 64, 64, 0, 0, 0, 0], [1, 1, 1, 32, 0, 0, 0, 0], [0, 0, 0, 0, 0, 32, 1, 1]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = CPULinalgExtTileAndVectorize>
// CHECK: func.func @attention_transpose_distribute_4d
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.attention
// CHECK-SAME: lowering_config = #[[CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
Expand Down

0 comments on commit 8cdcb7e

Please sign in to comment.