Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ iree_compiler_cc_library(
"FissionTransferOpsInControlFlow.cpp",
"FlattenMemRefSubspanPass.cpp",
"FlattenMemRefs.cpp",
"FlattenSwizzleHintAllocs.cpp",
"FoldAffineMinInDistributedLoops.cpp",
"FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp",
"FoldTensorExtractOpPass.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ iree_cc_library(
"FissionTransferOpsInControlFlow.cpp"
"FlattenMemRefSubspanPass.cpp"
"FlattenMemRefs.cpp"
"FlattenSwizzleHintAllocs.cpp"
"FoldAffineMinInDistributedLoops.cpp"
"FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp"
"FoldTensorExtractOpPass.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2026 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_FLATTENSWIZZLEHINTALLOCSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {
struct FlattenSwizzleHintAllocsPass final
: impl::FlattenSwizzleHintAllocsPassBase<FlattenSwizzleHintAllocsPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace

/// This pass flattens swizzle hint ops that operate on allocations of rank > 1.
/// This is required since swizzle hint op indices require flat memrefs.
///
/// Example:
/// ```
/// %0 = iree.alloc() : tensor<512x32xf4E2M1FN>
/// %1 = iree.swizzle_hint %0 : tensor<512x32xf4E2M1FN> ->
/// tensor<512x32xf4E2M1FN>
/// ```
///
/// is flattened to:
/// ```
/// %0 = iree.alloc() : tensor<16384xf4E2M1FN>
/// %1 = iree.swizzle_hint %0 : tensor<16384xf4E2M1FN> -> tensor<16384xf4E2M1FN>
/// %2 = iree.expand_shape %1 : tensor<16384xf4E2M1FN> ->
/// tensor<512x32xf4E2M1FN>
/// ```
static void flattenSwizzleHintAllocs(RewriterBase &rewriter,
IREE::Codegen::SwizzleHintOp hintOp) {
auto allocOp = hintOp.getOperand().getDefiningOp<memref::AllocOp>();
if (!allocOp || !allocOp->hasOneUse()) {
return;
}
MemRefType resultType = allocOp.getType();
if (resultType.getRank() == 1 || !resultType.getLayout().isIdentity() ||
!memref::isStaticShapeAndContiguousRowMajor(resultType)) {
return;
}

SmallVector<int64_t> newResultShape = {resultType.getNumElements()};
auto newResultType =
MemRefType::get(newResultShape, resultType.getElementType(), AffineMap(),
resultType.getMemorySpace());
rewriter.setInsertionPoint(hintOp);
ReassociationIndices reassoc =
llvm::to_vector(llvm::seq(resultType.getRank()));
auto newAllocOp =
memref::AllocOp::create(rewriter, hintOp.getLoc(), newResultType);
auto newSwizzleHintOp = IREE::Codegen::SwizzleHintOp::create(
rewriter, hintOp.getLoc(), newAllocOp.getResult(), hintOp.getSwizzle());
auto expandShape = memref::ExpandShapeOp::create(rewriter, hintOp.getLoc(),
resultType.getShape(),
newSwizzleHintOp, {reassoc});
rewriter.replaceOp(hintOp, expandShape);
}

void FlattenSwizzleHintAllocsPass::runOnOperation() {
FunctionOpInterface funcOp = getOperation();
// Collect all swizzle hint ops that operate on allocations.
// Flatten all allocs of rank > 1.
SmallVector<IREE::Codegen::SwizzleHintOp> hintOps;
funcOp.walk(
[&](IREE::Codegen::SwizzleHintOp hint) { hintOps.push_back(hint); });

IRRewriter rewriter(funcOp->getContext());
for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) {
flattenSwizzleHintAllocs(rewriter, hintOp);
}
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
[
"amdgpu_lower_coalesced_dma_to_gather_lds.mlir",
"decompose_horizontally_fused_gemms.mlir",
"flatten_swizzle_hint_allocs.mlir",
"gpu_alloc_private_memory_for_dps_ops.mlir",
"gpu_apply_derived_thread_config.mlir",
"gpu_apply_padding_online_attention.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
SRCS
"amdgpu_lower_coalesced_dma_to_gather_lds.mlir"
"decompose_horizontally_fused_gemms.mlir"
"flatten_swizzle_hint_allocs.mlir"
"gpu_alloc_private_memory_for_dps_ops.mlir"
"gpu_apply_derived_thread_config.mlir"
"gpu_apply_padding_online_attention.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: iree-opt --allow-unregistered-dialect --pass-pipeline="builtin.module(func.func(iree-codegen-flatten-swizzle-hint-allocs))" \
// RUN: --mlir-print-local-scope %s | FileCheck %s

// Test: 1D alloc should NOT be flattened (already 1D).
func.func @skip_1d_alloc() {
%alloc = memref.alloc() : memref<2048xf32, #gpu.address_space<workgroup>>
%0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space<workgroup>>
"test.use"(%0) : (memref<2048xf32, #gpu.address_space<workgroup>>) -> ()
return
}

// CHECK-LABEL: func @skip_1d_alloc
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space<workgroup>>
// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space<workgroup>>
// CHECK-NOT: memref.expand_shape
// CHECK: "test.use"(%[[HINT]])

// Test: 2D alloc with swizzle hint should be flattened to 1D.
func.func @flatten_2d_alloc() {
%alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space<workgroup>>
%0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space<workgroup>>
"test.use"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
return
}

// CHECK-LABEL: func @flatten_2d_alloc
// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space<workgroup>>
// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space<workgroup>>
// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : memref<2048xf32, #gpu.address_space<workgroup>> into memref<32x64xf32, #gpu.address_space<workgroup>>
// CHECK: "test.use"(%[[EXPAND]])
// CHECK-NOT: memref.alloc() : memref<32x64xf32
// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<32x64xf32
// CHECK: return

// Test: 3D alloc with swizzle hint should be flattened to 1D.
func.func @flatten_3d_alloc() {
%alloc = memref.alloc() : memref<4x8x16xf32, #gpu.address_space<workgroup>>
%0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<4x8x16xf32, #gpu.address_space<workgroup>>
"test.use"(%0) : (memref<4x8x16xf32, #gpu.address_space<workgroup>>) -> ()
return
}

// CHECK-LABEL: func @flatten_3d_alloc
// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<512xf32, #gpu.address_space<workgroup>>
// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<512xf32, #gpu.address_space<workgroup>>
// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1, 2{{\]\]}} output_shape [4, 8, 16] : memref<512xf32, #gpu.address_space<workgroup>> into memref<4x8x16xf32, #gpu.address_space<workgroup>>
// CHECK: "test.use"(%[[EXPAND]])
// CHECK-NOT: memref.alloc() : memref<4x8x16xf32
// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<4x8x16xf32
// CHECK: return

// Test: Non-alloc operand should NOT be affected.
func.func @skip_non_alloc(%arg0: memref<32x64xf32, #gpu.address_space<workgroup>>) {
%0 = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space<workgroup>>
"test.use"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
return
}

// CHECK-LABEL: func @skip_non_alloc
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: memref<32x64xf32, #gpu.address_space<workgroup>>
// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space<workgroup>>
// CHECK-NOT: memref.expand_shape
// CHECK: "test.use"(%[[HINT]])

// Test: Alloc with multiple uses should NOT be flattened.
func.func @skip_multi_use_alloc() {
%alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space<workgroup>>
%0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space<workgroup>>
"test.use"(%alloc) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
"test.use"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
return
}

// CHECK-LABEL: func @skip_multi_use_alloc
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32, #gpu.address_space<workgroup>>
// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space<workgroup>>
// CHECK-NOT: memref.expand_shape
// CHECK: "test.use"(%[[ALLOC]])
// CHECK: "test.use"(%[[HINT]])

// Test: XOR shuffle swizzle attribute.
func.func @flatten_xor_shuffle() {
%alloc = memref.alloc() : memref<16x128xi8, #gpu.address_space<workgroup>>
%0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<16x128xi8, #gpu.address_space<workgroup>>
"test.use"(%0) : (memref<16x128xi8, #gpu.address_space<workgroup>>) -> ()
return
}

// CHECK-LABEL: func @flatten_xor_shuffle
// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xi8, #gpu.address_space<workgroup>>
// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.xor_shuffle<128, 16>] : memref<2048xi8, #gpu.address_space<workgroup>>
// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [16, 128] : memref<2048xi8, #gpu.address_space<workgroup>> into memref<16x128xi8, #gpu.address_space<workgroup>>
// CHECK: "test.use"(%[[EXPAND]])
// CHECK-NOT: memref.alloc() : memref<16x128xi8
// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<16x128xi8
// CHECK: return
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ def FlattenMemRefSubspanPass : Pass<"iree-codegen-flatten-memref-subspan", "Modu
}];
}

def FlattenSwizzleHintAllocsPass :
InterfacePass<"iree-codegen-flatten-swizzle-hint-allocs", "mlir::FunctionOpInterface"> {
let summary = "Flattens allocations associated with iree_codegen.swizzle_hint ops";
}

def FoldAffineMinInDistributedLoopsPass :
InterfacePass<"iree-codegen-fold-affinemin-in-distributed-loops", "mlir::FunctionOpInterface"> {
let summary = "Fold `affine.min` ops in distributed loops";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -157,6 +158,22 @@ static void swizzleGatherToLDS(RewriterBase &rewriter,
});
}

static LogicalResult
verifyFlatContiguousSwizzleHintOp(IREE::Codegen::SwizzleHintOp hintOp) {
auto memrefType = cast<MemRefType>(hintOp.getOperand().getType());
// Swizzle hints require flat (rank 1) memrefs.
// For rank 1, allow dynamic memrefs or static contiguous row-major memrefs.
if ((memrefType.getRank() != 1 || !memrefType.getLayout().isIdentity()) ||
(memrefType.hasStaticShape() &&
!memref::isStaticShapeAndContiguousRowMajor(memrefType))) {
hintOp.emitError()
<< "swizzle hint operand must be a contiguous flat memref, got "
<< hintOp.getOperand().getType();
return failure();
}
return success();
}

/// Resolves all hints. Walks all direct users and splits them into loads and
/// stores. If any user is not a swizzle-able load or store, bail out and
/// silently drop the optimization hint.
Expand Down Expand Up @@ -189,7 +206,7 @@ static void resolveHintOp(RewriterBase &rewriter,
}
if (auto gatherToLDSOp = dyn_cast<amdgpu::GatherToLDSOp>(user)) {
// Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a
// subgroup contiguously in order of lane ID
// subgroup contiguously in order of lane ID.
if (gatherToLDSOp.getDst() == hintOp) {
continue;
}
Expand Down Expand Up @@ -242,6 +259,9 @@ void ResolveSwizzleHintsPass::runOnOperation() {
// silently pass through for that hint.
IRRewriter rewriter(funcOp->getContext());
for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) {
if (failed(verifyFlatContiguousSwizzleHintOp(hintOp))) {
return signalPassFailure();
}
resolveHintOp(rewriter, hintOp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,24 @@ func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #a
// CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index
// CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]

// -----

// Verify that swizzle_hint fails on non-flat (rank > 1) memrefs.
func.func @swizzle_hint_non_flat_memref_error(%src: memref<32x64xf32>) -> vector<4xf32> {
// expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32>'}}
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32>
%offset = arith.constant 0 : index
%1 = vector.load %0[%offset, %offset] : memref<32x64xf32>, vector<4xf32>
return %1: vector<4xf32>
}

// Verify that swizzle_hint fails on non-contiguous memrefs.
func.func @swizzle_hint_non_contiguous_memref_error() -> vector<4xf32> {
%src = memref.alloc() : memref<32x64xf32, strided<[2, 1], offset: 0>>
// expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32, strided<[2, 1]>>'}}
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, strided<[2, 1], offset: 0>>
%offset = arith.constant 0 : index
%1 = vector.load %0[%offset, %offset] : memref<32x64xf32, strided<[2, 1], offset: 0>>, vector<4xf32>
return %1: vector<4xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def IREECodegen_SwizzleHintOp : Op<IREECodegen_Dialect, "swizzle_hint", [
is otherwise perfectly legal.
}];

let arguments = (ins RankedTensorOrMemRefOf<[AnyType], [1]>:$operand,
let arguments = (ins AnyRankedTensorOrMemRef:$operand,
IREECodegen_AnySwizzleAttr:$swizzle);
let results = (outs RankedTensorOrMemRefOf<[AnyType], [1]>:$result);
let results = (outs AnyRankedTensorOrMemRef:$result);

let assemblyFormat = [{
$operand `[` $swizzle attr-dict `]` `:` type($result)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
Expand Down Expand Up @@ -579,6 +580,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

// Step 9. Remaining post-bufferization optimizations/lowerings.
funcPassManager.addPass(createFlattenSwizzleHintAllocsPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
funcPassManager.addPass(createUnrollAnnotatedLoopsPass());
Expand Down
Loading