diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index e95be646ac0f..82ecd24a75cc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -104,6 +104,7 @@ iree_compiler_cc_library( "FissionTransferOpsInControlFlow.cpp", "FlattenMemRefSubspanPass.cpp", "FlattenMemRefs.cpp", + "FlattenSwizzleHintAllocs.cpp", "FoldAffineMinInDistributedLoops.cpp", "FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp", "FoldTensorExtractOpPass.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 2d7e7b6ea2ec..0cf3e51afb15 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -97,6 +97,7 @@ iree_cc_library( "FissionTransferOpsInControlFlow.cpp" "FlattenMemRefSubspanPass.cpp" "FlattenMemRefs.cpp" + "FlattenSwizzleHintAllocs.cpp" "FoldAffineMinInDistributedLoops.cpp" "FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp" "FoldTensorExtractOpPass.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp new file mode 100644 index 000000000000..235ab5176d23 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp @@ -0,0 +1,87 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_FLATTENSWIZZLEHINTALLOCSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { +struct FlattenSwizzleHintAllocsPass final + : impl::FlattenSwizzleHintAllocsPassBase { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +/// This pass flattens swizzle hint ops that operate on allocations of rank > 1. +/// This is required since swizzle hint op indices require flat memrefs. +/// +/// Example: +/// ``` +/// %0 = iree.alloc() : tensor<512x32xf4E2M1FN> +/// %1 = iree.swizzle_hint %0 : tensor<512x32xf4E2M1FN> -> +/// tensor<512x32xf4E2M1FN> +/// ``` +/// +/// is flattened to: +/// ``` +/// %0 = iree.alloc() : tensor<16384xf4E2M1FN> +/// %1 = iree.swizzle_hint %0 : tensor<16384xf4E2M1FN> -> tensor<16384xf4E2M1FN> +/// %2 = iree.expand_shape %1 : tensor<16384xf4E2M1FN> -> +/// tensor<512x32xf4E2M1FN> +/// ``` +static void flattenSwizzleHintAllocs(RewriterBase &rewriter, + IREE::Codegen::SwizzleHintOp hintOp) { + auto allocOp = hintOp.getOperand().getDefiningOp(); + if (!allocOp || !allocOp->hasOneUse()) { + return; + } + MemRefType resultType = allocOp.getType(); + if (resultType.getRank() == 1 || !resultType.getLayout().isIdentity() || + !memref::isStaticShapeAndContiguousRowMajor(resultType)) { + return; + } + + SmallVector newResultShape = {resultType.getNumElements()}; + auto newResultType = + MemRefType::get(newResultShape, resultType.getElementType(), AffineMap(), + resultType.getMemorySpace()); + rewriter.setInsertionPoint(hintOp); + ReassociationIndices reassoc = + llvm::to_vector(llvm::seq(resultType.getRank())); + auto newAllocOp = + memref::AllocOp::create(rewriter, hintOp.getLoc(), newResultType); + auto newSwizzleHintOp = IREE::Codegen::SwizzleHintOp::create( + rewriter, hintOp.getLoc(), newAllocOp.getResult(), hintOp.getSwizzle()); + auto expandShape = memref::ExpandShapeOp::create(rewriter, hintOp.getLoc(), + resultType.getShape(), + newSwizzleHintOp, {reassoc}); + rewriter.replaceOp(hintOp, expandShape); +} + +void FlattenSwizzleHintAllocsPass::runOnOperation() { + FunctionOpInterface funcOp = getOperation(); + // Collect all swizzle hint ops that operate on allocations. + // Flatten all allocs of rank > 1. + SmallVector hintOps; + funcOp.walk( + [&](IREE::Codegen::SwizzleHintOp hint) { hintOps.push_back(hint); }); + + IRRewriter rewriter(funcOp->getContext()); + for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) { + flattenSwizzleHintAllocs(rewriter, hintOp); + } +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index c36e063de06b..2ef20e4eab60 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -20,6 +20,7 @@ iree_lit_test_suite( [ "amdgpu_lower_coalesced_dma_to_gather_lds.mlir", "decompose_horizontally_fused_gemms.mlir", + "flatten_swizzle_hint_allocs.mlir", "gpu_alloc_private_memory_for_dps_ops.mlir", "gpu_apply_derived_thread_config.mlir", "gpu_apply_padding_online_attention.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 002a332570b4..a041d22d8286 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "amdgpu_lower_coalesced_dma_to_gather_lds.mlir" "decompose_horizontally_fused_gemms.mlir" + "flatten_swizzle_hint_allocs.mlir" "gpu_alloc_private_memory_for_dps_ops.mlir" "gpu_apply_derived_thread_config.mlir" "gpu_apply_padding_online_attention.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir new file mode 100644 index 000000000000..b571927fa818 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir @@ -0,0 +1,96 @@ +// RUN: iree-opt --allow-unregistered-dialect --pass-pipeline="builtin.module(func.func(iree-codegen-flatten-swizzle-hint-allocs))" \ +// RUN: --mlir-print-local-scope %s | FileCheck %s + +// Test: 1D alloc should NOT be flattened (already 1D). +func.func @skip_1d_alloc() { + %alloc = memref.alloc() : memref<2048xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> + "test.use"(%0) : (memref<2048xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_1d_alloc +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[HINT]]) + +// Test: 2D alloc with swizzle hint should be flattened to 1D. +func.func @flatten_2d_alloc() { + %alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_2d_alloc +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : memref<2048xf32, #gpu.address_space> into memref<32x64xf32, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<32x64xf32 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<32x64xf32 +// CHECK: return + +// Test: 3D alloc with swizzle hint should be flattened to 1D. +func.func @flatten_3d_alloc() { + %alloc = memref.alloc() : memref<4x8x16xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<4x8x16xf32, #gpu.address_space> + "test.use"(%0) : (memref<4x8x16xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_3d_alloc +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<512xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<512xf32, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1, 2{{\]\]}} output_shape [4, 8, 16] : memref<512xf32, #gpu.address_space> into memref<4x8x16xf32, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<4x8x16xf32 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<4x8x16xf32 +// CHECK: return + +// Test: Non-alloc operand should NOT be affected. +func.func @skip_non_alloc(%arg0: memref<32x64xf32, #gpu.address_space>) { + %0 = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_non_alloc +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: memref<32x64xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[HINT]]) + +// Test: Alloc with multiple uses should NOT be flattened. +func.func @skip_multi_use_alloc() { + %alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%alloc) : (memref<32x64xf32, #gpu.address_space>) -> () + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_multi_use_alloc +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[ALLOC]]) +// CHECK: "test.use"(%[[HINT]]) + +// Test: XOR shuffle swizzle attribute. +func.func @flatten_xor_shuffle() { + %alloc = memref.alloc() : memref<16x128xi8, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<16x128xi8, #gpu.address_space> + "test.use"(%0) : (memref<16x128xi8, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_xor_shuffle +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xi8, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.xor_shuffle<128, 16>] : memref<2048xi8, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [16, 128] : memref<2048xi8, #gpu.address_space> into memref<16x128xi8, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<16x128xi8 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<16x128xi8 +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 3590144b0337..5ee6d21ccab9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -476,6 +476,11 @@ def FlattenMemRefSubspanPass : Pass<"iree-codegen-flatten-memref-subspan", "Modu }]; } +def FlattenSwizzleHintAllocsPass : + InterfacePass<"iree-codegen-flatten-swizzle-hint-allocs", "mlir::FunctionOpInterface"> { + let summary = "Flattens allocations associated with iree_codegen.swizzle_hint ops"; +} + def FoldAffineMinInDistributedLoopsPass : InterfacePass<"iree-codegen-fold-affinemin-in-distributed-loops", "mlir::FunctionOpInterface"> { let summary = "Fold `affine.min` ops in distributed loops"; diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp index bc2f6aef0c0c..b6f5994755e5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -157,6 +158,22 @@ static void swizzleGatherToLDS(RewriterBase &rewriter, }); } +static LogicalResult +verifyFlatContiguousSwizzleHintOp(IREE::Codegen::SwizzleHintOp hintOp) { + auto memrefType = cast(hintOp.getOperand().getType()); + // Swizzle hints require flat (rank 1) memrefs. + // For rank 1, allow dynamic memrefs or static contiguous row-major memrefs. + if ((memrefType.getRank() != 1 || !memrefType.getLayout().isIdentity()) || + (memrefType.hasStaticShape() && + !memref::isStaticShapeAndContiguousRowMajor(memrefType))) { + hintOp.emitError() + << "swizzle hint operand must be a contiguous flat memref, got " + << hintOp.getOperand().getType(); + return failure(); + } + return success(); +} + /// Resolves all hints. Walks all direct users and splits them into loads and /// stores. If any user is not a swizzle-able load or store, bail out and /// silently drop the optimization hint. @@ -189,7 +206,7 @@ static void resolveHintOp(RewriterBase &rewriter, } if (auto gatherToLDSOp = dyn_cast(user)) { // Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a - // subgroup contiguously in order of lane ID + // subgroup contiguously in order of lane ID. if (gatherToLDSOp.getDst() == hintOp) { continue; } @@ -242,6 +259,9 @@ void ResolveSwizzleHintsPass::runOnOperation() { // silently pass through for that hint. IRRewriter rewriter(funcOp->getContext()); for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) { + if (failed(verifyFlatContiguousSwizzleHintOp(hintOp))) { + return signalPassFailure(); + } resolveHintOp(rewriter, hintOp); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir index 4a77f8d8d2f3..6b6b9030a3ba 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir @@ -322,3 +322,24 @@ func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #a // CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index // CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space> // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]] + +// ----- + +// Verify that swizzle_hint fails on non-flat (rank > 1) memrefs. +func.func @swizzle_hint_non_flat_memref_error(%src: memref<32x64xf32>) -> vector<4xf32> { + // expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32>'}} + %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32> + %offset = arith.constant 0 : index + %1 = vector.load %0[%offset, %offset] : memref<32x64xf32>, vector<4xf32> + return %1: vector<4xf32> +} + +// Verify that swizzle_hint fails on non-contiguous memrefs. +func.func @swizzle_hint_non_contiguous_memref_error() -> vector<4xf32> { + %src = memref.alloc() : memref<32x64xf32, strided<[2, 1], offset: 0>> + // expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32, strided<[2, 1]>>'}} + %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, strided<[2, 1], offset: 0>> + %offset = arith.constant 0 : index + %1 = vector.load %0[%offset, %offset] : memref<32x64xf32, strided<[2, 1], offset: 0>>, vector<4xf32> + return %1: vector<4xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td index 8c2bdab9cff7..a59e341f9657 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td @@ -173,9 +173,9 @@ def IREECodegen_SwizzleHintOp : Op:$operand, + let arguments = (ins AnyRankedTensorOrMemRef:$operand, IREECodegen_AnySwizzleAttr:$swizzle); - let results = (outs RankedTensorOrMemRefOf<[AnyType], [1]>:$result); + let results = (outs AnyRankedTensorOrMemRef:$result); let assemblyFormat = [{ $operand `[` $swizzle attr-dict `]` `:` type($result) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 5590d03f71a9..757a1f2142ef 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" #include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" @@ -579,6 +580,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); // Step 9. Remaining post-bufferization optimizations/lowerings. + funcPassManager.addPass(createFlattenSwizzleHintAllocsPass()); funcPassManager.addPass(createPropagateDispatchSizeBoundsPass()); funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass()); funcPassManager.addPass(createUnrollAnnotatedLoopsPass());