From 6687498856e6e1a8050223803c62085360223415 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 04:46:36 -0800 Subject: [PATCH 01/71] Added pattern rewriter and lit tests for Torch::HigherOrderFlexAttentionOp -> LinalgExt::AttentionOp Signed-off-by: Keshav Vinayak Jha --- .../ConvertTorchUnstructuredToLinalgExt.cpp | 339 +++++++++++++++++- .../test/unstructured_linalg_ext.mlir | 164 ++++++++- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 18 +- 3 files changed, 509 insertions(+), 12 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 7f18e169f157..5dae736deff8 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -6,7 +6,15 @@ #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "llvm/ADT/APFloat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -163,21 +171,342 @@ struct FftRfftOpConversion } }; +// Utility to add a score modification region to the attention op. +void createScoreModificationRegion( + PatternRewriter &rewriter, Location loc, + IREE::LinalgExt::AttentionOp attentionOp, + std::optional scoreModSymbol, FloatType floatType, + const int kAttentionRank) { + OpBuilder::InsertionGuard g(rewriter); + Block *block = rewriter.createBlock(&attentionOp.getRegion()); + + block->addArgument(floatType, loc); + rewriter.setInsertionPointToStart(block); + + Value score = block->getArgument(0); + Value modifiedScore = score; + + if (scoreModSymbol) { + Type i32Type = rewriter.getI32Type(); + Type si32Type = + IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed); + RankedTensorType scalarTensorType = RankedTensorType::get({}, floatType); + torch::Torch::ValueTensorType torchScalarType = + rewriter.getType(ArrayRef{}, + floatType); + RankedTensorType i32ScalarTensorType = RankedTensorType::get({}, i32Type); + torch::Torch::ValueTensorType torchI32ScalarType = + rewriter.getType(ArrayRef{}, + si32Type); + + Value scoreTensor = tensor::FromElementsOp::create( + rewriter, loc, scalarTensorType, ValueRange{score}); + Value torchScore = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, torchScalarType, scoreTensor); + + SmallVector callArgs; + callArgs.push_back(torchScore); + + for (unsigned i = 0; i < kAttentionRank; ++i) { + Value idx = IREE::LinalgExt::IndexOp::create(rewriter, loc, i); + Value idxI32 = arith::IndexCastOp::create(rewriter, loc, i32Type, idx); + Value idxTensor = tensor::FromElementsOp::create( + rewriter, loc, i32ScalarTensorType, ValueRange{idxI32}); + Value torchIdx = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, torchI32ScalarType, idxTensor); + callArgs.push_back(torchIdx); + } + + auto callOp = + func::CallOp::create(rewriter, loc, TypeRange{torchScalarType}, + scoreModSymbol.value(), ValueRange(callArgs)); + Value torchResult = callOp.getResult(0); + + Value resultTensor = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, scalarTensorType, torchResult); + + modifiedScore = + tensor::ExtractOp::create(rewriter, loc, resultTensor, ValueRange{}); + } + + IREE::LinalgExt::YieldOp::create(rewriter, loc, modifiedScore); +} + +// Utility to compute dynamic sizes for attention tensors. +void computeDynamicSizes(PatternRewriter &rewriter, Location loc, + const SmallVector &shape, + SmallVector &dynSizes, Value first, + Value second, const int kAttentionRank) { + for (int i = 0; i < kAttentionRank; ++i) { + if (shape[i] == torch::Torch::kUnknownSize) { + Value idx = arith::ConstantIndexOp::create(rewriter, loc, std::min(i, 2)); + Value dim = + tensor::DimOp::create(rewriter, loc, i < 3 ? first : second, idx); + dynSizes.push_back(dim); + } + } +} + +// Utility to create a modified mask tensor. +Value createModifiedMask(PatternRewriter &rewriter, Location loc, + MLIRContext *ctx, FlatSymbolRefAttr maskModRef, + int64_t batch, int64_t numHeads, int64_t seqLenQ, + int64_t seqLenKV, FloatType floatType, + Value builtinQuery, Value builtinKey, Value zero, + const int kAttentionRank) { + static const int kNumModificationIndices = 4; + // Create mask tensor [B, H, M, N] with values 0.0 (attend) or -inf + // (mask). + RankedTensorType boolScalarTensorType = + RankedTensorType::get({}, rewriter.getI1Type()); + torch::Torch::ValueTensorType torchBoolScalarType = + rewriter.getType(ArrayRef{}, + rewriter.getI1Type()); + Type i32Type = rewriter.getI32Type(); + RankedTensorType i32ScalarTensorType = RankedTensorType::get({}, i32Type); + Type si32Type = + IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed); + torch::Torch::ValueTensorType torchI32ScalarType = + rewriter.getType(ArrayRef{}, + si32Type); + SmallVector maskShape = {batch, numHeads, seqLenQ, seqLenKV}; + SmallVector maskDynSizes; + + computeDynamicSizes(rewriter, loc, maskShape, maskDynSizes, builtinQuery, + builtinKey, kAttentionRank); + + Value maskTensor = tensor::EmptyOp::create(rewriter, loc, maskShape, + floatType, maskDynSizes); + // Create linalg.generic to materialize mask. + SmallVector maskMaps; + maskMaps.push_back(AffineMap::getMultiDimIdentityMap(kAttentionRank, ctx)); + + SmallVector iteratorTypes(kAttentionRank, + utils::IteratorType::parallel); + + Value negInf = arith::ConstantFloatOp::create( + rewriter, loc, floatType, + llvm::APFloat::getInf(floatType.getFloatSemantics(), + /*Negative=*/true)); + + auto maskGeneric = linalg::GenericOp::create( + rewriter, loc, TypeRange{maskTensor.getType()}, ValueRange{}, + ValueRange{maskTensor}, maskMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // Get indices and convert to torch tensors. + SmallVector torchIndices; + for (unsigned i = 0; i < kNumModificationIndices; ++i) { + Value idx = linalg::IndexOp::create(b, loc, i); + Value idxI32 = + arith::IndexCastOp::create(b, loc, rewriter.getI32Type(), idx); + Value idxTensor = tensor::FromElementsOp::create( + b, loc, i32ScalarTensorType, ValueRange{idxI32}); + Value torchIdx = torch::TorchConversion::FromBuiltinTensorOp::create( + b, loc, torchI32ScalarType, idxTensor); + torchIndices.push_back(torchIdx); + } + + // Call mask_mod_fn(b, h, q_idx, kv_idx). + auto callOp = + func::CallOp::create(b, loc, TypeRange{torchBoolScalarType}, + maskModRef, ValueRange(torchIndices)); + Value torchMaskResult = callOp.getResult(0); + + Value maskResult = torch::TorchConversion::ToBuiltinTensorOp::create( + b, loc, boolScalarTensorType, torchMaskResult); + + Value maskBool = + tensor::ExtractOp::create(b, loc, maskResult, ValueRange{}); + + Value maskValue = + arith::SelectOp::create(b, loc, maskBool, zero, negInf); + + linalg::YieldOp::create(b, loc, maskValue); + }); + + return maskGeneric.getResult(0); +} + +Value convertToBuiltinTensor(PatternRewriter &rewriter, Location loc, + Value torchTensor) { + auto torchType = cast(torchTensor.getType()); + return torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, torchType.toBuiltinTensor(), torchTensor); +} + +struct FlexAttentionOpConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Attention tensors are 4D: [batch, head, query_seq, key_seq]. + static const int kAttentionRank = 4; + + LogicalResult matchAndRewrite(torch::Torch::HigherOrderFlexAttentionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = getContext(); + Value query = op.getQuery(); + Value key = op.getKey(); + Value value = op.getValue(); + Value scaleValue = op.getScale(); + auto scoreModSymbol = op.getScoreModFn(); + auto maskModSymbol = op.getMaskModFn(); + + bool returnLseValue; + if (!matchPattern(op.getReturnLse(), + torch::Torch::m_TorchConstantBool(&returnLseValue))) { + return rewriter.notifyMatchFailure( + op, "expected return_lse to be a constant bool"); + } + + bool returnMaxScoresValue; + if (!matchPattern( + op.getReturnMaxScores(), + torch::Torch::m_TorchConstantBool(&returnMaxScoresValue))) { + return rewriter.notifyMatchFailure( + op, "expected return_max_scores to be a constant bool"); + } + + auto queryType = cast(query.getType()); + auto keyType = cast(key.getType()); + auto valueType = cast(value.getType()); + + ArrayRef queryShape = queryType.getSizes(); + ArrayRef valueShape = valueType.getSizes(); + + int64_t batch = queryShape[0]; + int64_t numHeads = queryShape[1]; + int64_t seqLenQ = queryShape[2]; + int64_t headDim = queryShape[3]; + int64_t seqLenKV = keyType.getSizes()[2]; + int64_t valueDim = valueShape[3]; + + auto floatType = dyn_cast(queryType.getOptionalDtype()); + // Default scale: 1.0 / sqrt(head_dim). + double scaleVal; + if (!matchPattern(scaleValue, + torch::Torch::m_TorchConstantFloat(&scaleVal))) { + scaleVal = 1.0 / std::sqrt(static_cast(headDim)); + } + + Value scale = arith::ConstantOp::create( + rewriter, loc, floatType, rewriter.getFloatAttr(floatType, scaleVal)); + + Value builtinQuery = convertToBuiltinTensor(rewriter, loc, query); + Value builtinKey = convertToBuiltinTensor(rewriter, loc, key); + Value builtinValue = convertToBuiltinTensor(rewriter, loc, value); + + // Declare common types for mask and score modification regions. + Value zero = arith::ConstantFloatOp::create( + rewriter, loc, floatType, + llvm::APFloat::getZero(floatType.getFloatSemantics())); + Value mask; + if (maskModSymbol) { + FlatSymbolRefAttr maskModRef = + FlatSymbolRefAttr::get(ctx, *maskModSymbol); + mask = createModifiedMask(rewriter, loc, ctx, maskModRef, batch, numHeads, + seqLenQ, seqLenKV, floatType, builtinQuery, + builtinKey, zero, kAttentionRank); + } + + // Create output tensor for attention. + SmallVector outputDynSizes; + SmallVector outputShape = {batch, numHeads, seqLenQ, valueDim}; + computeDynamicSizes(rewriter, loc, outputShape, outputDynSizes, + builtinQuery, builtinValue, kAttentionRank); + + // Initialize output tensor with identity value (0.0 for addition). + Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, + floatType, rewriter, loc, + /*useOnlyFiniteValue=*/true); + Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit, + outputShape, outputDynSizes); + + // Build indexing maps for attention. + // Standard maps: Q, K, V, scale, [mask], output. + AffineExpr b, h, m, n, k1, k2; + bindDims(ctx, b, h, m, n, k1, k2); + + auto qMap = AffineMap::get(6, 0, {b, h, m, k1}, ctx); + auto kMap = AffineMap::get(6, 0, {b, h, n, k1}, ctx); + auto vMap = AffineMap::get(6, 0, {b, h, n, k2}, ctx); + auto sMap = AffineMap::get(6, 0, {}, ctx); + auto oMap = AffineMap::get(6, 0, {b, h, m, k2}, ctx); + + SmallVector indexingMaps = {qMap, kMap, vMap, sMap}; + if (mask) { + indexingMaps.push_back(AffineMap::get(6, 0, {b, h, m, n}, ctx)); + } + + indexingMaps.push_back(oMap); + + // Create attention op. + auto attentionOp = IREE::LinalgExt::AttentionOp::create( + rewriter, loc, outputTensor.getType(), builtinQuery, builtinKey, + builtinValue, scale, outputTensor, + rewriter.getAffineMapArrayAttr(indexingMaps), mask); + + createScoreModificationRegion(rewriter, loc, attentionOp, scoreModSymbol, + floatType, kAttentionRank); + + rewriter.setInsertionPointAfter(attentionOp); + + Value normalizedOutput = attentionOp.getResult(0); + + auto outputTorchType = + queryType.getWithSizesAndDtype(outputShape, floatType); + Value torchOutput = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, outputTorchType, normalizedOutput); + + // Handle logsumexp. + // Note: AttentionOp doesn't expose intermediate max/sum + // values needed for LSE calculation. Return a dummy tensor - logsumexp + // shape is output_shape[:-1] (remove last dim). + if (returnLseValue) { + op.emitWarning("FlexAttention: logsumexp output is a dummy (zeros), " + "actual values are not available from AttentionOp"); + } + // Same goes for max_scores computation from AttentionOp. + if (returnMaxScoresValue) { + op.emitWarning("FlexAttention: max_scores output is a dummy (zeros), " + "actual values are not available from AttentionOp"); + } + SmallVector lseShape = outputShape; + lseShape.pop_back(); + + SmallVector lseDynSizes = outputDynSizes; + if (ShapedType::isDynamic(outputShape.back())) { + lseDynSizes.pop_back(); + } + + Value lseTensor = + tensor::SplatOp::create(rewriter, loc, zero, lseShape, lseDynSizes); + + auto lseTorchType = queryType.getWithSizesAndDtype(lseShape, floatType); + Value torchLogsumexp = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, lseTorchType, lseTensor); + + rewriter.replaceOp( + op, {torchOutput, torchLogsumexp, /*max_scores=*/torchLogsumexp}); + return success(); + } +}; + class ConvertTorchUnstructuredToLinalgExtPass final : public impl::ConvertTorchUnstructuredToLinalgExtPassBase< ConvertTorchUnstructuredToLinalgExtPass> { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert< + IREE::LinalgExt::IREELinalgExtDialect, torch::Torch::TorchDialect, + tensor::TensorDialect, linalg::LinalgDialect, arith::ArithDialect, + func::FuncDialect, torch::TorchConversion::TorchConversionDialect>(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir index a568966d906b..ac954ddeba23 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(torch-iree-torch-unstructured-to-linalg-ext))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(func.func(torch-iree-torch-unstructured-to-linalg-ext))" %s | FileCheck %s // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -99,3 +99,165 @@ func.func @fft_rfft.last(%arg0: !torch.vtensor<[3,8,16],f32>) -> !torch.vtensor< // CHECK: %[[VAR12:.*]] = torch.aten.cat %[[VAR11]], %[[INTM1]] : !torch.list>, !torch.int -> !torch.vtensor<[3,8,9,2],f32> // CHECK: %[[VAR13:.*]] = torch.aten.view_as_complex %[[VAR12]] : !torch.vtensor<[3,8,9,2],f32> -> !torch.vtensor<[3,8,9],complex> // CHECK: return %[[VAR13]] : !torch.vtensor<[3,8,9],complex> + +// ----- + +//===----------------------------------------------------------------------===// +// FlexAttention tests +//===----------------------------------------------------------------------===// + + +func.func @flex_attn_with_scoremod_and_maskmod(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_scoremod_and_maskmod( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[CST_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[MASK_EMPTY:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> +// CHECK: %[[MASK:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MASK_EMPTY]] : tensor<4x8x1024x1024xf32>) +// CHECK: func.call @sdpa_mask0 +// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_2]], %[[MASK]] : +// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: func.call @sdpa_score0 +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> + +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} +// CHECK-LABEL: func.func private @sdpa_score0( +// CHECK: %{{.*}} = torch.aten.tanh %{{.*}} : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} +// CHECK-LABEL: func.func private @sdpa_mask0( +// CHECK: %{{.*}} = torch.aten.ge.Tensor %{{.*}}, %{{.*}} : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + +// ----- + +func.func @flex_attn_with_scoremod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {score_mod_fn = @sdpa_score1} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_scoremod_only( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : +// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: func.call @sdpa_score1 +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> + +func.func private @sdpa_score1(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} +// CHECK-LABEL: func.func private @sdpa_score1( +// CHECK: %{{.*}} = torch.aten.tanh %{{.*}} : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + +// ----- + +func.func @flex_attn_with_maskmod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {mask_mod_fn = @sdpa_mask1} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_with_maskmod_only( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[CST_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[MASK_EMPTY:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> +// CHECK: %[[MASK:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MASK_EMPTY]] : tensor<4x8x1024x1024xf32>) +// CHECK: func.call @sdpa_mask1 +// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_2]], %[[MASK]] : +// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> + +func.func private @sdpa_mask1(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} +// CHECK-LABEL: func.func private @sdpa_mask1( +// CHECK: %{{.*}} = torch.aten.ge.Tensor %{{.*}}, %{{.*}} : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + +// ----- + +func.func @flex_attn_without_mods(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %false = torch.constant.bool false + %true = torch.constant.bool true + // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_without_mods( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : +// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> + +// ----- + +func.func @flex_attn_without_mods_returnmaxscore(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %true = torch.constant.bool true + // expected-warning @+2 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} + // expected-warning @+1 {{FlexAttention: max_scores output is a dummy (zeros), actual values are not available from AttentionOp}} + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %true : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + return %output : !torch.vtensor<[4,8,1024,64],f32> +} +// CHECK-LABEL: func.func @flex_attn_without_mods_returnmaxscore( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : +// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> \ No newline at end of file diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 7719c2522b7c..9e76efc8737b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -2808,14 +2808,20 @@ CustomOp::reifyResultShapes(OpBuilder &builder, //===---------------------------------------------------------------------===// LogicalResult IREE::LinalgExt::IndexOp::verify() { - auto customOp = dyn_cast(getOperation()->getParentOp()); - if (!customOp) { - return emitOpError("expected parent op to be `iree_linalg_ext.custom_op`"); + auto parentOp = getOperation()->getParentOp(); + auto customOp = dyn_cast(parentOp); + auto attentionOp = dyn_cast(parentOp); + if (!customOp && !attentionOp) { + return emitOpError( + "expected parent op to be one of `iree_linalg_ext.custom_op`, " + "`iree_linalg_ext.attention`"); } - if (customOp.getNumLoops() <= getDim()) { + int64_t numLoops = + customOp ? customOp.getNumLoops() : attentionOp.getNumLoops(); + if (numLoops <= getDim()) { return emitOpError("expected dim (") - << getDim() << ") to be lower than the number of loops (" - << customOp.getNumLoops() << ") of the enclosing CustomOp"; + << getDim() << ") to be lower than the number of loops (" << numLoops + << ") of the enclosing CustomOp/AttentionOp"; } return success(); } From d4d4074d41ede63a1e14390c629cb4360db27263 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 04:47:47 -0800 Subject: [PATCH 02/71] EOL added to lit Signed-off-by: Keshav Vinayak Jha --- .../Torch/InputConversion/test/unstructured_linalg_ext.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir index ac954ddeba23..36f0cc9102fc 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir @@ -260,4 +260,4 @@ func.func @flex_attn_without_mods_returnmaxscore(%arg0: !torch.vtensor<[4,8,1024 // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : // CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> \ No newline at end of file +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> From b1feb97f0a3ff0171e060ba8a2b1bc4b7c7e741c Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 04:49:13 -0800 Subject: [PATCH 03/71] Redundant change (Another PR) Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 9e76efc8737b..7719c2522b7c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -2808,20 +2808,14 @@ CustomOp::reifyResultShapes(OpBuilder &builder, //===---------------------------------------------------------------------===// LogicalResult IREE::LinalgExt::IndexOp::verify() { - auto parentOp = getOperation()->getParentOp(); - auto customOp = dyn_cast(parentOp); - auto attentionOp = dyn_cast(parentOp); - if (!customOp && !attentionOp) { - return emitOpError( - "expected parent op to be one of `iree_linalg_ext.custom_op`, " - "`iree_linalg_ext.attention`"); + auto customOp = dyn_cast(getOperation()->getParentOp()); + if (!customOp) { + return emitOpError("expected parent op to be `iree_linalg_ext.custom_op`"); } - int64_t numLoops = - customOp ? customOp.getNumLoops() : attentionOp.getNumLoops(); - if (numLoops <= getDim()) { + if (customOp.getNumLoops() <= getDim()) { return emitOpError("expected dim (") - << getDim() << ") to be lower than the number of loops (" << numLoops - << ") of the enclosing CustomOp/AttentionOp"; + << getDim() << ") to be lower than the number of loops (" + << customOp.getNumLoops() << ") of the enclosing CustomOp"; } return success(); } From 4c66033a3e2e17ffaa6a8f51b3de0732c735037d Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 20:58:12 -0800 Subject: [PATCH 04/71] Added Dynamic Head NYI check Signed-off-by: Keshav Vinayak Jha --- .../InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 5dae736deff8..5c6bb27cc117 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -381,6 +381,11 @@ struct FlexAttentionOpConversion int64_t seqLenKV = keyType.getSizes()[2]; int64_t valueDim = valueShape[3]; + // Dynamic head dim is not supported. + if (headDim == kUnknownSize) { + return emitError() << "NYI: dynamic head dimension"; + } + auto floatType = dyn_cast(queryType.getOptionalDtype()); // Default scale: 1.0 / sqrt(head_dim). double scaleVal; From 690d8b0785926054e34074cbd547804cc3852ace Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 27 Nov 2025 03:30:38 -0800 Subject: [PATCH 05/71] missing scope resiultion Signed-off-by: Keshav Vinayak Jha --- .../InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 5c6bb27cc117..83be4267ff23 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -382,8 +382,8 @@ struct FlexAttentionOpConversion int64_t valueDim = valueShape[3]; // Dynamic head dim is not supported. - if (headDim == kUnknownSize) { - return emitError() << "NYI: dynamic head dimension"; + if (headDim == torch::Torch::kUnknownSize) { + return rewriter.notifyMatchFailure(op, "NYI: dynamic head dimension"); } auto floatType = dyn_cast(queryType.getOptionalDtype()); From c6b9868da4d83759040ba4974da648253293828f Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 08:15:17 +0000 Subject: [PATCH 06/71] Added CHECK-NOT statements for no func references Signed-off-by: Keshav Vinayak Jha --- .../test/unstructured_linalg_ext.mlir | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir index 36f0cc9102fc..7974f9e24bf6 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir @@ -149,14 +149,12 @@ func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vten // CHECK-LABEL: func.func private @sdpa_mask0( // CHECK: %{{.*}} = torch.aten.ge.Tensor %{{.*}}, %{{.*}} : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> -// ----- - func.func @flex_attn_with_scoremod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %false = torch.constant.bool false %true = torch.constant.bool true // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} - %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {score_mod_fn = @sdpa_score1} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } // CHECK-LABEL: func.func @flex_attn_with_scoremod_only( @@ -166,28 +164,20 @@ func.func @flex_attn_with_scoremod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32> // CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> // CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> // CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK-NOT: func.call @sdpa_mask0 // CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : // CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) -// CHECK: func.call @sdpa_score1 +// CHECK: func.call @sdpa_score0 // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> -func.func private @sdpa_score1(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { - %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> - return %0 : !torch.vtensor<[],f32> -} -// CHECK-LABEL: func.func private @sdpa_score1( -// CHECK: %{{.*}} = torch.aten.tanh %{{.*}} : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> - -// ----- - func.func @flex_attn_with_maskmod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %false = torch.constant.bool false %true = torch.constant.bool true // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} - %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {mask_mod_fn = @sdpa_mask1} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } // CHECK-LABEL: func.func @flex_attn_with_maskmod_only( @@ -202,20 +192,14 @@ func.func @flex_attn_with_maskmod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, // CHECK: %[[MASK_EMPTY:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> // CHECK: %[[MASK:.*]] = linalg.generic // CHECK-SAME: outs(%[[MASK_EMPTY]] : tensor<4x8x1024x1024xf32>) -// CHECK: func.call @sdpa_mask1 +// CHECK: func.call @sdpa_mask0 // CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_2]], %[[MASK]] : // CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK-NOT: func.call @sdpa_score0 // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> -func.func private @sdpa_mask1(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { - %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> - return %0 : !torch.vtensor<[],i1> -} -// CHECK-LABEL: func.func private @sdpa_mask1( -// CHECK: %{{.*}} = torch.aten.ge.Tensor %{{.*}}, %{{.*}} : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> - // ----- func.func @flex_attn_without_mods(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { @@ -233,9 +217,11 @@ func.func @flex_attn_without_mods(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg // CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> // CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> // CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK-NOT: func.call @sdpa_mask0 // CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : // CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK-NOT: func.call @sdpa_score0 // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> @@ -256,8 +242,10 @@ func.func @flex_attn_without_mods_returnmaxscore(%arg0: !torch.vtensor<[4,8,1024 // CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> // CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> // CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK-NOT: func.call @sdpa_mask0 // CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention // CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : // CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK-NOT: func.call @sdpa_score0 // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> From 3820c28f974a1d517ddd38cfe11b933b654c82b4 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 08:28:38 +0000 Subject: [PATCH 07/71] Simplified FileCheck statements Signed-off-by: Keshav Vinayak Jha --- .../test/unstructured_linalg_ext.mlir | 91 +++++-------------- 1 file changed, 23 insertions(+), 68 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir index 7974f9e24bf6..0432200fa3d4 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir @@ -110,9 +110,7 @@ func.func @fft_rfft.last(%arg0: !torch.vtensor<[3,8,16],f32>) -> !torch.vtensor< func.func @flex_attn_with_scoremod_and_maskmod(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %false = torch.constant.bool false - %true = torch.constant.bool true - // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} - %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } // CHECK-LABEL: func.func @flex_attn_with_scoremod_and_maskmod( @@ -152,82 +150,47 @@ func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vten func.func @flex_attn_with_scoremod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %false = torch.constant.bool false - %true = torch.constant.bool true - // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} - %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false {score_mod_fn = @sdpa_score0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } -// CHECK-LABEL: func.func @flex_attn_with_scoremod_only( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> +// CHECK-LABEL: func.func @flex_attn_with_scoremod_only +// CHECK-NOT: linalg.generic // CHECK-NOT: func.call @sdpa_mask0 -// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention -// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : -// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: iree_linalg_ext.attention // CHECK: func.call @sdpa_score0 -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return func.func @flex_attn_with_maskmod_only(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %false = torch.constant.bool false - %true = torch.constant.bool true - // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} - %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } -// CHECK-LABEL: func.func @flex_attn_with_maskmod_only( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: %[[CST_1:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[MASK_EMPTY:.*]] = tensor.empty() : tensor<4x8x1024x1024xf32> -// CHECK: %[[MASK:.*]] = linalg.generic -// CHECK-SAME: outs(%[[MASK_EMPTY]] : tensor<4x8x1024x1024xf32>) +// CHECK-LABEL: func.func @flex_attn_with_maskmod_only +// CHECK: linalg.generic // CHECK: func.call @sdpa_mask0 -// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention -// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_2]], %[[MASK]] : -// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK: iree_linalg_ext.attention // CHECK-NOT: func.call @sdpa_score0 -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return // ----- func.func @flex_attn_without_mods(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %false = torch.constant.bool false - %true = torch.constant.bool true - // expected-warning @+1 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} - %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %false : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> + %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %false, %false : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } -// CHECK-LABEL: func.func @flex_attn_without_mods( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK-NOT: func.call @sdpa_mask0 -// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention -// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : -// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK-LABEL: func.func @flex_attn_without_mods +// CHECK-NOT: linalg.generic +// CHECK-NOT: func.call @sdpa_mask0 +// CHECK: iree_linalg_ext.attention // CHECK-NOT: func.call @sdpa_score0 -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return // ----- -func.func @flex_attn_without_mods_returnmaxscore(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { +func.func @flex_attn_without_mods_return_maxscore_and_lse(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) -> (!torch.vtensor<[4,8,1024,64],f32>) attributes {torch.assume_strict_symbolic_shapes} { %float1.000000e00 = torch.constant.float 1.000000e+00 %true = torch.constant.bool true // expected-warning @+2 {{FlexAttention: logsumexp output is a dummy (zeros), actual values are not available from AttentionOp}} @@ -235,17 +198,9 @@ func.func @flex_attn_without_mods_returnmaxscore(%arg0: !torch.vtensor<[4,8,1024 %output, %logsumexp, %maxscores = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.000000e00, %true, %true : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> return %output : !torch.vtensor<[4,8,1024,64],f32> } -// CHECK-LABEL: func.func @flex_attn_without_mods_returnmaxscore( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,8,1024,64],f32>, %[[ARG2:.*]]: !torch.vtensor<[4,8,1024,64],f32>) -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<4x8x1024x64xf32> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[QUERY:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[KEY:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK: %[[VALUE:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[4,8,1024,64],f32> -> tensor<4x8x1024x64xf32> -// CHECK-NOT: func.call @sdpa_mask0 -// CHECK: %[[ATTENTION:.*]] = iree_linalg_ext.attention -// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[VALUE]], %[[CST_0]] : -// CHECK-SAME: outs(%[[CST]] : tensor<4x8x1024x64xf32>) +// CHECK-LABEL: func.func @flex_attn_without_mods_return_maxscore_and_lse +// CHECK-NOT: linalg.generic +// CHECK-NOT: func.call @sdpa_mask0 +// CHECK: iree_linalg_ext.attention // CHECK-NOT: func.call @sdpa_score0 -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ATTENTION]] : tensor<4x8x1024x64xf32> -> !torch.vtensor<[4,8,1024,64],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[4,8,1024,64],f32> +// CHECK: return From 6eba854f861289c5955630286018ef9d8a7d303e Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 08:55:26 +0000 Subject: [PATCH 08/71] Added verbose comments explaining computeDynamicSizes utility Signed-off-by: Keshav Vinayak Jha --- .../ConvertTorchUnstructuredToLinalgExt.cpp | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 83be4267ff23..5947639692b7 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -233,6 +233,30 @@ void createScoreModificationRegion( } // Utility to compute dynamic sizes for attention tensors. +// This helper is used in two places: +// +// For the mask tensor. Shape = (B, Hq, L, S). Any of these may be dynamic, so +// we extract B/Hq/L from the query tensor and S from the key tensor. The +// resulting dynamic sizes are passed to tensor.empty when materialising the +// mask. +// +// For the output tensor. Shape = (B, Hq, L, Ev). Since Ev is statically known, +// only B/Hq/L may be dynamic. The helper again generates the needed tensor.dim +// ops from the query/value tensors so that tensor.splat/tensor.empty gets the +// correct dynamic extents. Assuming the standard 4D layout: +// Query: (B, Hq, L, E) +// Key: (B, Hkv, S, E) +// Value: (B, Hkv, S, Ev) +// When constructing new tensors (mask/output), we need dynamic sizes for +// dimensions that come from the input shapes. +// +// For dims (B, H, L), the runtime sizes always come from the query tensor. +// For dim 3, the required runtime size depends on what we are building: +// For the mask (shape = B×H×L×S), the 3rd axis is S, which lives at +// index 2 of the Key tensor. +// For the output (shape = B×H×L×Ev), Ev is statically known, so we never need a +// dynamic dimension for i = 3. + void computeDynamicSizes(PatternRewriter &rewriter, Location loc, const SmallVector &shape, SmallVector &dynSizes, Value first, From d0c09b3eb9449c96fd891c074e357b55097f4fee Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 10:53:54 +0000 Subject: [PATCH 09/71] Replaced Splat with linalg::fill Signed-off-by: Keshav Vinayak Jha --- .../ConvertTorchUnstructuredToLinalgExt.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 5947639692b7..7700959d4eea 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -242,7 +242,7 @@ void createScoreModificationRegion( // // For the output tensor. Shape = (B, Hq, L, Ev). Since Ev is statically known, // only B/Hq/L may be dynamic. The helper again generates the needed tensor.dim -// ops from the query/value tensors so that tensor.splat/tensor.empty gets the +// ops from the query/value tensors so that linalg.fill/tensor.empty gets the // correct dynamic extents. Assuming the standard 4D layout: // Query: (B, Hq, L, E) // Key: (B, Hkv, S, E) @@ -448,8 +448,7 @@ struct FlexAttentionOpConversion Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, floatType, rewriter, loc, /*useOnlyFiniteValue=*/true); - Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit, - outputShape, outputDynSizes); + Value outputTensor = linalg::FillOp::create(rewriter, loc, outputInit, outputShape, outputDynSizes); // Build indexing maps for attention. // Standard maps: Q, K, V, scale, [mask], output. @@ -508,8 +507,7 @@ struct FlexAttentionOpConversion lseDynSizes.pop_back(); } - Value lseTensor = - tensor::SplatOp::create(rewriter, loc, zero, lseShape, lseDynSizes); + Value lseTensor = linalg::FillOp::create(rewriter, loc, zero, lseShape, lseDynSizes); auto lseTorchType = queryType.getWithSizesAndDtype(lseShape, floatType); Value torchLogsumexp = torch::TorchConversion::FromBuiltinTensorOp::create( From 115116f90b29b213fc522965e8fb622b4646ed4c Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 11:15:31 +0000 Subject: [PATCH 10/71] SplatOp handles dynamic shapes Signed-off-by: Keshav Vinayak Jha --- .../ConvertTorchUnstructuredToLinalgExt.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp index 7700959d4eea..5947639692b7 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp @@ -242,7 +242,7 @@ void createScoreModificationRegion( // // For the output tensor. Shape = (B, Hq, L, Ev). Since Ev is statically known, // only B/Hq/L may be dynamic. The helper again generates the needed tensor.dim -// ops from the query/value tensors so that linalg.fill/tensor.empty gets the +// ops from the query/value tensors so that tensor.splat/tensor.empty gets the // correct dynamic extents. Assuming the standard 4D layout: // Query: (B, Hq, L, E) // Key: (B, Hkv, S, E) @@ -448,7 +448,8 @@ struct FlexAttentionOpConversion Value outputInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, floatType, rewriter, loc, /*useOnlyFiniteValue=*/true); - Value outputTensor = linalg::FillOp::create(rewriter, loc, outputInit, outputShape, outputDynSizes); + Value outputTensor = tensor::SplatOp::create(rewriter, loc, outputInit, + outputShape, outputDynSizes); // Build indexing maps for attention. // Standard maps: Q, K, V, scale, [mask], output. @@ -507,7 +508,8 @@ struct FlexAttentionOpConversion lseDynSizes.pop_back(); } - Value lseTensor = linalg::FillOp::create(rewriter, loc, zero, lseShape, lseDynSizes); + Value lseTensor = + tensor::SplatOp::create(rewriter, loc, zero, lseShape, lseDynSizes); auto lseTorchType = queryType.getWithSizesAndDtype(lseShape, floatType); Value torchLogsumexp = torch::TorchConversion::FromBuiltinTensorOp::create( From 3c0187e6bc8b7b086f64366cd897a9ec006aaff8 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 21:57:12 -0800 Subject: [PATCH 11/71] Added toggle for using useexp2 for onlineAttention Decomposition Signed-off-by: Keshav Vinayak Jha --- .../IR/AggregatedOpInterfaceImpl.cpp | 46 ++++++---- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 2 + .../IR/test/decompose_aggregate_op.mlir | 92 +++++++++++++++++++ 3 files changed, 123 insertions(+), 17 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index eddece9acfa8..2e935204d518 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -260,10 +260,10 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, return genericOp.getResult(0); } -// Compute output = exp2(output - input) -static Value computeSubAndExp2(OpBuilder &builder, Location loc, - AffineMap inputMap, AffineMap outputMap, - Value input, Value output) { +// Compute output = exp2/exp(output - input) depending on useExp2 flag. +static Value computeSubAndExp(OpBuilder &builder, Location loc, + AffineMap inputMap, AffineMap outputMap, + Value input, Value output, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{inputMap, outputMap}); inputMap = compressedMaps[0]; @@ -279,8 +279,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc, Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(), /*isUnsignedCast=*/false); Value diff = arith::SubFOp::create(b, loc, args[1], in); - Value weight = math::Exp2Op::create(b, loc, diff); - linalg::YieldOp::create(b, loc, weight); + Operation *weight = useExp2 ? math::Exp2Op::create(b, loc, diff) + : math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, weight->getResult(0)); }); return genericOp.getResult(0); } @@ -316,12 +317,18 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, std::optional maskMap, SmallVector iterationDomain, Type sElementType, Region &elementwiseRegion, - DictionaryAttr qkAttrs, bool lowPrecision) { + DictionaryAttr qkAttrs, bool lowPrecision, + bool useExp2) { MLIRContext *ctx = b.getContext(); - // Since we use exp2 for attention instead of the original exp, we have to + // If using exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); + scale = arith::MulFOp::create(b, loc, scale, log2e); + } Value log2e = arith::ConstantOp::create( b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = arith::MulFOp::create(b, loc, scale, log2e); @@ -436,9 +443,9 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { Type f32Type = b.getF32Type(); // ---- QK Matmul + elementwise math ---- - Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, - kMap, sMap, getMaskMap(), sizes, f32Type, - getRegion(), qkAttrs, lowPrecision); + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, f32Type, getRegion(), qkAttrs, lowPrecision, /*useExp2=*/true); // ---- Softmax ---- @@ -480,7 +487,7 @@ FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { // P = exp2(S - max) AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, max, s, /*useExp2=*/true); // sum = rowSum(P) Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); @@ -530,9 +537,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { DictionaryAttr config = getDecompositionConfigAttr(); DictionaryAttr qkAttrs, pvAttrs; + bool useExp2 = true; if (config) { qkAttrs = config.getAs(getQKAttrStr()); pvAttrs = config.getAs(getPVAttrStr()); + if (auto useExp2Attr = config.getAs(getUseExp2AttrStr())) { + useExp2 = useExp2Attr.getValue(); + } } FailureOr maybeOpInfo = AttentionOpDetail::get( @@ -553,7 +564,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // ---- QK Matmul + elementwise math ---- Value s = computeQKAndElementwise( loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), - sizes, elementType, getRegion(), qkAttrs, lowPrecision); + sizes, elementType, getRegion(), qkAttrs, lowPrecision, useExp2); // TODO: This decomposition should be in a seperate op called // "online softmax". @@ -563,20 +574,21 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); - // norm = exp2(oldMax - newMax) + // norm = exp2(oldMax - newMax) or exp(oldMax - newMax) depending on useExp2 // normMap = maxMap AffineMap normMap = getMaxMap(); - Value norm = computeSubAndExp2(b, loc, maxMap, normMap, newMax, oldMax); + Value norm = + computeSubAndExp(b, loc, maxMap, normMap, newMax, oldMax, useExp2); // normSum = norm * oldSum AffineMap sumMap = getSumMap(); Value normSum = elementwiseValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - // P = exp2(S - newMax) + // P = exp2(S - newMax) or exp(S - newMax) depending on useExp2 // PMap = SMap AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + Value p = computeSubAndExp(b, loc, maxMap, sMap, newMax, s, useExp2); // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index ee84c2abd433..aa4ff9de6a29 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1081,6 +1081,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", // Attributes to set on QK and PV matmul after decomposition. static StringRef getQKAttrStr() { return "qk_attrs"; } static StringRef getPVAttrStr() { return "pv_attrs"; } + // Flag to control whether to use exp2 (with log2(e) scaling) or exp. + static StringRef getUseExp2AttrStr() { return "use_exp2"; } }]; let hasCanonicalizer = 1; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 7de65d86c8c7..c973621d57af 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -419,6 +419,98 @@ func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>, + %max: tensor<192x1024xf32>, + %sum: tensor<192x1024xf32>) + -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>) { + %scale = arith.constant 1.0 : f16 + + %out:3 = iree_linalg_ext.online_attention + {decomposition_config = {use_exp2=false}, indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32> + + return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> +} + +// We want to check that we're correctly using exp +// when specified so from the decomposition_config. +// CHECK-LABEL: @online_attention_f16_noexp2 +// Q = Q * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// newMax = max(oldMax, rowMax(S)) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.maximumf +// CHECK: linalg.yield +// norm = exp (oldMax - newMax) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.mulf +// CHECK: linalg.yield +// P = exp(S - newMax) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK-NOT: math.exp2 +// CHECK: linalg.yield +// newSum = normSum + rowSum(P) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.addf +// CHECK: linalg.yield +// newAcc = norm * oldAcc +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.mulf +// CHECK: linalg.yield +// newAcc = P @ V + newAcc +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + // Spec to decompose exp reduction op. module attributes { transform.with_named_sequence } { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { From f14f00b06f35095dfd8ecc4278909844221e0dd7 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Wed, 26 Nov 2025 22:24:09 -0800 Subject: [PATCH 12/71] Added useExp2 as pass option to DecomposeAttention Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/Transforms/DecomposeAttention.cpp | 8 ++++++++ .../iree/compiler/Dialect/LinalgExt/Transforms/Passes.td | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index ac7c42ab58ec..97a0d8740710 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -32,8 +32,16 @@ struct DecomposeAttentionPass final void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); + + SmallVector decompositionConfigAttrs; + decompositionConfigAttrs.push_back( + rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(useExp2))); + DictionaryAttr decompositionConfig = + rewriter.getDictionaryAttr(decompositionConfigAttrs); + getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); + onlineAtt.setDecompositionConfigAttr(decompositionConfig); FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td index c1ce03397950..ea27c2e16fd7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td @@ -116,6 +116,11 @@ def DecomposeAttentionPass : InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> { let summary = "Decomposes attention op into a sequence of linalg ops"; + let options = [ + Option<"useExp2", "use-exp2", "bool", /*default=*/"true", + "Use exp2 for computations; Tunable to allow for accuracte computations" + "in case of accuracy losses due to fp-reassociation.">, + ]; } def ConvertAttentionToOnlineAttentionPass : From d3845c498bc6178dfa1578deef769f892d9c1974 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Thu, 27 Nov 2025 02:29:39 -0800 Subject: [PATCH 13/71] Removed Typo Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 2e935204d518..6efe2a30f784 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -329,9 +329,6 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = arith::MulFOp::create(b, loc, scale, log2e); } - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(scale.getType(), M_LOG2E)); - scale = arith::MulFOp::create(b, loc, scale, log2e); auto qETy = getElementTypeOrSelf(query.getType()); From 9c369687abdc968caf73337d0693e22c6753ac9c Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 00:12:05 -0800 Subject: [PATCH 14/71] Simplified FileChecks; Added check of log2e vs 1.0 scaling Signed-off-by: Keshav Vinayak Jha --- .../IR/test/decompose_aggregate_op.mlir | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index c973621d57af..21185db2d70c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -190,6 +190,7 @@ func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, // correct number of extf/truncfs are emitted. // CHECK-LABEL: @online_attention_f16 // Q = Q * scale +// CHECK: arith.constant 1.442380e+00 : f16 // CHECK: linalg.generic // CHECK: arith.mulf // S = Q @ K @@ -460,35 +461,19 @@ func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, // when specified so from the decomposition_config. // CHECK-LABEL: @online_attention_f16_noexp2 // Q = Q * scale +// CHECK: arith.constant 1.000000e+00 : f16 // CHECK: linalg.generic // CHECK: arith.mulf -// S = Q @ K -// CHECK: linalg.generic -// CHECK: arith.extf -// CHECK: arith.extf -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield -// newMax = max(oldMax, rowMax(S)) -// CHECK: linalg.generic -// CHECK-NOT: arith.extf -// CHECK: arith.maximumf -// CHECK: linalg.yield // norm = exp (oldMax - newMax) // CHECK: linalg.generic -// CHECK-NOT: arith.extf // CHECK: arith.subf -// CHECK-NOT: math.exp2 -// CHECK: linalg.yield -// normSum = norm * oldSum -// CHECK: linalg.generic // CHECK-NOT: arith.extf -// CHECK: arith.mulf +// CHECK-NOT: math.exp2 // CHECK: linalg.yield // P = exp(S - newMax) // CHECK: linalg.generic -// CHECK-NOT: arith.extf // CHECK: arith.subf +// CHECK-NOT: arith.extf // CHECK-NOT: math.exp2 // CHECK: linalg.yield // newSum = normSum + rowSum(P) From 2b7083f3997d2ee140b2a45f38bec28e6e81db77 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 00:12:42 -0800 Subject: [PATCH 15/71] Newline at EOF Signed-off-by: Keshav Vinayak Jha --- .../IR/test/decompose_aggregate_op.mlir | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 21185db2d70c..7eff2e014b79 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -476,23 +476,6 @@ func.func @online_attention_f16_noexp2(%query: tensor<192x1024x64xf16>, // CHECK-NOT: arith.extf // CHECK-NOT: math.exp2 // CHECK: linalg.yield -// newSum = normSum + rowSum(P) -// CHECK: linalg.generic -// CHECK-NOT: arith.extf -// CHECK: arith.addf -// CHECK: linalg.yield -// newAcc = norm * oldAcc -// CHECK: linalg.generic -// CHECK-NOT: arith.extf -// CHECK: arith.mulf -// CHECK: linalg.yield -// newAcc = P @ V + newAcc -// CHECK: linalg.generic -// CHECK: arith.extf -// CHECK: arith.extf -// CHECK: arith.mulf -// CHECK: arith.addf -// CHECK: linalg.yield // ----- @@ -572,4 +555,4 @@ func.func @exp_reduction( // CHECK-SAME: outs(%[[acc_norm]] // CHECK: arith.mulf // CHECK: arith.addf -// CHECK: return %[[M]], %[[SUM]], %[[PV]] +// CHECK: return %[[M]], %[[SUM]], %[[PV]] \ No newline at end of file From a4dfdfafd9e4a9f27f02ee5a604e2c6ff291805f Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 01:28:55 -0800 Subject: [PATCH 16/71] Mask scaling is conditional to useExp2 Signed-off-by: Keshav Vinayak Jha --- .../LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 6efe2a30f784..92a1a58227d5 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -209,7 +209,7 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc, } static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, - AffineMap maskMap, Value qk, Value mask) { + AffineMap maskMap, Value qk, Value mask, bool useExp2) { SmallVector compressedMaps = compressUnusedDims(SmallVector{qkMap, maskMap}); @@ -245,9 +245,11 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(), /*isUnsignedCast=*/false); // Scaling to compensate for base-2 softmax - Value log2e = arith::ConstantOp::create( - b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); - maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + if (useExp2) { + Value log2e = arith::ConstantOp::create( + b, loc, b.getFloatAttr(qkVal.getType(), M_LOG2E)); + maskVal = arith::MulFOp::create(b, loc, maskVal, log2e); + } } // Finally, set the returned value to the qk element plus the mask // element (or 0/-infinity if bool mask). We opt for a AddFOp (instead @@ -396,7 +398,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value(), useExp2); } return s; From d80d115a49187d980ae7ff214e72192247210d79 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 03:25:39 -0800 Subject: [PATCH 17/71] Bug in code: Overwriting the existing DecompositionAttr, we want to add use_exp2 not overwrite Signed-off-by: Keshav Vinayak Jha --- .../LinalgExt/Transforms/DecomposeAttention.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index 97a0d8740710..1fb50d83d52a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -33,15 +33,14 @@ void DecomposeAttentionPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); - SmallVector decompositionConfigAttrs; - decompositionConfigAttrs.push_back( - rewriter.getNamedAttr("use_exp2", rewriter.getBoolAttr(useExp2))); - DictionaryAttr decompositionConfig = - rewriter.getDictionaryAttr(decompositionConfigAttrs); - getOperation().walk([&](OnlineAttentionOp onlineAtt) { rewriter.setInsertionPoint(onlineAtt); - onlineAtt.setDecompositionConfigAttr(decompositionConfig); + + NamedAttrList decompositionConfig(onlineAtt.getDecompositionConfigAttr()); + decompositionConfig.set("use_exp2", rewriter.getBoolAttr(useExp2)); + onlineAtt.setDecompositionConfigAttr( + decompositionConfig.getDictionary(context)); + FailureOr> results = onlineAtt.decomposeOperation(rewriter); if (failed(results)) { From dabdd8673b35b77c3d40fb71805e3908749eb338 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 03:35:28 -0800 Subject: [PATCH 18/71] Added docs for Decomposition Configuration: Signed-off-by: Keshav Vinayak Jha --- .../iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index aa4ff9de6a29..ea1ba6490c12 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -997,6 +997,16 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", it over the entire softmax reduction dimension by: x, _, sum : results x = (1 / sum) * x + + Decomposition Configuration: + The `decomposition_config` attribute is a DictionaryAttr that controls how + this operation is decomposed into lower-level operations. It supports: + - "qk_attrs": DictionaryAttr - Attributes to attach to the Q@K matmul + operation after decomposition (e.g., lowering_config, attention markers) + - "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul + operation after decomposition + - "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead + of exp. (Might be better accuracy-wise on some hardware) }]; let arguments = (ins AnyShaped:$query, From 2dde237a8b6f1d3070d7a84ecc5399fd8919f1ca Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 16 Dec 2025 07:39:42 -0800 Subject: [PATCH 19/71] Nit comment Signed-off-by: Keshav Vinayak Jha --- compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index ea1ba6490c12..01510e5e7703 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1006,7 +1006,7 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", - "pv_attrs": DictionaryAttr - Attributes to attach to the P@V matmul operation after decomposition - "use_exp2": BoolAttr - If true, uses exp2 with log2(e) scaling instead - of exp. (Might be better accuracy-wise on some hardware) + of exp. (Gives better perf on some hardware, but trades off accuracy) }]; let arguments = (ins AnyShaped:$query, From 4972d7aa9e61b4f0275dcf3211810fa838dd0fc8 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Tue, 23 Dec 2025 13:27:40 +0530 Subject: [PATCH 20/71] Refactor computeSubAndExp2 to computeSubAndExp Updated computeSubAndExp2 calls to computeSubAndExp with useExp2 flag. --- .../Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 92a1a58227d5..8607f93f6032 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -1220,11 +1220,11 @@ FailureOr> ExpReductionOp::decomposeOperation(OpBuilder &b) { Value currMax = reduce( rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get()); // ex = e^{sValue - curr_max} - Value ex = computeSubAndExp2(rewriter, loc, prevMaxMap, normValMap, currMax, - sValue->get()); + Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax, + sValue->get(), /*useExp2=*/true); // norm = e^(prev_max - curr_max) - Value norm = computeSubAndExp2(rewriter, loc, prevMaxMap, prevMaxMap, currMax, - prevMax->get()); + Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax, + prevMax->get(), /*useExp2=*/true); SmallVector inputs = getDpsInputs(); SmallVector normOuts(getNumDpsInits()); From 162dcf4e34ea9a3232da3e6ebb202ee91d90d144 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Tue, 23 Dec 2025 08:11:26 +0000 Subject: [PATCH 21/71] Formatting Signed-off-by: Keshav Vinayak Jha --- .../Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 8607f93f6032..e8c3d90362c0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -1221,10 +1221,10 @@ FailureOr> ExpReductionOp::decomposeOperation(OpBuilder &b) { rewriter, loc, normValMap, prevMaxMap, sValue->get(), prevMax->get()); // ex = e^{sValue - curr_max} Value ex = computeSubAndExp(rewriter, loc, prevMaxMap, normValMap, currMax, - sValue->get(), /*useExp2=*/true); + sValue->get(), /*useExp2=*/true); // norm = e^(prev_max - curr_max) Value norm = computeSubAndExp(rewriter, loc, prevMaxMap, prevMaxMap, currMax, - prevMax->get(), /*useExp2=*/true); + prevMax->get(), /*useExp2=*/true); SmallVector inputs = getDpsInputs(); SmallVector normOuts(getNumDpsInits()); From 354c71a8f67c8eb12f38397ebe3ccad57125dda2 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 13 Jan 2026 08:30:56 -0800 Subject: [PATCH 22/71] Integrate llvm-project@f9a8096067 [ours 31f0e3e644] (#23101) Existing local reverts carried forward: * Local revert of https://github.com/llvm/llvm-project/pull/169614 due to https://github.com/llvm/llvm-project/issues/172932. Reverts dropped: * https://github.com/llvm/llvm-project/pull/174084 and followups - it works now Local workarounds dropped: * Remove copying around some Python DLLs since upstream seems to put them in the right place now --- compiler/bindings/python/CMakeLists.txt | 23 ----------------------- third_party/llvm-project | 2 +- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index 5956c2d24cb8..5a148eddcdcf 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -416,29 +416,6 @@ add_custom_target(IREECompilerPythonDylibFiles add_dependencies(IREECompilerPythonModules IREECompilerPythonDylibFiles) -################################################################################ -# Windows DLL colocation fix -# On Windows, the nanobind-mlir.dll ends up in iree/build/_mlir_libs/ but we -# need to copy it to iree/compiler/_mlir_libs/ for the Python extensions to find -# it at runtime. -################################################################################ -if(WIN32) - set(_nanobind_src "${_PYTHON_BUILD_PREFIX}/iree/build/_mlir_libs/nanobind-mlir.dll") - set(_nanobind_dst "${_PYTHON_BUILD_PREFIX}/iree/compiler/_mlir_libs/nanobind-mlir.dll") - add_custom_command( - OUTPUT "${_nanobind_dst}" - DEPENDS "${_nanobind_src}" - COMMAND ${CMAKE_COMMAND} -E copy_if_different - "${_nanobind_src}" "${_nanobind_dst}" - COMMENT "Copying nanobind-mlir.dll to iree/compiler/_mlir_libs/ for Windows DLL loading" - ) - add_custom_target(IREECompilerPythonNanobindCopy - DEPENDS "${_nanobind_dst}" - ) - add_dependencies(IREECompilerPythonNanobindCopy IREECompilerBuildPythonModules) - add_dependencies(IREECompilerPythonModules IREECompilerPythonNanobindCopy) -endif() - ################################################################################ # Subdirectories ################################################################################ diff --git a/third_party/llvm-project b/third_party/llvm-project index fc66e8eaa7e8..31f0e3e64485 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit fc66e8eaa7e843305b917f749ba02775d3a3d5ac +Subproject commit 31f0e3e644857ed4886884b650530ef791680f95 From dd4c15953fd0a802789692da0c1e0a61be60d215 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 13 Jan 2026 08:56:58 -0800 Subject: [PATCH 23/71] [NFC] Add tablegen_compile_commands.yml to .gitignore (#23104) The tablegen LSP has a tablegen_compile_commands.yml analogous to the compile_commands.json for C++. Since people may want to link it into their source tree for similar reasons to compile_commands.json, add it to .gitignore as well. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 98a5c89a6f3f..b7da4ef62955 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ imgui.ini # Source indexing files compile_commands.json +tablegen_compile_commands.yml .cache/clangd # Language server configuration files From 4f866d4ea9725dfce0de09e6b17db81e8758b054 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Tue, 13 Jan 2026 09:14:42 -0800 Subject: [PATCH 24/71] [Stream] Handle all stream tensor ops in UnifyEncodingForGlobals (#23069) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend the pass to handle additional TensorPhaseOp ops that may appear in the encoded global use chain: - TensorCloneOp: Insert re-encode op before clone if encoding mismatch - TensorEncodeOp: Update source_encoding attribute directly - TensorUpdateOp: Insert re-encode op before update for target/update operands Added `maybeInsertReencode` helper for consistent re-encode op insertion. The re-encode ops convert from unified encoding back to the encoding expected by each op, maintaining correctness while allowing later passes to fold them. The rest of ops are not reachable, based on their op definition. Below is the analysis. | Op | Reachable? | Reason | |----|------------|----------| | TensorDispatchOp | ✅ Yes | Can take loaded global as operands. | | TensorCloneOp | ✅ Yes | Can clone loaded global. | | TensorEncodeOp | ✅ Yes | Can re-encode the loaded global. | | TensorUpdateOp | ✅ Yes | Can partially update the value with loaded global. | | TensorSizeOfOp | ❌ No | No tensor operand - just a TypeAttr. | | TensorEmptyOp | ❌ No | No tensor input - only produces a tensor. | | TensorConstantOp | ❌ No | No tensor input - only produces a tensor. | | TensorSplatOp | ❌ No | No tensor input - takes a scalar, produces a tensor. | | TensorFillOp | ❌ No | Encoded globals are immutable constants, not fill targets. | | TensorSliceOp | ❌ No | Slicing happens on resources, not on already-encoded globals. | | TensorLoadOp | ❌ No | Loading from resource, not tensor. | | TensorStoreOp | ❌ No | Storing value is scalar/vector, not tensor. | --------- Signed-off-by: hanhanW --- .../Transforms/UnifyEncodingForGlobals.cpp | 174 ++++++++++++++++-- .../test/unify_encoding_for_globals.mlir | 102 ++++++++++ 2 files changed, 259 insertions(+), 17 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp index 304c24d89f68..0c92aa38cd02 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" @@ -538,18 +539,151 @@ static void updateTensorDispatchOp(TensorDispatchOp dispatchOp, } } +// Inserts a re-encode op before the given op if the source encoding doesn't +// match the new (unified) encoding. Returns the re-encoded value, or the +// original source if no re-encoding is needed. +static Value maybeInsertReencode(IRRewriter &rewriter, Operation *op, + Value source, Type sourceEncodingType, + ValueRange sourceEncodingDims, + Value sourceSize, Attribute newEncoding, + AffinityAttr affinityAttr) { + auto expectedType = cast(sourceEncodingType); + Attribute expectedEncoding = expectedType.getEncoding(); + + // No re-encode needed if encodings match. + if (expectedEncoding == newEncoding) { + return source; + } + + LDBG() << " Inserting re-encode: " << newEncoding << " -> " + << expectedEncoding; + + // Build the source type (with unified encoding). + RankedTensorType unifiedType = expectedType.cloneWithEncoding(newEncoding); + + // Compute sizes for unified encoding. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + Value unifiedSize = TensorSizeOfOp::create( + rewriter, op->getLoc(), rewriter.getIndexType(), + TypeAttr::get(unifiedType), sourceEncodingDims, affinityAttr); + + // Insert the re-encode op: unified -> expected. + auto reencodeOp = TensorEncodeOp::create( + rewriter, op->getLoc(), source.getType(), source, + TypeAttr::get(unifiedType), + /*source_encoding_dims=*/sourceEncodingDims, unifiedSize, + TypeAttr::get(expectedType), + /*result_encoding_dims=*/sourceEncodingDims, sourceSize, affinityAttr); + + LDBG() << " Created: " << reencodeOp; + return reencodeOp.getResult(); +} + +// Updates TensorCloneOp by inserting re-encode if needed. +static void updateTensorCloneOp(TensorCloneOp cloneOp, + const OperandEncodingUpdates &operandUpdates, + IRRewriter &rewriter) { + int operandNumber = cloneOp.getSourceMutable().getOperandNumber(); + if (!operandUpdates.contains(operandNumber)) { + return; + } + Attribute newEncoding = operandUpdates.lookup(operandNumber); + Value reencoded = maybeInsertReencode( + rewriter, cloneOp, cloneOp.getSource(), cloneOp.getSourceEncoding(), + cloneOp.getSourceEncodingDims(), cloneOp.getSourceSize(), newEncoding, + cloneOp.getAffinityAttr()); + if (reencoded != cloneOp.getSource()) { + rewriter.modifyOpInPlace( + cloneOp, [&] { cloneOp.getSourceMutable().set(reencoded); }); + } +} + +// Updates TensorEncodeOp by updating the source_encoding attribute. +static void updateTensorEncodeOp(TensorEncodeOp encodeOp, + const OperandEncodingUpdates &operandUpdates, + IRRewriter &rewriter) { + int operandNumber = encodeOp.getSourceMutable().getOperandNumber(); + if (!operandUpdates.contains(operandNumber)) { + return; + } + Attribute newEncoding = operandUpdates.lookup(operandNumber); + auto oldSourceType = cast(encodeOp.getSourceEncoding()); + RankedTensorType newSourceType = oldSourceType.cloneWithEncoding(newEncoding); + rewriter.modifyOpInPlace(encodeOp, [&] { + encodeOp.setSourceEncodingAttr(TypeAttr::get(newSourceType)); + }); + LDBG() << " Updated TensorEncodeOp source encoding to " << newEncoding; +} + +// Updates TensorUpdateOp by inserting re-encode if needed. +static void updateTensorUpdateOp(TensorUpdateOp updateOp, + const OperandEncodingUpdates &operandUpdates, + IRRewriter &rewriter) { + // Handle target operand. + int targetOperandNum = updateOp.getTargetMutable().getOperandNumber(); + if (operandUpdates.contains(targetOperandNum)) { + Attribute newEncoding = operandUpdates.lookup(targetOperandNum); + Value reencoded = maybeInsertReencode( + rewriter, updateOp, updateOp.getTarget(), updateOp.getTargetEncoding(), + updateOp.getTargetEncodingDims(), updateOp.getTargetSize(), newEncoding, + updateOp.getAffinityAttr()); + if (reencoded != updateOp.getTarget()) { + rewriter.modifyOpInPlace( + updateOp, [&] { updateOp.getTargetMutable().set(reencoded); }); + } + } + + // Handle update operand. + unsigned updateOperandNum = updateOp.getUpdateMutable().getOperandNumber(); + if (operandUpdates.contains(updateOperandNum)) { + Attribute newEncoding = operandUpdates.lookup(updateOperandNum); + Value reencoded = maybeInsertReencode( + rewriter, updateOp, updateOp.getUpdate(), updateOp.getUpdateEncoding(), + updateOp.getUpdateEncodingDims(), updateOp.getUpdateSize(), newEncoding, + updateOp.getAffinityAttr()); + if (reencoded != updateOp.getUpdate()) { + rewriter.modifyOpInPlace( + updateOp, [&] { updateOp.getUpdateMutable().set(reencoded); }); + } + } +} + // Applies all cached encoding updates to tensor ops. static void applyTensorEncodingUpdates(TensorEncodingUpdates &updates) { for (auto &[op, operandUpdates] : updates) { + // Copy to local variable to allow capture in C++17 lambdas. + const OperandEncodingUpdates &opUpdates = operandUpdates; IRRewriter rewriter(op->getContext()); - // TODO: Handle other TensorPhaseOp ops (TensorFillOp, etc.) via TypeSwitch. - if (auto dispatchOp = dyn_cast(op)) { - updateTensorDispatchOp(dispatchOp, operandUpdates, rewriter); - } + TypeSwitch(op) + .Case([&](auto dispatchOp) { + updateTensorDispatchOp(dispatchOp, opUpdates, rewriter); + }) + .Case([&](auto cloneOp) { + updateTensorCloneOp(cloneOp, opUpdates, rewriter); + }) + .Case([&](auto encodeOp) { + updateTensorEncodeOp(encodeOp, opUpdates, rewriter); + }) + .Case([&](auto updateOp) { + updateTensorUpdateOp(updateOp, opUpdates, rewriter); + }) + .Case( + [&](auto) { + assert(false && "unexpected tensor op needing encoding update"); + }) + .Default([](Operation *op) { + LDBG() << " Unhandled op: " << op->getName() + << ", maybe it is a new tensor op?"; + assert(false); + }); } } -// Collects updates for stream tensor ops by walking from global loads. +// Collects updates for stream tensor ops by walking from global loads. Fixup +// should be applied to all stream tensor ops that use the encoded global's +// data. static void collectUpdatesForStreamTensorOps(Explorer &explorer, EncodedGlobalInfo &encodedInfo, Attribute newEncoding, @@ -582,18 +716,24 @@ static void collectUpdatesForStreamTensorOps(Explorer &explorer, return WalkResult::advance(); } - // TODO: Handle other tensor phase ops (TensorFillOp, etc.) - auto dispatchOp = dyn_cast(user); - if (!dispatchOp) { - return WalkResult::advance(); - } - - // The operand number is the index in the full operand list (including - // workload). We need the index in getMixedOperands() for encoding lookup. - unsigned mixedOperandIdx = - operand.getOperandNumber() - dispatchOp.getWorkload().size(); - LDBG() << " Found TensorDispatchOp operand " << mixedOperandIdx; - updates[user][mixedOperandIdx] = newEncoding; + // Do not continue walking past these ops because this is the end point. + // The fixup will be applied directly to these ops, so updates are not + // needed for their users. + TypeSwitch(user) + .Case([&](auto dispatchOp) { + // The operand number is the index in the full operand list + // (including workload). We need the index in getMixedOperands() for + // encoding lookup. + unsigned mixedOperandIdx = + operand.getOperandNumber() - dispatchOp.getWorkload().size(); + LDBG() << " Found TensorDispatchOp operand " + << mixedOperandIdx; + updates[user][mixedOperandIdx] = newEncoding; + }) + .Case([&](auto op) { + updates[user][operand.getOperandNumber()] = newEncoding; + }) + .Default([](Operation *op) {}); return WalkResult::advance(); }); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir index 1e2640f19af7..296bb1cb5a75 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/unify_encoding_for_globals.mlir @@ -859,3 +859,105 @@ util.initializer { util.return } + +// ----- + +// Test: TensorCloneOp, TensorEncodeOp, and TensorUpdateOp in dispatch site. + +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_encoding.specialization_resolver<123>}> +#device_target_local = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +// CHECK-DAG: #[[$ENC:.+]] = #iree_encoding.testing]> +// CHECK-DAG: #[[$ENC2:.+]] = #iree_encoding.testing]> +#encoding1 = #iree_encoding.testing]> +#encoding2 = #iree_encoding.testing]> + +// CHECK: util.global private @[[$DEVICE_A:.+]] = +util.global private @device_a = #device_target_local +util.global private @weight : !stream.resource +util.global private @weight_size : index +util.global private @encoded_v1 : !stream.resource +util.global private @encoded_v1_size : index +util.global private @encoded_v2 : !stream.resource +util.global private @encoded_v2_size : index + +// CHECK: util.initializer +util.initializer { + %cst = stream.tensor.constant on(#hal.device.affinity<@device_a>) : tensor<4096x4096xf32> in !stream.resource = #stream.parameter.named<"model"::"weight"> : tensor<4096x4096xf32> + %0 = stream.resource.size %cst : !stream.resource + util.global.store %cst, @weight : !stream.resource + util.global.store %0, @weight_size : index + // CHECK: %[[SOURCE:.+]] = util.global.load @weight + %source = util.global.load @weight : !stream.resource + %source_size = util.global.load @weight_size : index + + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>> + %size1 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding1> : index + %enc1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource{%source_size} -> tensor<4096x4096xf32, #encoding1> in !stream.resource{%size1} + util.global.store %enc1, @encoded_v1 : !stream.resource + util.global.store %size1, @encoded_v1_size : index + + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[SOURCE]] : {{.*}} -> tensor<4096x4096xf32, #iree_encoding.specialized<123>> + %size2 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4096x4096xf32, #encoding2> : index + %enc2 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %source : tensor<4096x4096xf32> in !stream.resource{%source_size} -> tensor<4096x4096xf32, #encoding2> in !stream.resource{%size2} + util.global.store %enc2, @encoded_v2 : !stream.resource + util.global.store %size2, @encoded_v2_size : index + + util.return +} + +// CHECK-LABEL: util.func public @tensor_clone_reencode +util.func public @tensor_clone_reencode(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) { + %loaded_v1 = util.global.load @encoded_v1 : !stream.resource + %loaded_v1_size = util.global.load @encoded_v1_size : index + + // Re-encode should be inserted before the clone op. + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: %[[REENC:.+]] = stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) + // CHECK-SAME: tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK-SAME: -> tensor<4096x4096xf32, #[[$ENC]]> + // CHECK: stream.tensor.clone on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[REENC]] + // CHECK-SAME: tensor<4096x4096xf32, #[[$ENC]]> + %0 = stream.tensor.clone on(#hal.device.affinity<@device_a>) %loaded_v1 + : tensor<4096x4096xf32, #encoding1> in !stream.resource{%loaded_v1_size} + -> tensor<4096x4096xf32, #encoding1> in !stream.resource<*>{%loaded_v1_size} + + util.return +} + +// CHECK-LABEL: util.func public @tensor_encode_update_source +util.func public @tensor_encode_update_source(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) { + %loaded_v1 = util.global.load @encoded_v1 : !stream.resource + %loaded_v1_size = util.global.load @encoded_v1_size : index + + // The encode op's source_encoding should be updated to the unified encoding. + // CHECK: stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) + // CHECK-SAME: tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK-SAME: -> tensor<4096x4096xf32, #[[$ENC2]]> + %1 = stream.tensor.encode on(#hal.device.affinity<@device_a>) %loaded_v1 + : tensor<4096x4096xf32, #encoding1> in !stream.resource{%loaded_v1_size} + -> tensor<4096x4096xf32, #encoding2> in !stream.resource<*>{%loaded_v1_size} + + util.return +} + +// CHECK-LABEL: util.func public @tensor_update_reencode +util.func public @tensor_update_reencode(%arg0: !stream.resource<*>, %arg1: !stream.resource<*>, %arg2: index) { + %loaded_v1 = util.global.load @encoded_v1 : !stream.resource + %loaded_v1_size = util.global.load @encoded_v1_size : index + %c0 = arith.constant 0 : index + + // Re-encode should be inserted before the update op. + // CHECK: stream.tensor.sizeof on(#hal.device.affinity<@[[$DEVICE_A]]>) tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK: %[[REENC:.+]] = stream.tensor.encode on(#hal.device.affinity<@[[$DEVICE_A]]>) + // CHECK-SAME: tensor<4096x4096xf32, #iree_encoding.specialized<123>> + // CHECK-SAME: -> tensor<4096x4096xf32, #[[$ENC]]> + // CHECK: stream.tensor.update on(#hal.device.affinity<@[[$DEVICE_A]]>) %[[REENC]] + // CHECK-SAME: tensor<4096x4096xf32, #[[$ENC]]> + %2 = stream.tensor.update on(#hal.device.affinity<@device_a>) + %loaded_v1, %arg0[%c0, %c0] : tensor<4096x4096xf32, #encoding1> in !stream.resource{%loaded_v1_size} + -> tensor<4096x4096xf32, #encoding1> in %arg0 as !stream.resource<*>{%arg2} + + util.return +} From 91c344f17e0512646f4663f38ed2d47cd1ce0888 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 13 Jan 2026 09:59:53 -0800 Subject: [PATCH 25/71] Unifying VM testing infrastructure to allow better scaling. (#23099) As we add VMFB->C(/JS/WASM/etc) export and JIT we'll now have a consistent way to get the whole test suite running without drift. EmitC has been converted but is mostly going to remain static as it is deprecated - this only maintains the existing tests that were running with it. --- .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 92 +- .../VMToEmitC/test/control_flow_ops.mlir | 6 +- runtime/src/iree/vm/bytecode/BUILD.bazel | 9 +- runtime/src/iree/vm/bytecode/CMakeLists.txt | 5 - .../iree/vm/bytecode/dispatch_async_test.cc | 834 ------------------ runtime/src/iree/vm/bytecode/dispatch_test.cc | 146 --- runtime/src/iree/vm/list.c | 8 +- runtime/src/iree/vm/list_test.cc | 55 ++ runtime/src/iree/vm/native_module_packing.h | 173 +++- runtime/src/iree/vm/test/BUILD.bazel | 20 +- runtime/src/iree/vm/test/CMakeLists.txt | 25 +- runtime/src/iree/vm/test/async_ops.mlir | 250 ++++-- runtime/src/iree/vm/test/bytecode/BUILD.bazel | 37 + .../src/iree/vm/test/bytecode/CMakeLists.txt | 33 + .../vm/test/bytecode/bytecode_module_test.cc | 94 ++ runtime/src/iree/vm/test/emitc/BUILD.bazel | 6 +- runtime/src/iree/vm/test/emitc/CMakeLists.txt | 16 +- .../iree/vm/test/emitc/emitc_module_test.cc | 122 +++ runtime/src/iree/vm/test/emitc/module_test.cc | 186 ---- runtime/src/iree/vm/testing/BUILD.bazel | 35 + runtime/src/iree/vm/testing/CMakeLists.txt | 40 + runtime/src/iree/vm/testing/test_runner.cc | 20 + runtime/src/iree/vm/testing/test_runner.h | 225 +++++ .../yieldable_test_module.h} | 81 +- 24 files changed, 1149 insertions(+), 1369 deletions(-) delete mode 100644 runtime/src/iree/vm/bytecode/dispatch_async_test.cc delete mode 100644 runtime/src/iree/vm/bytecode/dispatch_test.cc create mode 100644 runtime/src/iree/vm/test/bytecode/BUILD.bazel create mode 100644 runtime/src/iree/vm/test/bytecode/CMakeLists.txt create mode 100644 runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc create mode 100644 runtime/src/iree/vm/test/emitc/emitc_module_test.cc delete mode 100644 runtime/src/iree/vm/test/emitc/module_test.cc create mode 100644 runtime/src/iree/vm/testing/BUILD.bazel create mode 100644 runtime/src/iree/vm/testing/CMakeLists.txt create mode 100644 runtime/src/iree/vm/testing/test_runner.cc create mode 100644 runtime/src/iree/vm/testing/test_runner.h rename runtime/src/iree/vm/{test/async_ops_test_module.h => testing/yieldable_test_module.h} (70%) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 913b73e6edb5..4a1c0b6f3dea 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -2202,6 +2202,15 @@ class ImportOpConverter { return importOp.emitError() << "failed to create call"; } + // Release refs in argument buffer after call returns. Refs that were + // taken by the callee (via assign_ref+memset) will be null and release + // will be a no-op. + if (failed(releaseArgumentBuffer( + flattenInputTypes(importOp, segmentSizes, builder), call.value(), + builder, loc))) { + return importOp.emitError() << "failed to release argument buffer"; + } + if (failed(unpackResultBuffer(importOp.getResultTypes(), newFuncOp, call.value(), builder, loc))) { return importOp.emitError() << "failed to unpack result struct"; @@ -2451,10 +2460,14 @@ class ImportOpConverter { /*operand=*/uint8Ptr) .getResult(); + // Retain the ref into args_storage. The callee may take ownership via + // assign_ref+memset(0), so we must retain (not just assign/borrow). + // After the call returns, releaseArgumentBuffer will release any refs + // that weren't taken by the callee. emitc::CallOpaqueOp::create(builder, /*location=*/loc, /*type=*/TypeRange{}, - /*callee=*/"iree_vm_ref_assign", + /*callee=*/"iree_vm_ref_retain", /*operands=*/ArrayRef{arg, refPtr}); } else { auto argLValue = emitc_builders::asLValue(builder, loc, arg); @@ -2482,6 +2495,83 @@ class ImportOpConverter { return success(); } + // Releases refs in the argument buffer after an import call returns. + // This mirrors packArgumentBuffer but releases instead of retaining. + // Refs that were taken by the callee (via assign_ref+memset) will be null + // and release will be a no-op. + LogicalResult releaseArgumentBuffer(ArrayRef inputTypes, + TypedValue call, + OpBuilder &builder, Location loc) const { + // Find the last ref type index. We only need to iterate up to and including + // that index to release all refs. This avoids generating unused pointer + // arithmetic for trailing non-ref types. + std::optional lastRefIndex; + for (size_t i = 0; i < inputTypes.size(); i++) { + if (isa(inputTypes[i])) { + lastRefIndex = i; + } + } + if (!lastRefIndex) { + return success(); + } + + auto ctx = builder.getContext(); + + auto arguments = + emitc::MemberOp::create(builder, loc, + /*type=*/ + emitc::LValueType::get(emitc::OpaqueType::get( + ctx, "iree_byte_span_t")), + /*memberName=*/"arguments", + /*operand=*/call) + .getResult(); + + Type bytePtrType = + emitc::PointerType::get(builder.getIntegerType(8, false)); + auto uint8Ptr = emitc_builders::structMember(builder, loc, + /*type=*/bytePtrType, + /*memberName=*/"data", + /*operand=*/arguments); + + // Only iterate up to and including the last ref type. + for (size_t i = 0; i <= *lastRefIndex; i++) { + Type inputType = inputTypes[i]; + + // Get the value type and compute alignment (must match packArgumentBuffer + // exactly to ensure we're releasing the correct locations). + Type valueType = typeConverter.convertTypeAsNonPointer(inputType); + size_t alignment = getTypeAlignment(valueType); + if (alignment > 4) { + uint8Ptr = emitc_builders::alignPtr(builder, loc, uint8Ptr, alignment); + } + + // Release refs. If the callee took ownership and zeroed the ref, + // iree_vm_ref_release on a null ref is a no-op. + if (isa(inputType)) { + Type refPtrType = emitc::PointerType::get( + emitc::OpaqueType::get(ctx, "iree_vm_ref_t")); + Value refPtr = emitc::CastOp::create(builder, + /*location=*/loc, + /*type=*/refPtrType, + /*operand=*/uint8Ptr) + .getResult(); + emitc_builders::ireeVmRefRelease(builder, loc, refPtr); + } + + // Advance pointer to next element (only if not at the last ref). + if (i < *lastRefIndex) { + Value size = + emitc_builders::sizeOf(builder, loc, TypeAttr::get(valueType)); + uint8Ptr = + emitc::AddOp::create(builder, + /*location=*/loc, /*type=*/bytePtrType, + /*operands=*/ArrayRef{uint8Ptr, size}) + .getResult(); + } + } + return success(); + } + LogicalResult unpackResultBuffer(ArrayRef resultTypes, mlir::emitc::FuncOp &funcOp, TypedValue call, diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir index 3929db967f10..7ee30a829acb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir @@ -725,7 +725,8 @@ vm.module @my_module { // CHECK: %[[ARGSDATA:.+]] = load %[[ARGSDATA_LVAL]] : > // CHECK: call_opaque "iree_host_align" // CHECK: %[[ARG:.+]] = cast %{{.+}} : !emitc.ptr to !emitc.ptr> - // CHECK: call_opaque "iree_vm_ref_assign"(%arg2, %[[ARG]]) + // Retain the ref into args_storage (not just assign/borrow). + // CHECK: call_opaque "iree_vm_ref_retain"(%arg2, %[[ARG]]) // Create the call to the imported function. // CHECK: %[[MODULE_LVAL:.+]] = "emitc.member_of_ptr"(%[[FUNC_LVAL]]) <{member = "module"}> : (!emitc.lvalue>>) -> !emitc.lvalue>> @@ -735,6 +736,9 @@ vm.module @my_module { // CHECK-NEXT: %[[ARGSTRUCT_RVAL:.+]] = load %[[ARGSTRUCT]] : > // CHECK-NEXT: %{{.+}} = call_opaque "EMITC_CALL_INDIRECT"(%[[BEGIN_CALL]], %[[MODULE]], %arg0, %[[ARGSTRUCT_RVAL]]) + // Release refs in argument buffer after call returns. + // CHECK: call_opaque "iree_vm_ref_release" + // Unpack the function results (with pointer alignment). // CHECK: %[[RES_MEMBER:.+]] = "emitc.member"(%[[ARGSTRUCT]]) <{member = "results"}> : (!emitc.lvalue>) -> !emitc.lvalue> // CHECK: %[[RESPTR_MEMBER:.+]] = "emitc.member"(%[[RES_MEMBER]]) <{member = "data"}> : (!emitc.lvalue>) -> !emitc.lvalue> diff --git a/runtime/src/iree/vm/bytecode/BUILD.bazel b/runtime/src/iree/vm/bytecode/BUILD.bazel index 177ecd1fcf3a..7eeddb1ccb25 100644 --- a/runtime/src/iree/vm/bytecode/BUILD.bazel +++ b/runtime/src/iree/vm/bytecode/BUILD.bazel @@ -55,11 +55,7 @@ if(IREE_BUILD_COMPILER) iree_runtime_cc_test( name = "module_test", - srcs = [ - "dispatch_async_test.cc", - "dispatch_test.cc", - "module_test.cc", - ], + srcs = ["module_test.cc"], deps = [ ":module", ":module_test_module_c", @@ -67,9 +63,6 @@ iree_runtime_cc_test( "//runtime/src/iree/testing:gtest", "//runtime/src/iree/testing:gtest_main", "//runtime/src/iree/vm", - "//runtime/src/iree/vm/test:all_bytecode_modules_c", - "//runtime/src/iree/vm/test:async_bytecode_modules_c", - "//runtime/src/iree/vm/test:async_ops_test_module", ], ) diff --git a/runtime/src/iree/vm/bytecode/CMakeLists.txt b/runtime/src/iree/vm/bytecode/CMakeLists.txt index 8d0250024cbc..2293aadf4e58 100644 --- a/runtime/src/iree/vm/bytecode/CMakeLists.txt +++ b/runtime/src/iree/vm/bytecode/CMakeLists.txt @@ -41,8 +41,6 @@ iree_cc_test( NAME module_test SRCS - "dispatch_async_test.cc" - "dispatch_test.cc" "module_test.cc" DEPS ::module @@ -51,9 +49,6 @@ iree_cc_test( iree::testing::gtest iree::testing::gtest_main iree::vm - iree::vm::test::all_bytecode_modules_c - iree::vm::test::async_bytecode_modules_c - iree::vm::test::async_ops_test_module ) iree_bytecode_module( diff --git a/runtime/src/iree/vm/bytecode/dispatch_async_test.cc b/runtime/src/iree/vm/bytecode/dispatch_async_test.cc deleted file mode 100644 index c117791c737e..000000000000 --- a/runtime/src/iree/vm/bytecode/dispatch_async_test.cc +++ /dev/null @@ -1,834 +0,0 @@ -// Copyright 2022 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Tests covering the dispatch logic for individual ops. -// -// iree/vm/test/async_ops.mlir contains the functions used here for testing. We -// avoid defining the IR inline here so that we can run this test on platforms -// that we can't run the full MLIR compiler stack on. - -#include "iree/base/api.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" -#include "iree/vm/api.h" -#include "iree/vm/bytecode/module.h" - -// Compiled module embedded here to avoid file IO: -#include "iree/vm/test/async_bytecode_modules.h" - -// Native test module for yieldable imports. -#include "iree/vm/test/async_ops_test_module.h" - -namespace iree { -namespace { - -using iree::testing::status::StatusIs; - -class VMBytecodeDispatchAsyncTest : public ::testing::Test { - protected: - void SetUp() override { - IREE_TRACE_SCOPE(); - const iree_file_toc_t* file = async_bytecode_modules_c_create(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - // Create native yieldable_test module (required by async_ops imports). - IREE_CHECK_OK(yieldable_test_module_create( - instance_, iree_allocator_system(), &native_module_)); - - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{reinterpret_cast(file->data), - static_cast(file->size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - // Native module first for import resolution. - std::vector modules = {native_module_, bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - void TearDown() override { - IREE_TRACE_SCOPE(); - iree_vm_module_release(bytecode_module_); - iree_vm_module_release(native_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* native_module_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -// Tests a simple straight-line yield sequence that requires 3 resumes. -// See iree/vm/test/async_ops.mlir > @yield_sequence -TEST_F(VMBytecodeDispatchAsyncTest, YieldSequence) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("yield_sequence"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 97; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 0/3 - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // 1/3 - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // 2/3 - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // 3/3 - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - ASSERT_EQ(ret_value, arg_value + 3); - - iree_vm_stack_deinitialize(stack); -} - -// Tests a yield with data-dependent control, ensuring that we run the -// alternating branches and pass along branch args on resume. -// See iree/vm/test/async_ops.mlir > @yield_divergent -TEST_F(VMBytecodeDispatchAsyncTest, YieldDivergent) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("yield_divergent"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - // result = %arg0 ? %arg1 : %arg2 - struct { - uint32_t arg0; - uint32_t arg1; - uint32_t arg2; - } arg_values = { - 0, - 100, - 200, - }; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // arg0=0: result = %arg0 ? %arg1 : %arg2 => %arg2 - arg_values.arg0 = 0; - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - ASSERT_EQ(ret_value, arg_values.arg2); - - // arg0=1: result = %arg0 ? %arg1 : %arg2 => %arg1 - arg_values.arg0 = 1; - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - ASSERT_EQ(ret_value, arg_values.arg1); - - iree_vm_stack_deinitialize(stack); -} - -//===----------------------------------------------------------------------===// -// CallYieldable tests -//===----------------------------------------------------------------------===// - -class VMBytecodeDispatchCallYieldableTest : public ::testing::Test { - protected: - void SetUp() override { - IREE_TRACE_SCOPE(); - const iree_file_toc_t* file = async_bytecode_modules_c_create(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - // Create native yieldable_test module (required by async_ops imports). - IREE_CHECK_OK(yieldable_test_module_create( - instance_, iree_allocator_system(), &native_module_)); - - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{reinterpret_cast(file->data), - static_cast(file->size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - // Native module first for import resolution. - std::vector modules = {native_module_, bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - void TearDown() override { - IREE_TRACE_SCOPE(); - iree_vm_module_release(bytecode_module_); - iree_vm_module_release(native_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* native_module_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -// Tests calling an internal function that yields 4 times via vm.call.yieldable. -// See iree/vm/test/call_yieldable_ops.mlir > @call_yieldable_internal -TEST_F(VMBytecodeDispatchCallYieldableTest, CallYieldableInternal) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_internal"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(nullptr, 0); // No arguments - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The callee yields 3 times, so we need 3 resumes. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be 4 (0 + 4 increments across 4 basic blocks) - ASSERT_EQ(ret_value, 4u); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling an internal yieldable function with an argument. -// See iree/vm/test/call_yieldable_ops.mlir > @call_yieldable_with_arg -TEST_F(VMBytecodeDispatchCallYieldableTest, CallYieldableWithArg) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_with_arg"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 42; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The callee yields 1 time. - // 0/1 - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // 1/1 - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg_value + 1 - ASSERT_EQ(ret_value, arg_value + 1); - - iree_vm_stack_deinitialize(stack); -} - -//===----------------------------------------------------------------------===// -// CallYieldable to Imports tests -//===----------------------------------------------------------------------===// -// Tests vm.call.yieldable calling native module functions that yield. - -class VMBytecodeDispatchCallYieldableImportTest : public ::testing::Test { - protected: - void SetUp() override { - IREE_TRACE_SCOPE(); - const iree_file_toc_t* file = async_bytecode_modules_c_create(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - // Create native yieldable_test module. - IREE_CHECK_OK(yieldable_test_module_create( - instance_, iree_allocator_system(), &native_module_)); - - // Create bytecode module that imports from native module. - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{reinterpret_cast(file->data), - static_cast(file->size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - // Create context with both modules (native first for import resolution). - std::vector modules = {native_module_, bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - void TearDown() override { - IREE_TRACE_SCOPE(); - iree_vm_module_release(bytecode_module_); - iree_vm_module_release(native_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* native_module_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -// Tests calling a yieldable import that yields 3 times. -// This exercises Bug 1 fix: PC must be saved at instruction start, not after -// decode. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportYields3) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_yields_3"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 100; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The import yields 3 times. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg + 3 - ASSERT_EQ(ret_value, arg_value + 3); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a yieldable import that yields 0 times (synchronous). -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportYields0) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_yields_0"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 42; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // No yields, should complete immediately. - IREE_ASSERT_OK( - function.module->begin_call(function.module->self, stack, call)); - - // Result should be arg + 0 - ASSERT_EQ(ret_value, arg_value); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a yieldable import after an internal function call. -// This exercises Bug 2 fix: return_registers must be cleared after internal -// call. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableAfterInternal) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_after_internal"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 5; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The function: - // 1. Calls internal_add_10(arg) -> arg + 10 - // 2. Calls yieldable import yield_n(arg+10, 2) which yields 2 times - // Expected: 2 yields, result = (arg + 10) + 2 - - // begin -> internal call completes, import yields 1st time -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be (arg + 10) + 2 = arg + 12 - ASSERT_EQ(ret_value, arg_value + 12); - - iree_vm_stack_deinitialize(stack); -} - -// Tests two sequential yieldable import calls in the same function. -// This catches bugs where the second call sees stale state from the first. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportSequential) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_sequential"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 10; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // First import yields 2 times, second import yields 3 times = 5 total yields. - // begin -> 1st import, 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st import, 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st import done, 2nd import, 1st yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd import, 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd import, 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd import done, return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be (arg + 2) + 3 = arg + 5 - ASSERT_EQ(ret_value, arg_value + 5); - - iree_vm_stack_deinitialize(stack); -} - -// Tests a yieldable import nested inside an internal yieldable function. -// This is the most complex frame stack scenario. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportNested) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_nested_yieldable"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 50; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // Sequence: 1 yield (internal) + 2 yields (import) + 1 yield (internal) = 4 - // begin -> internal 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> import 1st yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> import 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> import done, internal 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> internal done, return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be ((arg + 1) + 2) + 1 = arg + 4 - ASSERT_EQ(ret_value, arg_value + 4); - - iree_vm_stack_deinitialize(stack); -} - -// Tests a yieldable import with many yields to catch state accumulation bugs. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, YieldableImportStress) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_yieldable_import_stress"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 1000; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 10 yields total. - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - for (int i = 1; i < 10; ++i) { - ASSERT_THAT(function.module->resume_call(function.module->self, stack, - call.results), - StatusIs(StatusCode::kDeferred)) - << "Expected DEFERRED at resume " << i; - } - - // Final resume should complete. - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg + 10 - ASSERT_EQ(ret_value, arg_value + 10); - - iree_vm_stack_deinitialize(stack); -} - -//===----------------------------------------------------------------------===// -// CallVariadicYieldable to Imports tests -//===----------------------------------------------------------------------===// -// Tests vm.call.variadic.yieldable calling native module functions that yield. - -// Tests calling a variadic yieldable import with 2 args and 3 yields. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldable2Args) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_2args"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - struct { - uint32_t arg0; - uint32_t arg1; - } arg_values = {10, 20}; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // The import sums the variadic args and yields 3 times. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 3rd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be (arg0 + arg1) + 3 = 10 + 20 + 3 = 33 - ASSERT_EQ(ret_value, arg_values.arg0 + arg_values.arg1 + 3); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a variadic yieldable import with 0 yields (synchronous). -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldable0Yields) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_0yields"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - struct { - uint32_t arg0; - uint32_t arg1; - uint32_t arg2; - } arg_values = {5, 10, 15}; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // No yields, should complete immediately. - IREE_ASSERT_OK( - function.module->begin_call(function.module->self, stack, call)); - - // Result should be arg0 + arg1 + arg2 = 5 + 10 + 15 = 30 - ASSERT_EQ(ret_value, arg_values.arg0 + arg_values.arg1 + arg_values.arg2); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a variadic yieldable import with 1 arg. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldable1Arg) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_1arg"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t arg_value = 100; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 2 yields. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be arg0 + 2 = 100 + 2 = 102 - ASSERT_EQ(ret_value, arg_value + 2); - - iree_vm_stack_deinitialize(stack); -} - -// Tests calling a variadic yieldable import with empty variadic list. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldableEmpty) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_empty"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(nullptr, 0); // No arguments - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // 1 yield. - // begin -> 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be 0 + 1 = 1 - ASSERT_EQ(ret_value, 1u); - - iree_vm_stack_deinitialize(stack); -} - -// Tests two sequential variadic yieldable calls. -TEST_F(VMBytecodeDispatchCallYieldableImportTest, VariadicYieldableSequential) { - IREE_TRACE_SCOPE(); - - iree_vm_function_t function; - IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - IREE_SV("call_variadic_yieldable_sequential"), &function)); - IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, - iree_vm_context_state_resolver(context_), - iree_allocator_system()); - - struct { - uint32_t arg0; - uint32_t arg1; - uint32_t arg2; - } arg_values = {10, 20, 5}; - uint32_t ret_value = 0; - - iree_vm_function_call_t call; - memset(&call, 0, sizeof(call)); - call.function = function; - call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values)); - call.results = iree_make_byte_span(&ret_value, sizeof(ret_value)); - - // First variadic: 2 yields, second variadic: 1 yield = 3 yields total. - // begin -> 1st call, 1st yield -> DEFERRED - ASSERT_THAT(function.module->begin_call(function.module->self, stack, call), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st call, 2nd yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 1st call done, 2nd call, 1st yield -> DEFERRED - ASSERT_THAT( - function.module->resume_call(function.module->self, stack, call.results), - StatusIs(StatusCode::kDeferred)); - - // resume -> 2nd call done, return -> OK - IREE_ASSERT_OK( - function.module->resume_call(function.module->self, stack, call.results)); - - // Result should be: - // First call: sum(arg0, arg1) + 2 yields = (10 + 20) + 2 = 32 - // Second call: sum(32, arg2) + 1 yield = (32 + 5) + 1 = 38 - ASSERT_EQ(ret_value, 38u); - - iree_vm_stack_deinitialize(stack); -} - -} // namespace -} // namespace iree diff --git a/runtime/src/iree/vm/bytecode/dispatch_test.cc b/runtime/src/iree/vm/bytecode/dispatch_test.cc deleted file mode 100644 index 21361bfe2a0a..000000000000 --- a/runtime/src/iree/vm/bytecode/dispatch_test.cc +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Tests covering the dispatch logic for individual ops. -// -// iree/vm/test/*.mlir contains the functions used here for testing. We -// avoid defining the IR inline here so that we can run this test on platforms -// that we can't run the full MLIR compiler stack on. - -#include "iree/base/api.h" -#include "iree/testing/gtest.h" -#include "iree/vm/api.h" -#include "iree/vm/bytecode/module.h" - -// Compiled module embedded here to avoid file IO: -#include "iree/vm/test/all_bytecode_modules.h" - -namespace { - -struct TestParams { - const struct iree_file_toc_t& module_file; - std::string function_name; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - std::string name{params.module_file.name}; - auto name_sv = iree_make_string_view(name.data(), name.size()); - iree_string_view_replace_char(name_sv, ':', '_'); - iree_string_view_replace_char(name_sv, '.', '_'); - return os << name << "_" << params.function_name; -} - -std::vector GetModuleTestParams() { - std::vector test_params; - - iree_vm_instance_t* instance = NULL; - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance)); - - const struct iree_file_toc_t* module_file_toc = - all_bytecode_modules_c_create(); - for (size_t i = 0; i < all_bytecode_modules_c_size(); ++i) { - const auto& module_file = module_file_toc[i]; - iree_vm_module_t* module = nullptr; - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{ - reinterpret_cast(module_file.data), - static_cast(module_file.size)}, - iree_allocator_null(), iree_allocator_system(), &module)); - iree_vm_module_signature_t signature = iree_vm_module_signature(module); - test_params.reserve(test_params.size() + signature.export_function_count); - for (int i = 0; i < signature.export_function_count; ++i) { - iree_vm_function_t function; - IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal( - module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); - iree_string_view_t function_name = iree_vm_function_name(&function); - test_params.push_back( - {module_file, std::string(function_name.data, function_name.size)}); - } - iree_vm_module_release(module); - } - - iree_vm_instance_release(instance); - - return test_params; -} - -class VMBytecodeDispatchTest - : public ::testing::Test, - public ::testing::WithParamInterface { - protected: - virtual void SetUp() { - const auto& test_params = GetParam(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - IREE_CHECK_OK(iree_vm_bytecode_module_create( - instance_, IREE_VM_BYTECODE_MODULE_FLAG_NONE, - iree_const_byte_span_t{ - reinterpret_cast(test_params.module_file.data), - static_cast(test_params.module_file.size)}, - iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); - - std::vector modules = {bytecode_module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - } - - virtual void TearDown() { - iree_vm_module_release(bytecode_module_); - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_status_t RunFunction(const char* function_name) { - iree_vm_function_t function; - IREE_CHECK_OK(iree_vm_module_lookup_function_by_name( - bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, - iree_make_cstring_view(function_name), &function)); - - iree_vm_invocation_flags_t flags = IREE_VM_INVOCATION_FLAG_NONE; - // NOTE: adding this bit makes it easy to debug issues on stdout: - // flags |= IREE_VM_INVOCATION_FLAG_TRACE_EXECUTION; - return iree_vm_invoke(context_, function, flags, - /*policy=*/nullptr, /*inputs=*/nullptr, - /*outputs=*/nullptr, iree_allocator_system()); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; - iree_vm_module_t* bytecode_module_ = nullptr; -}; - -TEST_P(VMBytecodeDispatchTest, Check) { - const auto& test_params = GetParam(); - bool expect_failure = test_params.function_name.find("fail_") == 0; - - iree_status_t status = RunFunction(test_params.function_name.c_str()); - if (iree_status_is_ok(status)) { - if (expect_failure) { - GTEST_FAIL() << "Function expected failure but succeeded"; - } else { - GTEST_SUCCEED(); - } - } else { - if (expect_failure) { - iree_status_ignore(status); - GTEST_SUCCEED(); - } else { - GTEST_FAIL() << "Function expected success but failed with error: " - << iree::Status(std::move(status)).ToString(); - } - } -} - -INSTANTIATE_TEST_SUITE_P(VMIRFunctions, VMBytecodeDispatchTest, - ::testing::ValuesIn(GetModuleTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace diff --git a/runtime/src/iree/vm/list.c b/runtime/src/iree/vm/list.c index d78bc852fe80..89f0fca5ee7c 100644 --- a/runtime/src/iree/vm/list.c +++ b/runtime/src/iree/vm/list.c @@ -860,6 +860,9 @@ static iree_status_t iree_vm_list_get_ref(const iree_vm_list_t* list, return iree_make_status(IREE_STATUS_FAILED_PRECONDITION); } iree_vm_list_ref_op(ref_mode, &variant->ref, out_value); + if (ref_mode == IREE_VM_LIST_REF_MOVE) { + variant->type = iree_vm_make_undefined_type_def(); + } break; } default: @@ -978,7 +981,7 @@ static iree_status_t iree_vm_list_get_variant(const iree_vm_list_t* list, "index %" PRIhsz " out of bounds (%" PRIhsz ")", i, list->count); } - iree_vm_variant_reset(out_variant); + *out_variant = iree_vm_variant_empty(); uintptr_t element_ptr = (uintptr_t)list->storage + i * list->element_size; switch (list->storage_mode) { case IREE_VM_LIST_STORAGE_MODE_VALUE: { @@ -998,6 +1001,9 @@ static iree_status_t iree_vm_list_get_variant(const iree_vm_list_t* list, out_variant->type = variant->type; if (iree_vm_type_def_is_ref(variant->type)) { iree_vm_list_ref_op(ref_mode, &variant->ref, &out_variant->ref); + if (ref_mode == IREE_VM_LIST_REF_MOVE) { + variant->type = iree_vm_make_undefined_type_def(); + } } else { memcpy(out_variant->value_storage, variant->value_storage, sizeof(variant->value_storage)); diff --git a/runtime/src/iree/vm/list_test.cc b/runtime/src/iree/vm/list_test.cc index d75c02f18eb0..b95375f9a864 100644 --- a/runtime/src/iree/vm/list_test.cc +++ b/runtime/src/iree/vm/list_test.cc @@ -273,6 +273,61 @@ TEST_F(VMListTest, GetRefRetainOrMove) { iree_vm_list_release(list); } +// Tests that moving a ref from a variant list properly marks the slot as empty. +TEST_F(VMListTest, VariantListRefMoveMarksSlotEmpty) { + // Create a variant list (stores any type). + iree_vm_type_def_t element_type = iree_vm_make_undefined_type_def(); + iree_vm_list_t* list = nullptr; + IREE_ASSERT_OK(iree_vm_list_create(element_type, /*initial_capacity=*/1, + iree_allocator_system(), &list)); + IREE_ASSERT_OK(iree_vm_list_resize(list, 1)); + + // Set a ref into the variant slot. + iree_vm_ref_t ref_a = MakeRef(1.0f); + IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, 0, &ref_a)); + + // Verify the slot contains a ref. + { + iree_vm_variant_t variant; + IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, 0, &variant)); + EXPECT_TRUE(iree_vm_variant_is_ref(variant)); + EXPECT_FALSE(iree_vm_variant_is_empty(variant)); + } + + // Move the ref out of the variant list. + iree_vm_ref_t moved{0}; + IREE_ASSERT_OK( + iree_vm_list_get_ref_retain_or_move(list, 0, /*is_move=*/true, &moved)); + EXPECT_TRUE(test_a_isa(moved)); + iree_vm_ref_release(&moved); + + // Verify the slot is now empty (type should be undefined/variant). + { + iree_vm_variant_t variant; + IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, 0, &variant)); + EXPECT_TRUE(iree_vm_variant_is_empty(variant)) + << "After move, variant slot should be empty"; + } + + // Also test get_variant_move marks the slot empty. + { + iree_vm_ref_t ref_b = MakeRef(2.0f); + IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, 0, &ref_b)); + + iree_vm_variant_t moved_variant; + IREE_ASSERT_OK(iree_vm_list_get_variant_move(list, 0, &moved_variant)); + EXPECT_TRUE(iree_vm_variant_is_ref(moved_variant)); + iree_vm_ref_release(&moved_variant.ref); + + iree_vm_variant_t after_move; + IREE_ASSERT_OK(iree_vm_list_get_variant_assign(list, 0, &after_move)); + EXPECT_TRUE(iree_vm_variant_is_empty(after_move)) + << "After get_variant_move, slot should be empty"; + } + + iree_vm_list_release(list); +} + // Tests simple variant list usage, mainly just for demonstration. // Stores any heterogeneous element type, equivalent to `!vm.list`. TEST_F(VMListTest, UsageVariant) { diff --git a/runtime/src/iree/vm/native_module_packing.h b/runtime/src/iree/vm/native_module_packing.h index 49860f16d22e..794bb2668471 100644 --- a/runtime/src/iree/vm/native_module_packing.h +++ b/runtime/src/iree/vm/native_module_packing.h @@ -329,6 +329,41 @@ static inline params_ptr_t align_ptr(params_ptr_t ptr) { return ptr; } +// Computes the effective alignment for a parameter type. +// Only 8-byte types (i64, f64, ref) require special alignment; everything +// else uses the minimum 4-byte alignment. The primary template works for all +// types since alignof() is valid for any complete type. +template +struct ParamAlignment { + static constexpr iree_host_size_t value = alignof(T) > sizeof(int32_t) + ? alignof(T) + : sizeof(int32_t); +}; + +// Computes the maximum alignment across a parameter pack. +template +struct MaxParamAlignment; + +template <> +struct MaxParamAlignment<> { + static constexpr iree_host_size_t value = sizeof(int32_t); +}; + +template +struct MaxParamAlignment { + static constexpr iree_host_size_t value = + ParamAlignment::type>::value; +}; + +template +struct MaxParamAlignment { + static constexpr iree_host_size_t value = + (ParamAlignment::type>::value > + MaxParamAlignment::value) + ? ParamAlignment::type>::value + : MaxParamAlignment::value; +}; + template struct ParamUnpack; template <> @@ -364,31 +399,36 @@ struct Unpacker { typename impl::remove_cvref::type>::storage_type()...); Status status; params_ptr_t ptr = storage.data; - ApplyLoad(status, ptr, params, + params_ptr_t limit = storage.data + storage.data_length; + ApplyLoad(status, ptr, limit, params, std::make_index_sequence()); IREE_RETURN_IF_ERROR(std::move(status)); - // Note: we check > instead of != because alignment padding can leave - // unused bytes at the end of the buffer. - params_ptr_t limit = storage.data + storage.data_length; - if (IREE_UNLIKELY(ptr > limit)) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "argument buffer unpacking failure; consumed %" PRIhsz - " bytes beyond %" PRIhsz " byte buffer", - (reinterpret_cast(ptr) - reinterpret_cast(limit)), - storage.data_length); + // Verify remaining bytes are valid trailing alignment padding. + // Buffer sizes are computed with trailing padding to max_alignment, so + // unconsumed bytes must be less than max_alignment. This catches cases + // where the caller provided more data than expected (trailing garbage). + constexpr iree_host_size_t max_alignment = + impl::MaxParamAlignment::value; + iree_host_size_t remaining = static_cast(limit - ptr); + if (IREE_UNLIKELY(ptr > limit || remaining >= max_alignment)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer unpacking failure; %" PRIhsz + " bytes remaining in %" PRIhsz + " byte buffer (max valid padding: %" PRIhsz ")", + remaining, storage.data_length, + max_alignment - 1); } return std::move(params); } private: template - static void ApplyLoad(Status& status, params_ptr_t& ptr, T&& params, - std::index_sequence) { + static void ApplyLoad(Status& status, params_ptr_t& ptr, params_ptr_t limit, + T&& params, std::index_sequence) { impl::order_sequence{ (impl::ParamUnpack>::type>::type>:: - Load(status, ptr, std::get(params)), + Load(status, ptr, limit, std::get(params)), 0)...}; } }; @@ -397,8 +437,15 @@ struct Unpacker { template struct ParamUnpack> { using storage_type = T; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(T) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading primitive"); + return; + } out_param = *reinterpret_cast(ptr); ptr += sizeof(T); } @@ -408,8 +455,15 @@ struct ParamUnpack> { template <> struct ParamUnpack { using storage_type = opaque_ref; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } iree_vm_ref_retain(reinterpret_cast(ptr), &out_param); ptr += sizeof(iree_vm_ref_t); } @@ -421,8 +475,15 @@ struct ParamUnpack { template struct ParamUnpack> { using storage_type = ref; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -447,8 +508,15 @@ struct ParamUnpack> { template struct ParamUnpack> { using storage_type = ref; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -474,8 +542,15 @@ struct ParamUnpack> { template struct ParamUnpack::value>> { using storage_type = T*; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -500,8 +575,15 @@ struct ParamUnpack::value>> { template <> struct ParamUnpack { using storage_type = iree_string_view_t; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -530,8 +612,15 @@ struct ParamUnpack { template <> struct ParamUnpack { using storage_type = std::string_view; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(iree_vm_ref_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading ref"); + return; + } auto* reg_ptr = reinterpret_cast(ptr); ptr += sizeof(iree_vm_ref_t); if (reg_ptr->type == ref_type_descriptor::type()) { @@ -563,9 +652,10 @@ template struct ParamUnpack> { using element_type = typename impl::remove_cvref::type; using storage_type = std::array; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { for (size_t i = 0; i < S; ++i) { - ParamUnpack::Load(status, ptr, out_param[i]); + ParamUnpack::Load(status, ptr, limit, out_param[i]); } } }; @@ -574,16 +664,17 @@ struct ParamUnpack> { template struct ParamUnpack> { using storage_type = std::tuple::type...>; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { - UnpackTuple(status, ptr, out_param, + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + UnpackTuple(status, ptr, limit, out_param, std::make_index_sequence()); } template - static void UnpackTuple(Status& status, params_ptr_t& ptr, + static void UnpackTuple(Status& status, params_ptr_t& ptr, params_ptr_t limit, storage_type& params, std::index_sequence) { impl::order_sequence{ (ParamUnpack>::type>:: - Load(status, ptr, std::get(params)), + Load(status, ptr, limit, std::get(params)), 0)...}; } }; @@ -596,12 +687,20 @@ template struct ParamUnpack, enable_if_not_primitive> { using element_type = typename impl::remove_cvref::type; using storage_type = std::vector; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; + if (IREE_UNLIKELY(ptr + sizeof(int32_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading span count"); + return; + } iree_host_size_t count = *reinterpret_cast(ptr); ptr += sizeof(int32_t); out_param.resize(count); for (iree_host_size_t i = 0; i < count; ++i) { - ParamUnpack::Load(status, ptr, out_param[i]); + ParamUnpack::Load(status, ptr, limit, out_param[i]); + if (!status.ok()) return; } } }; @@ -612,10 +711,24 @@ template struct ParamUnpack, enable_if_primitive> { using element_type = U; using storage_type = iree::span; - static void Load(Status& status, params_ptr_t& ptr, storage_type& out_param) { + static void Load(Status& status, params_ptr_t& ptr, params_ptr_t limit, + storage_type& out_param) { + if (!status.ok()) return; + if (IREE_UNLIKELY(ptr + sizeof(int32_t) > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading span count"); + return; + } iree_host_size_t count = *reinterpret_cast(ptr); ptr += sizeof(int32_t); ptr = align_ptr(ptr); + if (IREE_UNLIKELY(ptr + sizeof(element_type) * count > limit)) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer overflow reading span elements" + " (count=%" PRIhsz ")", + count); + return; + } out_param = iree::span(reinterpret_cast(ptr), count); ptr += sizeof(element_type) * count; diff --git a/runtime/src/iree/vm/test/BUILD.bazel b/runtime/src/iree/vm/test/BUILD.bazel index 824492e31626..a81b2d09f6f4 100644 --- a/runtime/src/iree/vm/test/BUILD.bazel +++ b/runtime/src/iree/vm/test/BUILD.bazel @@ -34,6 +34,7 @@ iree_c_embed_data( ":assignment_ops_f32.vmfb", ":assignment_ops_f64.vmfb", ":assignment_ops_i64.vmfb", + ":async_ops.vmfb", ":buffer_ops.vmfb", ":call_ops.vmfb", ":comparison_ops.vmfb", @@ -61,25 +62,6 @@ iree_c_embed_data( h_file_output = "all_bytecode_modules.h", ) -iree_c_embed_data( - name = "async_bytecode_modules_c", - srcs = [ - ":async_ops.vmfb", - ], - c_file_output = "async_bytecode_modules.c", - flatten = True, - h_file_output = "async_bytecode_modules.h", -) - -iree_runtime_cc_library( - name = "async_ops_test_module", - hdrs = ["async_ops_test_module.h"], - deps = [ - "//runtime/src/iree/base", - "//runtime/src/iree/vm", - ], -) - iree_bytecode_module( name = "arithmetic_ops", src = "arithmetic_ops.mlir", diff --git a/runtime/src/iree/vm/test/CMakeLists.txt b/runtime/src/iree/vm/test/CMakeLists.txt index bdc7c787a8e1..8f00924e395f 100644 --- a/runtime/src/iree/vm/test/CMakeLists.txt +++ b/runtime/src/iree/vm/test/CMakeLists.txt @@ -26,6 +26,7 @@ iree_c_embed_data( "assignment_ops_f32.vmfb" "assignment_ops_f64.vmfb" "assignment_ops_i64.vmfb" + "async_ops.vmfb" "buffer_ops.vmfb" "call_ops.vmfb" "comparison_ops.vmfb" @@ -55,30 +56,6 @@ iree_c_embed_data( PUBLIC ) -iree_c_embed_data( - NAME - async_bytecode_modules_c - SRCS - "async_ops.vmfb" - C_FILE_OUTPUT - "async_bytecode_modules.c" - H_FILE_OUTPUT - "async_bytecode_modules.h" - FLATTEN - PUBLIC -) - -iree_cc_library( - NAME - async_ops_test_module - HDRS - "async_ops_test_module.h" - DEPS - iree::base - iree::vm - PUBLIC -) - iree_bytecode_module( NAME arithmetic_ops diff --git a/runtime/src/iree/vm/test/async_ops.mlir b/runtime/src/iree/vm/test/async_ops.mlir index 75772158c5e7..7a21e63a3d0e 100644 --- a/runtime/src/iree/vm/test/async_ops.mlir +++ b/runtime/src/iree/vm/test/async_ops.mlir @@ -1,17 +1,16 @@ -// Tested by iree/vm/bytecode/dispatch_async_test.cc. - vm.module @async_ops { //===--------------------------------------------------------------------===// // vm.yield //===--------------------------------------------------------------------===// // Tests a simple straight-line yield sequence that requires 3 resumes. - // - // Expects a result of %arg0 + 3. - vm.export @yield_sequence - vm.func @yield_sequence(%arg0: i32) -> i32 { + // Starts with 100, adds 1 three times across yields, expects 103. + vm.export @test_yield_sequence + vm.func @test_yield_sequence() { %c1 = vm.const.i32 1 - %y0 = vm.add.i32 %arg0, %c1 : i32 + %c100 = vm.const.i32 100 + %c100_dno = util.optimization_barrier %c100 : i32 + %y0 = vm.add.i32 %c100_dno, %c1 : i32 %y0_dno = util.optimization_barrier %y0 : i32 vm.yield ^bb1 ^bb1: @@ -23,25 +22,47 @@ vm.module @async_ops { %y2_dno = util.optimization_barrier %y2 : i32 vm.yield ^bb3 ^bb3: - vm.return %y2_dno : i32 + %c103 = vm.const.i32 103 + vm.check.eq %y2_dno, %c103, "100+1+1+1=103" : i32 + vm.return + } + + // Tests a yield with data-dependent control flow (true branch). + vm.export @test_yield_divergent_true + vm.func @test_yield_divergent_true() { + %c1 = vm.const.i32 1 + %c100 = vm.const.i32 100 + %c200 = vm.const.i32 200 + %cond = vm.cmp.nz.i32 %c1 : i32 + vm.cond_br %cond, ^true, ^false + ^true: + %v_true = util.optimization_barrier %c100 : i32 + vm.yield ^check(%v_true : i32) + ^false: + %v_false = util.optimization_barrier %c200 : i32 + vm.yield ^check(%v_false : i32) + ^check(%result : i32): + vm.check.eq %result, %c100, "cond=1 selects true branch" : i32 + vm.return } - // Tests a yield with data-dependent control, ensuring that we run the - // alternating branches and pass along branch args on resume. - // - // Expects a result of %arg0 ? %arg1 : %arg2. - vm.export @yield_divergent - vm.func @yield_divergent(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { - %cond = vm.cmp.nz.i32 %arg0 : i32 + // Tests a yield with data-dependent control flow (false branch). + vm.export @test_yield_divergent_false + vm.func @test_yield_divergent_false() { + %c0 = vm.const.i32 0 + %c100 = vm.const.i32 100 + %c200 = vm.const.i32 200 + %cond = vm.cmp.nz.i32 %c0 : i32 vm.cond_br %cond, ^true, ^false ^true: - %arg1_dno = util.optimization_barrier %arg1 : i32 - vm.yield ^bb3(%arg1_dno : i32) + %v_true = util.optimization_barrier %c100 : i32 + vm.yield ^check(%v_true : i32) ^false: - %arg2_dno = util.optimization_barrier %arg2 : i32 - vm.yield ^bb3(%arg2_dno: i32) - ^bb3(%result : i32): - vm.return %result : i32 + %v_false = util.optimization_barrier %c200 : i32 + vm.yield ^check(%v_false : i32) + ^check(%result : i32): + vm.check.eq %result, %c200, "cond=0 selects false branch" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -69,14 +90,15 @@ vm.module @async_ops { } // Tests calling an internal yieldable function. - // The callee yields 4 times, so we need 4 resumes. - // Expects result of 0 + 4 = 4. - vm.export @call_yieldable_internal attributes {emitc.exclude} - vm.func @call_yieldable_internal() -> i32 { + // The callee yields 4 times. Expects result of 0 + 4 = 4. + vm.export @test_call_yieldable_internal attributes {emitc.exclude} + vm.func @test_call_yieldable_internal() { %c0 = vm.const.i32 0 vm.call.yieldable @yield_counter(%c0) : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c4 = vm.const.i32 4 + vm.check.eq %result, %c4, "0+4=4" : i32 + vm.return } // Internal function that takes an input and yields once, returning input + 1. @@ -90,12 +112,15 @@ vm.module @async_ops { } // Tests calling an internal yieldable function with an argument. - // Expects result of %arg0 + 1. - vm.export @call_yieldable_with_arg attributes {emitc.exclude} - vm.func @call_yieldable_with_arg(%arg0: i32) -> i32 { - vm.call.yieldable @yield_add_one(%arg0) : (i32) -> ^resume(i32) + // Expects result of 42 + 1 = 43. + vm.export @test_call_yieldable_with_arg attributes {emitc.exclude} + vm.func @test_call_yieldable_with_arg() { + %c42 = vm.const.i32 42 + vm.call.yieldable @yield_add_one(%c42) : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c43 = vm.const.i32 43 + vm.check.eq %result, %c43, "42+1=43" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -112,42 +137,51 @@ vm.module @async_ops { vm.import private @yieldable_test.yield_variadic_sum(%args : i32 ..., %yield_count : i32) -> i32 attributes {vm.yield} // Test: call yieldable import with 3 yields. - // Expected: 3 DEFERRED returns, then OK with result = arg + 3 - vm.export @call_yieldable_import_yields_3 attributes {emitc.exclude} - vm.func @call_yieldable_import_yields_3(%arg0 : i32) -> i32 { + // Expects 100 + 3 = 103. + vm.export @test_call_yieldable_import_yields_3 attributes {emitc.exclude} + vm.func @test_call_yieldable_import_yields_3() { + %c100 = vm.const.i32 100 %c3 = vm.const.i32 3 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c3) : (i32, i32) -> ^resume(i32) + vm.call.yieldable @yieldable_test.yield_n(%c100, %c3) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c103 = vm.const.i32 103 + vm.check.eq %result, %c103, "100+3=103" : i32 + vm.return } // Test: call yieldable import with 0 yields (synchronous). - // Expected: immediate OK with result = arg - vm.export @call_yieldable_import_yields_0 attributes {emitc.exclude} - vm.func @call_yieldable_import_yields_0(%arg0 : i32) -> i32 { + // Expects immediate return with 42. + vm.export @test_call_yieldable_import_yields_0 attributes {emitc.exclude} + vm.func @test_call_yieldable_import_yields_0() { + %c42 = vm.const.i32 42 %c0 = vm.const.i32 0 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c0) : (i32, i32) -> ^resume(i32) + vm.call.yieldable @yieldable_test.yield_n(%c42, %c0) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + vm.check.eq %result, %c42, "42+0=42" : i32 + vm.return } // Test: call yieldable import after internal function call. - // This exercises Bug 2 fix: return_registers must be cleared after internal call. + // This exercises return_registers clearing after internal call. vm.func private @internal_add_10(%x : i32) -> i32 { %c10 = vm.const.i32 10 %r = vm.add.i32 %x, %c10 : i32 vm.return %r : i32 } - vm.export @call_yieldable_after_internal attributes {emitc.exclude} - vm.func @call_yieldable_after_internal(%arg0 : i32) -> i32 { + // Expects (5 + 10) + 2 = 17. + vm.export @test_call_yieldable_after_internal attributes {emitc.exclude} + vm.func @test_call_yieldable_after_internal() { + %c5 = vm.const.i32 5 // First call an internal function (sets return_registers). - %v1 = vm.call @internal_add_10(%arg0) : (i32) -> i32 - // Then call yieldable import (should see return_registers == NULL for begin). + %v1 = vm.call @internal_add_10(%c5) : (i32) -> i32 + // Then call yieldable import. %c2 = vm.const.i32 2 vm.call.yieldable @yieldable_test.yield_n(%v1, %c2) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c17 = vm.const.i32 17 + vm.check.eq %result, %c17, "(5+10)+2=17" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -155,25 +189,25 @@ vm.module @async_ops { //===--------------------------------------------------------------------===// // Test: two sequential yieldable import calls in the same function. - // This catches bugs where the second call sees stale state from the first. - // Expected: 2 yields from first call + 3 yields from second call = 5 total - // Result: (arg + 2) + 3 = arg + 5 - vm.export @call_yieldable_import_sequential attributes {emitc.exclude} - vm.func @call_yieldable_import_sequential(%arg0 : i32) -> i32 { + // Expects (10 + 2) + 3 = 15. + vm.export @test_call_yieldable_import_sequential attributes {emitc.exclude} + vm.func @test_call_yieldable_import_sequential() { + %c10 = vm.const.i32 10 %c2 = vm.const.i32 2 %c3 = vm.const.i32 3 - // First yieldable import: yields 2 times, returns arg + 2 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c2) : (i32, i32) -> ^after_first(i32) + // First yieldable import: yields 2 times, returns 10 + 2 = 12 + vm.call.yieldable @yieldable_test.yield_n(%c10, %c2) : (i32, i32) -> ^after_first(i32) ^after_first(%v1 : i32): - // Second yieldable import: yields 3 times, returns v1 + 3 = arg + 5 + // Second yieldable import: yields 3 times, returns 12 + 3 = 15 vm.call.yieldable @yieldable_test.yield_n(%v1, %c3) : (i32, i32) -> ^done(i32) ^done(%result : i32): - vm.return %result : i32 + %c15 = vm.const.i32 15 + vm.check.eq %result, %c15, "(10+2)+3=15" : i32 + vm.return } // Test: yieldable import nested inside an internal yieldable function. // The internal function yields before and after calling the import. - // This creates the most complex frame stack scenario. vm.func private @yield_then_import_then_yield(%arg0 : i32) -> i32 { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 @@ -193,25 +227,28 @@ vm.module @async_ops { vm.return %v2_dno : i32 } - // Export that calls the nested yieldable function. - // Expected sequence: 1 yield (internal) + 2 yields (import) + 1 yield (internal) = 4 yields - // Result: ((arg + 1) + 2) + 1 = arg + 4 - vm.export @call_nested_yieldable attributes {emitc.exclude} - vm.func @call_nested_yieldable(%arg0 : i32) -> i32 { - vm.call.yieldable @yield_then_import_then_yield(%arg0) : (i32) -> ^resume(i32) + // Expects ((50 + 1) + 2) + 1 = 54. + vm.export @test_call_nested_yieldable attributes {emitc.exclude} + vm.func @test_call_nested_yieldable() { + %c50 = vm.const.i32 50 + vm.call.yieldable @yield_then_import_then_yield(%c50) : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c54 = vm.const.i32 54 + vm.check.eq %result, %c54, "((50+1)+2)+1=54" : i32 + vm.return } // Test: stress test with many yields to catch state accumulation bugs. - // Calls yieldable import with high yield count. - // Expected: 10 yields, result = arg + 10 - vm.export @call_yieldable_import_stress attributes {emitc.exclude} - vm.func @call_yieldable_import_stress(%arg0 : i32) -> i32 { + // Expects 1000 + 10 = 1010. + vm.export @test_call_yieldable_import_stress attributes {emitc.exclude} + vm.func @test_call_yieldable_import_stress() { + %c1000 = vm.const.i32 1000 %c10 = vm.const.i32 10 - vm.call.yieldable @yieldable_test.yield_n(%arg0, %c10) : (i32, i32) -> ^resume(i32) + vm.call.yieldable @yieldable_test.yield_n(%c1000, %c10) : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c1010 = vm.const.i32 1010 + vm.check.eq %result, %c1010, "1000+10=1010" : i32 + vm.return } //===--------------------------------------------------------------------===// @@ -219,58 +256,75 @@ vm.module @async_ops { //===--------------------------------------------------------------------===// // Test: call variadic yieldable import with 2 args and 3 yields. - // Expected: 3 DEFERRED returns, then OK with result = (arg0 + arg1) + 3 - vm.export @call_variadic_yieldable_2args attributes {emitc.exclude} - vm.func @call_variadic_yieldable_2args(%arg0 : i32, %arg1 : i32) -> i32 { + // Expects (10 + 20) + 3 = 33. + vm.export @test_call_variadic_yieldable_2args attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_2args() { + %c10 = vm.const.i32 10 + %c20 = vm.const.i32 20 %c3 = vm.const.i32 3 - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %arg1, %c3) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^resume(i32) + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c10, %c20, %c3) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c33 = vm.const.i32 33 + vm.check.eq %result, %c33, "(10+20)+3=33" : i32 + vm.return } // Test: call variadic yieldable import with 0 yields (synchronous). - // Expected: immediate OK with result = arg0 + arg1 + arg2 - vm.export @call_variadic_yieldable_0yields attributes {emitc.exclude} - vm.func @call_variadic_yieldable_0yields(%arg0 : i32, %arg1 : i32, %arg2 : i32) -> i32 { + // Expects 5 + 10 + 15 = 30. + vm.export @test_call_variadic_yieldable_0yields attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_0yields() { + %c5 = vm.const.i32 5 + %c10 = vm.const.i32 10 + %c15 = vm.const.i32 15 %c0 = vm.const.i32 0 - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %arg1, %arg2, %c0) {segment_sizes = dense<[3, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32, i32) -> ^resume(i32) + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c5, %c10, %c15, %c0) {segment_sizes = dense<[3, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c30 = vm.const.i32 30 + vm.check.eq %result, %c30, "5+10+15+0=30" : i32 + vm.return } // Test: call variadic yieldable import with single arg. - // Expected: 2 yields, result = arg0 + 2 - vm.export @call_variadic_yieldable_1arg attributes {emitc.exclude} - vm.func @call_variadic_yieldable_1arg(%arg0 : i32) -> i32 { + // Expects 100 + 2 = 102. + vm.export @test_call_variadic_yieldable_1arg attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_1arg() { + %c100 = vm.const.i32 100 %c2 = vm.const.i32 2 - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %c2) {segment_sizes = dense<[1, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32) -> ^resume(i32) + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c100, %c2) {segment_sizes = dense<[1, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + %c102 = vm.const.i32 102 + vm.check.eq %result, %c102, "100+2=102" : i32 + vm.return } // Test: call variadic yieldable import with empty variadic list. - // Expected: 1 yield, result = 0 + 1 = 1 - vm.export @call_variadic_yieldable_empty attributes {emitc.exclude} - vm.func @call_variadic_yieldable_empty() -> i32 { + // Expects 0 + 1 = 1. + vm.export @test_call_variadic_yieldable_empty attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_empty() { %c1 = vm.const.i32 1 vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c1) {segment_sizes = dense<[0, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32) -> ^resume(i32) ^resume(%result : i32): - vm.return %result : i32 + vm.check.eq %result, %c1, "0+1=1" : i32 + vm.return } // Test: two sequential variadic yieldable calls. - // Expected: 2 yields from first + 1 yield from second = 3 yields total - // Result: ((arg0 + arg1) + 2) + (arg2) + 1 = arg0 + arg1 + arg2 + 3 - vm.export @call_variadic_yieldable_sequential attributes {emitc.exclude} - vm.func @call_variadic_yieldable_sequential(%arg0 : i32, %arg1 : i32, %arg2 : i32) -> i32 { + // Expects ((10 + 20) + 2) + (32 + 5) + 1 = 38. + vm.export @test_call_variadic_yieldable_sequential attributes {emitc.exclude} + vm.func @test_call_variadic_yieldable_sequential() { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 - // First variadic yieldable: sum(arg0, arg1) + 2 yields - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%arg0, %arg1, %c2) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^after_first(i32) + %c5 = vm.const.i32 5 + %c10 = vm.const.i32 10 + %c20 = vm.const.i32 20 + // First variadic yieldable: sum(10, 20) + 2 yields = 32 + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%c10, %c20, %c2) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^after_first(i32) ^after_first(%v1 : i32): - // Second variadic yieldable: sum(v1, arg2) + 1 yield - vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%v1, %arg2, %c1) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^done(i32) + // Second variadic yieldable: sum(32, 5) + 1 yield = 38 + vm.call.variadic.yieldable @yieldable_test.yield_variadic_sum(%v1, %c5, %c1) {segment_sizes = dense<[2, 1]> : vector<2xi16>, segment_types = [i32, i32]} : (i32, i32, i32) -> ^done(i32) ^done(%result : i32): - vm.return %result : i32 + %c38 = vm.const.i32 38 + vm.check.eq %result, %c38, "((10+20)+2)+((32+5)+1)=38" : i32 + vm.return } } diff --git a/runtime/src/iree/vm/test/bytecode/BUILD.bazel b/runtime/src/iree/vm/test/bytecode/BUILD.bazel new file mode 100644 index 000000000000..1d1500715a51 --- /dev/null +++ b/runtime/src/iree/vm/test/bytecode/BUILD.bazel @@ -0,0 +1,37 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_test") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_cmake_extra_content( + content = """ +if(NOT IREE_BUILD_COMPILER OR NOT IREE_BUILD_TESTS) + return() +endif() +""", + inline = True, +) + +iree_runtime_cc_test( + name = "bytecode_module_test", + srcs = ["bytecode_module_test.cc"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/testing:gtest_main", + "//runtime/src/iree/vm", + "//runtime/src/iree/vm/bytecode:module", + "//runtime/src/iree/vm/test:all_bytecode_modules_c", + "//runtime/src/iree/vm/testing:test_runner", + "//runtime/src/iree/vm/testing:yieldable_test_module", + ], +) diff --git a/runtime/src/iree/vm/test/bytecode/CMakeLists.txt b/runtime/src/iree/vm/test/bytecode/CMakeLists.txt new file mode 100644 index 000000000000..4ac6fb773f02 --- /dev/null +++ b/runtime/src/iree/vm/test/bytecode/CMakeLists.txt @@ -0,0 +1,33 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# runtime/src/iree/vm/test/bytecode/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +if(NOT IREE_BUILD_COMPILER OR NOT IREE_BUILD_TESTS) + return() +endif() + +iree_cc_test( + NAME + bytecode_module_test + SRCS + "bytecode_module_test.cc" + DEPS + iree::base + iree::testing::gtest + iree::testing::gtest_main + iree::vm + iree::vm::bytecode::module + iree::vm::test::all_bytecode_modules_c + iree::vm::testing::test_runner + iree::vm::testing::yieldable_test_module +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc b/runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc new file mode 100644 index 000000000000..c9be80b79ca8 --- /dev/null +++ b/runtime/src/iree/vm/test/bytecode/bytecode_module_test.cc @@ -0,0 +1,94 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree/base/api.h" +#include "iree/vm/api.h" +#include "iree/vm/bytecode/module.h" +#include "iree/vm/test/all_bytecode_modules.h" +#include "iree/vm/testing/test_runner.h" +#include "iree/vm/testing/yieldable_test_module.h" + +IREE_VM_TEST_RUNNER_STATIC_STORAGE(); + +namespace iree::vm::testing { +namespace { + +std::vector GetBytecodeTestParams() { + std::vector test_params; + + // Prerequisite factory for modules that import from yieldable_test. + auto yieldable_test_factory = [](iree_vm_instance_t* inst, + iree_vm_module_t** out_mod) { + return yieldable_test_module_create(inst, iree_allocator_system(), out_mod); + }; + + iree_vm_instance_t* instance = nullptr; + IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, + iree_allocator_system(), &instance)); + + const struct iree_file_toc_t* module_file_toc = + all_bytecode_modules_c_create(); + for (size_t i = 0; i < all_bytecode_modules_c_size(); ++i) { + const auto& module_file = module_file_toc[i]; + std::string module_name(module_file.name); + + iree_vm_module_t* module = nullptr; + IREE_CHECK_OK(iree_vm_bytecode_module_create( + instance, IREE_VM_BYTECODE_MODULE_FLAG_NONE, + iree_const_byte_span_t{ + reinterpret_cast(module_file.data), + static_cast(module_file.size)}, + iree_allocator_null(), iree_allocator_system(), &module)); + + iree_vm_module_signature_t signature = iree_vm_module_signature(module); + for (iree_host_size_t j = 0; j < signature.export_function_count; ++j) { + iree_vm_function_t function; + IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal( + module, IREE_VM_FUNCTION_LINKAGE_EXPORT, j, &function)); + iree_string_view_t function_name = iree_vm_function_name(&function); + std::string fn_name(function_name.data, function_name.size); + + // Capture module data for lambda. + const void* data = module_file.data; + iree_host_size_t size = module_file.size; + + std::vector prereqs; + prereqs.push_back(yieldable_test_factory); + + test_params.push_back({ + module_name, + fn_name, + [data, size](iree_vm_instance_t* inst, iree_vm_module_t** out_mod) { + return iree_vm_bytecode_module_create( + inst, IREE_VM_BYTECODE_MODULE_FLAG_NONE, + iree_const_byte_span_t{reinterpret_cast(data), + static_cast(size)}, + iree_allocator_null(), iree_allocator_system(), out_mod); + }, + /*expects_failure=*/fn_name.find("fail_") == 0, + /*prerequisite_modules=*/prereqs, + }); + } + iree_vm_module_release(module); + } + + iree_vm_instance_release(instance); + return test_params; +} + +class VMBytecodeTest : public VMTestRunner<> {}; + +IREE_VM_TEST_F(VMBytecodeTest) + +INSTANTIATE_TEST_SUITE_P(bytecode, VMBytecodeTest, + ::testing::ValuesIn(GetBytecodeTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace iree::vm::testing diff --git a/runtime/src/iree/vm/test/emitc/BUILD.bazel b/runtime/src/iree/vm/test/emitc/BUILD.bazel index 1516cbf796a0..a19aebabe3d2 100644 --- a/runtime/src/iree/vm/test/emitc/BUILD.bazel +++ b/runtime/src/iree/vm/test/emitc/BUILD.bazel @@ -8,13 +8,14 @@ load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_test") load("//build_tools/bazel:iree_c_module.bzl", "iree_c_module") package( + default_visibility = ["//runtime/src/iree/vm:__subpackages__"], features = ["layering_check"], licenses = ["notice"], # Apache 2.0 ) iree_runtime_cc_test( - name = "module_test", - srcs = ["module_test.cc"], + name = "emitc_module_test", + srcs = ["emitc_module_test.cc"], deps = [ ":arithmetic_ops", ":arithmetic_ops_f32", @@ -46,6 +47,7 @@ iree_runtime_cc_test( "//runtime/src/iree/vm:ops", "//runtime/src/iree/vm:ops_emitc", "//runtime/src/iree/vm:shims_emitc", + "//runtime/src/iree/vm/testing:test_runner", ], ) diff --git a/runtime/src/iree/vm/test/emitc/CMakeLists.txt b/runtime/src/iree/vm/test/emitc/CMakeLists.txt index c80da9ee6f81..d174730f9c6f 100644 --- a/runtime/src/iree/vm/test/emitc/CMakeLists.txt +++ b/runtime/src/iree/vm/test/emitc/CMakeLists.txt @@ -10,14 +10,10 @@ if(IREE_OUTPUT_FORMAT_C) iree_cc_test( NAME - module_test + emitc_module_test SRCS - "module_test.cc" + "emitc_module_test.cc" DEPS - iree::base - iree::testing::gtest - iree::testing::gtest_main - iree::vm ::arithmetic_ops ::arithmetic_ops_f32 ::arithmetic_ops_i64 @@ -41,6 +37,14 @@ iree_cc_test( ::ref_ops ::shift_ops ::shift_ops_i64 + iree::base + iree::testing::gtest + iree::testing::gtest_main + iree::vm + iree::vm::ops + iree::vm::ops_emitc + iree::vm::shims_emitc + iree::vm::testing::test_runner ) iree_c_module( diff --git a/runtime/src/iree/vm/test/emitc/emitc_module_test.cc b/runtime/src/iree/vm/test/emitc/emitc_module_test.cc new file mode 100644 index 000000000000..1830a76d80fb --- /dev/null +++ b/runtime/src/iree/vm/test/emitc/emitc_module_test.cc @@ -0,0 +1,122 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +// We should not be including C implementation-only headers in a C++ +// module like this. In order to make this work for the moment across +// runtime libraries that are strict, do a global using of the std namespace. +// EmitC is deprecated and will not be gaining any additional test support so +// this is an "as long as it works it's fine" compromise. +using namespace std; + +#include "iree/base/api.h" +#include "iree/vm/api.h" +#include "iree/vm/testing/test_runner.h" + +#define EMITC_IMPLEMENTATION +#include "iree/vm/test/emitc/arithmetic_ops.h" +#include "iree/vm/test/emitc/arithmetic_ops_f32.h" +#include "iree/vm/test/emitc/arithmetic_ops_i64.h" +#include "iree/vm/test/emitc/assignment_ops.h" +#include "iree/vm/test/emitc/assignment_ops_f32.h" +#include "iree/vm/test/emitc/assignment_ops_i64.h" +#include "iree/vm/test/emitc/buffer_ops.h" +#include "iree/vm/test/emitc/call_ops.h" +#include "iree/vm/test/emitc/comparison_ops.h" +#include "iree/vm/test/emitc/comparison_ops_f32.h" +#include "iree/vm/test/emitc/comparison_ops_i64.h" +#include "iree/vm/test/emitc/control_flow_ops.h" +#include "iree/vm/test/emitc/conversion_ops.h" +#include "iree/vm/test/emitc/conversion_ops_f32.h" +#include "iree/vm/test/emitc/conversion_ops_i64.h" +#include "iree/vm/test/emitc/global_ops.h" +#include "iree/vm/test/emitc/global_ops_f32.h" +#include "iree/vm/test/emitc/global_ops_i64.h" +#include "iree/vm/test/emitc/list_ops.h" +#include "iree/vm/test/emitc/list_variant_ops.h" +#include "iree/vm/test/emitc/ref_ops.h" +#include "iree/vm/test/emitc/shift_ops.h" +#include "iree/vm/test/emitc/shift_ops_i64.h" + +IREE_VM_TEST_RUNNER_STATIC_STORAGE(); + +namespace iree::vm::testing { +namespace { + +typedef iree_status_t (*emitc_create_fn_t)(iree_vm_instance_t*, + iree_allocator_t, + iree_vm_module_t**); + +struct EmitcModuleInfo { + iree_vm_native_module_descriptor_t descriptor; + emitc_create_fn_t create_fn; +}; + +std::vector GetEmitcTestParams() { + std::vector test_params; + + std::vector modules = { + {arithmetic_ops_descriptor_, arithmetic_ops_create}, + {arithmetic_ops_f32_descriptor_, arithmetic_ops_f32_create}, + {arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create}, + {assignment_ops_descriptor_, assignment_ops_create}, + {assignment_ops_f32_descriptor_, assignment_ops_f32_create}, + {assignment_ops_i64_descriptor_, assignment_ops_i64_create}, + {buffer_ops_descriptor_, buffer_ops_create}, + {call_ops_descriptor_, call_ops_create}, + {comparison_ops_descriptor_, comparison_ops_create}, + {comparison_ops_f32_descriptor_, comparison_ops_f32_create}, + {comparison_ops_i64_descriptor_, comparison_ops_i64_create}, + {control_flow_ops_descriptor_, control_flow_ops_create}, + {conversion_ops_descriptor_, conversion_ops_create}, + {conversion_ops_f32_descriptor_, conversion_ops_f32_create}, + {conversion_ops_i64_descriptor_, conversion_ops_i64_create}, + {global_ops_descriptor_, global_ops_create}, + {global_ops_f32_descriptor_, global_ops_f32_create}, + {global_ops_i64_descriptor_, global_ops_i64_create}, + {list_ops_descriptor_, list_ops_create}, + {list_variant_ops_descriptor_, list_variant_ops_create}, + {ref_ops_descriptor_, ref_ops_create}, + {shift_ops_descriptor_, shift_ops_create}, + {shift_ops_i64_descriptor_, shift_ops_i64_create}, + }; + + for (const auto& mod : modules) { + std::string module_name(mod.descriptor.name.data, mod.descriptor.name.size); + emitc_create_fn_t create_fn = mod.create_fn; + + for (iree_host_size_t i = 0; i < mod.descriptor.export_count; ++i) { + const iree_vm_native_export_descriptor_t& export_desc = + mod.descriptor.exports[i]; + std::string fn_name(export_desc.local_name.data, + export_desc.local_name.size); + test_params.push_back({ + module_name, + fn_name, + [create_fn](iree_vm_instance_t* inst, iree_vm_module_t** out_mod) { + return create_fn(inst, iree_allocator_system(), out_mod); + }, + /*expects_failure=*/fn_name.find("fail_") == 0, + /*prerequisite_modules=*/{}, + }); + } + } + + return test_params; +} + +class VMEmitcTest : public VMTestRunner<> {}; + +IREE_VM_TEST_F(VMEmitcTest) + +INSTANTIATE_TEST_SUITE_P(emitc, VMEmitcTest, + ::testing::ValuesIn(GetEmitcTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace iree::vm::testing diff --git a/runtime/src/iree/vm/test/emitc/module_test.cc b/runtime/src/iree/vm/test/emitc/module_test.cc deleted file mode 100644 index fd73e3044c73..000000000000 --- a/runtime/src/iree/vm/test/emitc/module_test.cc +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// TODO: We should not be including C implementation-only headers in a C++ -// module like this. In order to make this work for the moment across -// runtime libraries that are strict, do a global using of the std namespace. -// See #7605 -#include -using namespace std; - -#include "iree/base/api.h" -#include "iree/testing/gtest.h" -#include "iree/vm/api.h" -#define EMITC_IMPLEMENTATION -#include "iree/vm/test/emitc/arithmetic_ops.h" -#include "iree/vm/test/emitc/arithmetic_ops_f32.h" -#include "iree/vm/test/emitc/arithmetic_ops_i64.h" -#include "iree/vm/test/emitc/assignment_ops.h" -#include "iree/vm/test/emitc/assignment_ops_f32.h" -#include "iree/vm/test/emitc/assignment_ops_i64.h" -#include "iree/vm/test/emitc/buffer_ops.h" -#include "iree/vm/test/emitc/call_ops.h" -#include "iree/vm/test/emitc/comparison_ops.h" -#include "iree/vm/test/emitc/comparison_ops_f32.h" -#include "iree/vm/test/emitc/comparison_ops_i64.h" -#include "iree/vm/test/emitc/control_flow_ops.h" -#include "iree/vm/test/emitc/conversion_ops.h" -#include "iree/vm/test/emitc/conversion_ops_f32.h" -#include "iree/vm/test/emitc/conversion_ops_i64.h" -#include "iree/vm/test/emitc/global_ops.h" -#include "iree/vm/test/emitc/global_ops_f32.h" -#include "iree/vm/test/emitc/global_ops_i64.h" -#include "iree/vm/test/emitc/list_ops.h" -#include "iree/vm/test/emitc/list_variant_ops.h" -#include "iree/vm/test/emitc/ref_ops.h" -#include "iree/vm/test/emitc/shift_ops.h" -#include "iree/vm/test/emitc/shift_ops_i64.h" - -namespace { - -typedef iree_status_t (*create_function_t)(iree_vm_instance_t*, - iree_allocator_t, - iree_vm_module_t**); - -struct TestParams { - std::string module_name; - std::string local_name; - create_function_t create_function; -}; - -struct ModuleDescription { - iree_vm_native_module_descriptor_t descriptor; - create_function_t create_function; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - std::string qualified_name = params.module_name + "." + params.local_name; - auto name_sv = - iree_make_string_view(qualified_name.data(), qualified_name.size()); - iree_string_view_replace_char(name_sv, ':', '_'); - iree_string_view_replace_char(name_sv, '.', '_'); - return os << qualified_name; -} - -std::vector GetModuleTestParams() { - std::vector test_params; - - // TODO(simon-camp): get these automatically - std::vector modules = { - {arithmetic_ops_descriptor_, arithmetic_ops_create}, - {arithmetic_ops_f32_descriptor_, arithmetic_ops_f32_create}, - {arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create}, - {assignment_ops_descriptor_, assignment_ops_create}, - {assignment_ops_f32_descriptor_, assignment_ops_f32_create}, - {assignment_ops_i64_descriptor_, assignment_ops_i64_create}, - {buffer_ops_descriptor_, buffer_ops_create}, - {call_ops_descriptor_, call_ops_create}, - {comparison_ops_descriptor_, comparison_ops_create}, - {comparison_ops_f32_descriptor_, comparison_ops_f32_create}, - {comparison_ops_i64_descriptor_, comparison_ops_i64_create}, - {control_flow_ops_descriptor_, control_flow_ops_create}, - {conversion_ops_descriptor_, conversion_ops_create}, - {conversion_ops_f32_descriptor_, conversion_ops_f32_create}, - {conversion_ops_i64_descriptor_, conversion_ops_i64_create}, - {global_ops_descriptor_, global_ops_create}, - {global_ops_f32_descriptor_, global_ops_f32_create}, - {global_ops_i64_descriptor_, global_ops_i64_create}, - {list_ops_descriptor_, list_ops_create}, - {list_variant_ops_descriptor_, list_variant_ops_create}, - {ref_ops_descriptor_, ref_ops_create}, - {shift_ops_descriptor_, shift_ops_create}, - {shift_ops_i64_descriptor_, shift_ops_i64_create}}; - - for (size_t i = 0; i < modules.size(); i++) { - iree_vm_native_module_descriptor_t descriptor = modules[i].descriptor; - create_function_t function = modules[i].create_function; - - std::string module_name = - std::string(descriptor.name.data, descriptor.name.size); - - for (iree_host_size_t i = 0; i < descriptor.export_count; i++) { - iree_vm_native_export_descriptor_t export_descriptor = - descriptor.exports[i]; - std::string local_name = std::string(export_descriptor.local_name.data, - export_descriptor.local_name.size); - test_params.push_back({module_name, local_name, function}); - } - } - - return test_params; -} - -class VMCModuleTest : public ::testing::Test, - public ::testing::WithParamInterface { - protected: - virtual void SetUp() { - const auto& test_params = GetParam(); - - IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, - iree_allocator_system(), &instance_)); - - iree_vm_module_t* module_ = nullptr; - IREE_CHECK_OK(test_params.create_function( - instance_, iree_allocator_system(), &module_)); - - std::vector modules = {module_}; - IREE_CHECK_OK(iree_vm_context_create_with_modules( - instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), - iree_allocator_system(), &context_)); - - iree_vm_module_release(module_); - } - - virtual void TearDown() { - iree_vm_context_release(context_); - iree_vm_instance_release(instance_); - } - - iree_status_t RunFunction(std::string module_name, std::string local_name) { - std::string qualified_name = module_name + "." + local_name; - iree_vm_function_t function; - IREE_CHECK_OK(iree_vm_context_resolve_function( - context_, - iree_string_view_t{qualified_name.data(), static_cast( - qualified_name.size())}, - &function)); - - return iree_vm_invoke(context_, function, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/nullptr, /*inputs=*/nullptr, - /*outputs=*/nullptr, iree_allocator_system()); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_vm_context_t* context_ = nullptr; -}; - -TEST_P(VMCModuleTest, Check) { - const auto& test_params = GetParam(); - bool expect_failure = test_params.local_name.find("fail_") == 0; - - iree::Status result = - RunFunction(test_params.module_name, test_params.local_name); - if (result.ok()) { - if (expect_failure) { - GTEST_FAIL() << "Function expected failure but succeeded"; - } else { - GTEST_SUCCEED(); - } - } else { - if (expect_failure) { - GTEST_SUCCEED(); - } else { - GTEST_FAIL() << "Function expected success but failed with error: " - << result.ToString(); - } - } -} - -INSTANTIATE_TEST_SUITE_P(VMIRFunctions, VMCModuleTest, - ::testing::ValuesIn(GetModuleTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace diff --git a/runtime/src/iree/vm/testing/BUILD.bazel b/runtime/src/iree/vm/testing/BUILD.bazel new file mode 100644 index 000000000000..f2623cac8524 --- /dev/null +++ b/runtime/src/iree/vm/testing/BUILD.bazel @@ -0,0 +1,35 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library") + +package( + default_visibility = ["//runtime/src/iree/vm:__subpackages__"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_runtime_cc_library( + name = "test_runner", + testonly = True, + srcs = ["test_runner.cc"], + hdrs = ["test_runner.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/vm", + ], +) + +iree_runtime_cc_library( + name = "yieldable_test_module", + testonly = True, + hdrs = ["yieldable_test_module.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/vm", + ], +) diff --git a/runtime/src/iree/vm/testing/CMakeLists.txt b/runtime/src/iree/vm/testing/CMakeLists.txt new file mode 100644 index 000000000000..e967460df364 --- /dev/null +++ b/runtime/src/iree/vm/testing/CMakeLists.txt @@ -0,0 +1,40 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# runtime/src/iree/vm/testing/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + test_runner + HDRS + "test_runner.h" + SRCS + "test_runner.cc" + DEPS + iree::base + iree::testing::gtest + iree::vm + TESTONLY + PUBLIC +) + +iree_cc_library( + NAME + yieldable_test_module + HDRS + "yieldable_test_module.h" + DEPS + iree::base + iree::vm + TESTONLY + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/vm/testing/test_runner.cc b/runtime/src/iree/vm/testing/test_runner.cc new file mode 100644 index 000000000000..4696039fb4c3 --- /dev/null +++ b/runtime/src/iree/vm/testing/test_runner.cc @@ -0,0 +1,20 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/vm/testing/test_runner.h" + +namespace iree::vm::testing { + +std::ostream& operator<<(std::ostream& os, const VMTestParams& params) { + std::string name = params.module_name + "_" + params.function_name; + // Replace special characters for valid test names. + for (char& c : name) { + if (c == ':' || c == '.') c = '_'; + } + return os << name; +} + +} // namespace iree::vm::testing diff --git a/runtime/src/iree/vm/testing/test_runner.h b/runtime/src/iree/vm/testing/test_runner.h new file mode 100644 index 000000000000..0ee371f19fb8 --- /dev/null +++ b/runtime/src/iree/vm/testing/test_runner.h @@ -0,0 +1,225 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Shared test framework for VM module testing. +// +// This framework provides a common test runner that works across different +// VM module implementations (bytecode interpreter, EmitC, JIT, etc.). +// Tests are defined in MLIR files under iree/vm/test/ and compiled to +// different formats per backend. + +#ifndef IREE_VM_TESTING_TEST_RUNNER_H_ +#define IREE_VM_TESTING_TEST_RUNNER_H_ + +#include +#include +#include + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" +#include "iree/vm/api.h" + +namespace iree::vm::testing { + +//===----------------------------------------------------------------------===// +// VMTestParams +//===----------------------------------------------------------------------===// +// Parameters for a single VM test function. + +// Module creation function type. +// Different backends implement this differently: +// - Bytecode: loads from embedded binary data +// - EmitC: calls static _create() function +// - JIT: compiles and loads at runtime +using VMModuleCreateFn = + std::function; + +// Parameters describing a single test to run. +struct VMTestParams { + // Module name (e.g., "arithmetic_ops"). + std::string module_name; + // Function name within the module (e.g., "test_add_i32"). + std::string function_name; + // Factory function to create the module under test. + VMModuleCreateFn create_module; + // Whether this function is expected to fail (fail_ prefix). + bool expects_failure = false; + // Factory functions for prerequisite modules that must be loaded before the + // module under test. These are loaded in order and added to the context + // first. Examples: native yieldable test module, HAL module, custom import + // modules. + std::vector prerequisite_modules; +}; + +// Allows test names to be printed nicely in gtest output. +std::ostream& operator<<(std::ostream& os, const VMTestParams& params); + +//===----------------------------------------------------------------------===// +// VMTestResources +//===----------------------------------------------------------------------===// +// Static resources shared across all tests in a suite. + +class VMTestResources { + public: + static iree_vm_instance_t* instance_; +}; + +//===----------------------------------------------------------------------===// +// VMTestRunner +//===----------------------------------------------------------------------===// +// Base test fixture for VM module testing. +// +// Usage: +// 1. Backend-specific test files include this header +// 2. Backend implements GetTestParams() returning vector +// 3. INSTANTIATE_TEST_SUITE_P with the params +// +// The runner automatically: +// - Creates VM instance/context +// - Loads the module under test +// - Optionally loads the native yieldable test module +// - Executes functions and checks results +// - Handles async/yieldable functions transparently + +template +class VMTestRunner : public BaseType, + public ::testing::WithParamInterface, + public VMTestResources { + public: + static void SetUpTestSuite() { + IREE_ASSERT_OK(iree_vm_instance_create( + IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance_)); + } + + static void TearDownTestSuite() { + if (instance_) { + iree_vm_instance_release(instance_); + instance_ = nullptr; + } + } + + void SetUp() override { + const auto& params = this->GetParam(); + + // Build module list for context. + std::vector modules; + + // Create and add prerequisite modules first (in order). + for (const auto& create_fn : params.prerequisite_modules) { + iree_vm_module_t* prereq_module = nullptr; + IREE_ASSERT_OK(create_fn(instance_, &prereq_module)); + prerequisite_modules_.push_back(prereq_module); + modules.push_back(prereq_module); + } + + // Create the module under test and add last. + IREE_ASSERT_OK(params.create_module(instance_, &test_module_)); + modules.push_back(test_module_); + + IREE_ASSERT_OK(iree_vm_context_create_with_modules( + instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), + iree_allocator_system(), &context_)); + } + + void TearDown() override { + if (context_) { + iree_vm_context_release(context_); + context_ = nullptr; + } + if (test_module_) { + iree_vm_module_release(test_module_); + test_module_ = nullptr; + } + for (auto* module : prerequisite_modules_) { + iree_vm_module_release(module); + } + prerequisite_modules_.clear(); + } + + // Runs a function by name. + // Handles DEFERRED status by resuming until completion. + // NOTE: Only supports void-returning functions; test functions should perform + // internal assertions via vm.check.* ops rather than returning values. + iree_status_t RunFunction(const char* function_name) { + iree_vm_function_t function; + IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name( + test_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, + iree_make_cstring_view(function_name), &function)); + + IREE_VM_INLINE_STACK_INITIALIZE(stack, IREE_VM_CONTEXT_FLAG_NONE, + iree_vm_context_state_resolver(context_), + iree_allocator_system()); + + iree_vm_function_call_t call; + memset(&call, 0, sizeof(call)); + call.function = function; + + iree_status_t status = + function.module->begin_call(function.module->self, stack, call); + + // Resume until completion. + // Limit iterations to catch infinite yield loops in tests. + constexpr int kMaxResumeCount = 10000; + int resume_count = 0; + while (iree_status_code(status) == IREE_STATUS_DEFERRED) { + iree_status_ignore(status); + if (++resume_count > kMaxResumeCount) { + iree_vm_stack_deinitialize(stack); + return iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "resume limit (%d) exceeded for function '%s'; possible infinite " + "yield loop", + kMaxResumeCount, function_name); + } + status = function.module->resume_call(function.module->self, stack, + call.results); + } + + iree_vm_stack_deinitialize(stack); + return status; + } + + protected: + iree_vm_context_t* context_ = nullptr; + iree_vm_module_t* test_module_ = nullptr; + std::vector prerequisite_modules_; +}; + +// Storage for static members. +// Note: This must only be included in one translation unit per test binary. +// The generated test template will include this. +#define IREE_VM_TEST_RUNNER_STATIC_STORAGE() \ + namespace iree::vm::testing { \ + /*static*/ iree_vm_instance_t* VMTestResources::instance_ = nullptr; \ + } + +//===----------------------------------------------------------------------===// +// Standard Test Macros +//===----------------------------------------------------------------------===// +// The parameterized test that runs each function. + +#define IREE_VM_TEST_F(test_class) \ + TEST_P(test_class, Check) { \ + const auto& params = GetParam(); \ + iree_status_t status = RunFunction(params.function_name.c_str()); \ + if (iree_status_is_ok(status)) { \ + if (params.expects_failure) { \ + GTEST_FAIL() << "Function expected failure but succeeded"; \ + } \ + } else { \ + if (params.expects_failure) { \ + iree_status_ignore(status); \ + } else { \ + GTEST_FAIL() << "Function expected success but failed with error: " \ + << iree::Status(std::move(status)).ToString(); \ + } \ + } \ + } + +} // namespace iree::vm::testing + +#endif // IREE_VM_TESTING_TEST_RUNNER_H_ diff --git a/runtime/src/iree/vm/test/async_ops_test_module.h b/runtime/src/iree/vm/testing/yieldable_test_module.h similarity index 70% rename from runtime/src/iree/vm/test/async_ops_test_module.h rename to runtime/src/iree/vm/testing/yieldable_test_module.h index 2bab57e3f227..7d841f2239d3 100644 --- a/runtime/src/iree/vm/test/async_ops_test_module.h +++ b/runtime/src/iree/vm/testing/yieldable_test_module.h @@ -4,16 +4,32 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// A simple native module for testing vm.call.yieldable to imports. +// A native module for testing vm.call.yieldable to imports. // Exports a single function that yields N times before returning. +// +// This module provides a controlled way to test async/yield behavior: +// yield_n(arg: i32, yield_count: i32) -> i32 +// Returns arg + yield_count after yielding yield_count times. +// +// NOTE: This module stores coroutine state (yield_count, accumulator) in module +// state, which means it is not reentrant. Concurrent or interleaved calls on +// the same module instance through the same VM context are not supported. This +// is consistent with IREE's threading model where modules are thread-compatible +// (safe for sequential use) but not thread-safe (no concurrent access). + +#ifndef IREE_VM_TESTING_YIELDABLE_TEST_MODULE_H_ +#define IREE_VM_TESTING_YIELDABLE_TEST_MODULE_H_ #include "iree/base/api.h" #include "iree/vm/native_module.h" +#ifdef __cplusplus +extern "C" { +#endif + //===----------------------------------------------------------------------===// // yieldable_test_module //===----------------------------------------------------------------------===// -// Native module with a single yieldable function for testing. typedef struct yieldable_test_module_state_t { iree_allocator_t allocator; @@ -37,18 +53,50 @@ static iree_status_t yieldable_test_module_yield_variadic_sum_shim( // Parse variadic arguments. // Layout: [segment_count: i32] [values: i32 * segment_count] [yield_count: // i32] + + // Validate minimum size for segment_count field. + if (args_storage.data_length < sizeof(int32_t)) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "argument buffer too small for segment_count; have %" PRIhsz + " bytes, need at least %" PRIhsz, + args_storage.data_length, sizeof(int32_t)); + } + const uint8_t* p = args_storage.data; - int32_t segment_count = *(const int32_t*)p; + int32_t segment_count; + memcpy(&segment_count, p, sizeof(int32_t)); p += sizeof(int32_t); + // Validate segment_count is non-negative and buffer has sufficient space. + // Required size: segment_count (1) + values (segment_count) + yield_count + // (1). + if (segment_count < 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "segment_count must be non-negative; got %d", + segment_count); + } + iree_host_size_t required_size = + (iree_host_size_t)(segment_count + 2) * sizeof(int32_t); + if (args_storage.data_length < required_size) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "argument buffer too small for %d variadic args; have %" PRIhsz + " bytes, need %" PRIhsz, + segment_count, args_storage.data_length, required_size); + } + // Sum all variadic values. int32_t sum = 0; for (int32_t i = 0; i < segment_count; ++i) { - sum += *(const int32_t*)p; + int32_t value; + memcpy(&value, p, sizeof(int32_t)); + sum += value; p += sizeof(int32_t); } - int32_t yield_count = *(const int32_t*)p; + int32_t yield_count; + memcpy(&yield_count, p, sizeof(int32_t)); // Initialize state. state->yield_count = yield_count; @@ -99,11 +147,22 @@ static iree_status_t yieldable_test_module_yield_n_shim( int32_t arg; int32_t yield_count; } args_t; - const args_t* args = (const args_t*)args_storage.data; + + // Validate buffer size. + if (args_storage.data_length < sizeof(args_t)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument buffer too small; have %" PRIhsz + " bytes, need %" PRIhsz, + args_storage.data_length, sizeof(args_t)); + } + + // Use memcpy for alignment-safe access. + args_t args; + memcpy(&args, args_storage.data, sizeof(args_t)); // Initialize state for coroutine. - state->yield_count = args->yield_count; - state->accumulator = args->arg; + state->yield_count = args.yield_count; + state->accumulator = args.arg; if (state->yield_count > 0) { state->accumulator += 1; @@ -200,3 +259,9 @@ static iree_status_t yieldable_test_module_create( &yieldable_test_module_descriptor_, instance, allocator, out_module); } + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_VM_TESTING_YIELDABLE_TEST_MODULE_H_ From 47b551118450e140a37d600f73d010ff7d29e80e Mon Sep 17 00:00:00 2001 From: Vivian Zhang Date: Tue, 13 Jan 2026 10:31:27 -0800 Subject: [PATCH 26/71] [Preprocessing] Refine checks in ConvertConvFilterToChannelsLast (#23103) The main purpose of the check in this pass is to exclude the weight backward convs from doing this conversion. Normally the layout out for weight backwards are `CHWN-CHWF`. Modify the condition to also avoid converting filter layout when it's `CNHW-CFHW`. Signed-off-by: yzhang93 --- .../Common/ConvertConvFilterToChannelsLast.cpp | 12 +++++++++--- .../test/conv_filter_to_channels_last.mlir | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp index 8b03a4bee45e..fec1c95531b2 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp @@ -122,6 +122,13 @@ struct ConvertGenericFilterToFhwc : public OpRewritePattern { return failure(); } + // Require non-empty filter, input and output channel dimensions. + if (convolutionDims->outputChannel.empty() || + convolutionDims->inputChannel.empty() || + convolutionDims->filterLoop.empty()) { + return failure(); + } + OpOperand *input = linalgOp.getDpsInputOperand(0); OpOperand *filter = linalgOp.getDpsInputOperand(1); OpOperand *output = linalgOp.getDpsInitOperand(0); @@ -161,11 +168,10 @@ struct ConvertGenericFilterToFhwc : public OpRewritePattern { return positions; }; - // Don't transpose when the input is in batch-last layout (e.g., CHWN). + // Don't transpose when the input is in not batch-first layout (e.g., CHWN). SmallVector batchInputPos = getDimPositions(convolutionDims->batch, inputMap); - if (!batchInputPos.empty() && - batchInputPos.back() == inputShape.size() - 1) { + if (!batchInputPos.empty() && batchInputPos.front() != 0) { return failure(); } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir index 16fc3177e99f..36e118dc314c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_filter_to_channels_last.mlir @@ -163,6 +163,24 @@ util.func public @conv_2d_chwn_chwf_no_transpose(%arg0: tensor<16x26x18x288xf32> // ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +util.func public @conv_2d_cnhw_cfhw_no_transpose(%arg0: tensor<16x288x26x18xf32>, %arg1: tensor<16x288x24x16xf32>, %arg2: tensor<288x288x3x3xf32>) -> tensor<288x288x3x3xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x288x26x18xf32>, tensor<16x288x24x16xf32>) outs(%arg2 : tensor<288x288x3x3xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %12 = arith.mulf %in, %in_3 : f32 + %13 = arith.addf %out, %12 : f32 + linalg.yield %13 : f32 + } -> tensor<288x288x3x3xf32> + util.return %0 : tensor<288x288x3x3xf32> +} + +// CHECK-FHWC-LABEL: @conv_2d_cnhw_cfhw_no_transpose +// CHECK-FHWC-NOT: linalg.transpose + +// ----- + #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> From de9fddb3cb308c4058a4a771cc8da06a46a08464 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Tue, 13 Jan 2026 14:47:30 -0500 Subject: [PATCH 27/71] Update iree-test-suites (#23112) * update iree-test-suite ref across the repo * add golden times for square gemm benchmarks This update will allow downloading from hugging face and not from github lfs for onnx models. --- .github/workflows/pkgci_test_onnx.yml | 4 +-- .github/workflows/pkgci_test_sharktank.yml | 4 +-- .github/workflows/pkgci_test_torch.yml | 4 +-- .../torch_ops/torch_ops_cpu_llvm_sync.json | 17 ++++++++++-- .../torch_ops_gpu_hip_gfx1100_O3.json | 15 ++++++++--- .../torch_ops_gpu_hip_gfx942_O3.json | 13 +++++++-- .../torch_ops/torch_ops_gpu_vulkan_O3.json | 27 ++++++++++++------- 7 files changed, 62 insertions(+), 22 deletions(-) diff --git a/.github/workflows/pkgci_test_onnx.yml b/.github/workflows/pkgci_test_onnx.yml index 08293275f575..db1e1ddb1571 100644 --- a/.github/workflows/pkgci_test_onnx.yml +++ b/.github/workflows/pkgci_test_onnx.yml @@ -103,7 +103,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 path: iree-test-suites - name: Install ONNX ops test suite requirements run: | @@ -189,7 +189,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 path: iree-test-suites - name: Install ONNX models test suite requirements run: | diff --git a/.github/workflows/pkgci_test_sharktank.yml b/.github/workflows/pkgci_test_sharktank.yml index f905e894e6f6..4f048bcb52d1 100644 --- a/.github/workflows/pkgci_test_sharktank.yml +++ b/.github/workflows/pkgci_test_sharktank.yml @@ -88,7 +88,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 path: iree-test-suites lfs: true @@ -197,7 +197,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 path: iree-test-suites lfs: true diff --git a/.github/workflows/pkgci_test_torch.yml b/.github/workflows/pkgci_test_torch.yml index a98d7b924c80..6780ef009ceb 100644 --- a/.github/workflows/pkgci_test_torch.yml +++ b/.github/workflows/pkgci_test_torch.yml @@ -74,7 +74,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 132f91e49d629c35f98492a9f619017b83782aba + ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 path: iree-test-suites - name: Install Torch ops test suite requirements run: | @@ -138,7 +138,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: 17a391dc3882f136e567bf2687806ef6af46ad64 + ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 path: iree-test-suites # Don't need lfs for torch models yet. lfs: false diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json index 9f064617587a..78bb259473b5 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json @@ -7,7 +7,20 @@ ], "iree_run_module_flags": [], "skip_compile_tests": [], - "skip_run_tests": [], + "skip_run_tests": [ + "AB/8192x8192xf32_bench", + "AB/4096x4096xf32_bench", + "AB/2048x2048xf32_bench" + ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 5587.488262355328, + "AB/1024x1024xf32_bench": 1.1874544098876767, + "AB/256x256xf32_bench": 0.044473891122477315, + "AB/128x128xf32_bench": 0.03577919309132721, + "AB/2048x2048xf32_bench": 10.41092509722771, + "AB/4096x4096xf32_bench": 131.7884701769799, + "AB/512x512xf32_bench": 0.12528392536236224 + } } diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json index 414cefe81cb0..79637b8c210c 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json @@ -10,9 +10,18 @@ ], "skip_compile_tests": [], "skip_run_tests": [ - "generated/test_a_b_plus_c_float16", - "generated/test_a_t_b_float16" + "ABPlusC/64x64xf16", + "ATB/64x64xf16" ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 211.36675303181013, + "AB/1024x1024xf32_bench": 0.473261977629155, + "AB/256x256xf32_bench": 0.13523232175050667, + "AB/128x128xf32_bench": 0.10226182854380102, + "AB/2048x2048xf32_bench": 3.35047472617589, + "AB/4096x4096xf32_bench": 26.499415814344378, + "AB/512x512xf32_bench": 0.16945170770798412 + } } diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json index 281b9e432a51..22d007362704 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json @@ -10,8 +10,17 @@ ], "skip_compile_tests": [], "skip_run_tests": [ - "generated/test_a_t_b_float16" + "ATB/64x64xf16" ], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 10.345919956763586, + "AB/1024x1024xf32_bench": 0.12306747736492638, + "AB/256x256xf32_bench": 0.06101156551354003, + "AB/128x128xf32_bench": 0.052202587451140196, + "AB/2048x2048xf32_bench": 0.2345012331137894, + "AB/4096x4096xf32_bench": 1.423236482683913, + "AB/512x512xf32_bench": 0.07336693902980182 + } } diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json index a2d406c149cd..87e32aba227f 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json @@ -8,15 +8,24 @@ "--device=vulkan" ], "skip_compile_tests": [ - "generated/test_a_t_b_float16", - "generated/test_a_b_plus_c_float16", - "generated/test_a_b_t_float16", - "generated/test_relu_a_b_plus_c_float16", - "generated/test_gelu_a_b_plus_c_float16" - ], - "skip_run_tests": [ - "generated/test_a_b_float16" + "ATB/64x64xf16", + "ABPlusC/64x64xf16", + "ABT/64x64xf16", + "ReluABPlusC/64x64xf16", + "GeluABPlusC/64x64xf16", + "AB/64x64xf16", + "AB/Nx64xf16_64xNxf16" ], + "skip_run_tests": [], "expected_compile_failures": [], - "expected_run_failures": [] + "expected_run_failures": [], + "golden_times_ms": { + "AB/8192x8192xf32_bench": 107.60851647438749, + "AB/1024x1024xf32_bench": 0.4509026762196051, + "AB/256x256xf32_bench": 0.1743873563575457, + "AB/128x128xf32_bench": 0.148048073505022, + "AB/2048x2048xf32_bench": 1.5943956199949283, + "AB/4096x4096xf32_bench": 10.252960922347533, + "AB/512x512xf32_bench": 0.23526767679662958 + } } From 1e45fd44af06655139efb91e0c3923b22dbdd4f1 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 13 Jan 2026 12:12:19 -0800 Subject: [PATCH 28/71] [LLVMGPU] Handle mixed-precision matmuls by returning nullopt (#23116) Previously, `getMmaScheduleFromProblemAndTarget` would assert that both matmul operands have the same type (aType == bType). This assertion fails for mixed-precision operations like f32 x bf16, causing crashes in debug builds. This change replaces the assertion with an explicit check that returns `std::nullopt` when operand types don't match, making the behavior consistent between debug and release builds. Mixed-precision operations will fall back to non-MMA tile-and-fuse configurations. This adds a test case in `config_tile_and_fuse.mlir` to verify that mixed-precision matmuls compile without crashing and use the non-MMA configuration. Workaround for https://github.com/iree-org/iree/issues/23040 Signed-off-by: Ian Wood --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 5 ++-- .../test/ROCDL/config_tile_and_fuse.mlir | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 7a65414761bb..18ea5f946f75 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -379,8 +379,9 @@ static std::optional getMmaScheduleFromProblemAndTarget( return std::nullopt; } - assert(problem.aType == problem.bType && - "expected the same aType and bType."); + if (problem.aType != problem.bType) { + return std::nullopt; + } GemmCutoff gemmCutoffs = computeGemmCutoffsForAI(target, problem.aType, scaled); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 179e5bb577c9..94f4164a6d92 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -1112,3 +1112,31 @@ func.func @aligned_matmul_biasadd(%lhs : tensor<512x512xf16>, %rhs : tensor<512x // CHECK-LABEL: func.func @aligned_matmul_biasadd( // CHECK: promote_operands = [0, 1] + +// ----- + +// Currently falls back to non-MMA path since MMA intrinsics require matching +// operand types. +func.func @mixed_precision_matmul_f32xbf16(%lhs: tensor<16x64xf32>, %rhs: tensor<64x32xbf16>) -> tensor<16x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<16x32xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<16x32xf32>) -> tensor<16x32xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor<16x64xf32>, tensor<64x32xbf16>) + outs(%fill : tensor<16x32xf32>) { + ^bb0(%in: f32, %in_0: bf16, %out: f32): + %0 = arith.extf %in_0 : bf16 to f32 + %1 = arith.mulf %in, %0 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } -> tensor<16x32xf32> + return %result : tensor<16x32xf32> +} + +// CHECK-LABEL: func.func @mixed_precision_matmul_f32xbf16( +// CHECK-SAME: #iree_codegen.translation_info +// CHECK-NOT: mma_kind From 4a82fc2c890ea9b001342d9e2f4b1cc8013eccf8 Mon Sep 17 00:00:00 2001 From: Vivian Zhang Date: Tue, 13 Jan 2026 17:08:50 -0800 Subject: [PATCH 29/71] Move SinkTransposeThroughPad preprocessing pass to PropagateLinalgTranspose (#23080) Moved the `SinkTransposeThroughPad` pattern from the preprocessing `SinkTransposeThroughPadPass` to the global optimization `PropagateLinalgTransposePass`. This enables the pattern to work with the existing `SinkTransposeThroughExpandShape` pattern, allowing transposes to sink through `transpose->expand_shape->pad` sequences.' The pattern is enabled behind a flag to avoid other models runtime regression. The flag `iree-global-opt-enable-sink-transpose-through-pad` is added to pipeline options. --------- Signed-off-by: yzhang93 --- .../compiler/GlobalOptimization/Passes.cpp | 2 + .../iree/compiler/GlobalOptimization/Passes.h | 6 ++ .../compiler/GlobalOptimization/Passes.td | 2 + .../PropagateLinalgTranspose.cpp | 49 ++++++++++ .../test/propagate_linalg_transpose.mlir | 57 +++++++++++ .../src/iree/compiler/Pipelines/Options.cpp | 5 + .../src/iree/compiler/Pipelines/Options.h | 3 + .../src/iree/compiler/Pipelines/Pipelines.cpp | 2 + .../compiler/Preprocessing/Common/BUILD.bazel | 1 - .../Preprocessing/Common/CMakeLists.txt | 1 - .../compiler/Preprocessing/Common/Passes.td | 9 -- .../Common/SinkTransposeThroughPad.cpp | 98 ------------------- .../Preprocessing/Common/test/BUILD.bazel | 1 - .../Preprocessing/Common/test/CMakeLists.txt | 1 - .../test/sink_transpose_through_pad.mlir | 17 ---- 15 files changed, 126 insertions(+), 128 deletions(-) delete mode 100644 compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp delete mode 100644 compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index b756902b9b10..4f38d12c88a9 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -180,6 +180,8 @@ void buildGlobalOptimizationPassPipeline( transformOptions.aggressiveTransposePropagation; options.enableConvolutionPropagation = transformOptions.propagateTransposesThroughConv; + options.enableSinkTransposeThroughPad = + transformOptions.sinkTransposeThroughPad; options.enableAttentionVTranspose = clEnableAttentionVTranspose; options.enableEdgeReshapePropagation = diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index cf6a89135990..317e2615abc2 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -71,6 +71,12 @@ struct TransformOptions : public PassPipelineOptions { "Enables propagation of transpose ops through convolutions"), llvm::cl::init(false), }; + Option sinkTransposeThroughPad{ + *this, + "sink-transpose-through-pad", + llvm::cl::desc("Enables sinking transpose through pad operations"), + llvm::cl::init(false), + }; Option outerDimConcat{ *this, "outer-dim-concat", diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 4d89a5ed2273..1bf312fee2c7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -121,6 +121,8 @@ def PropagateLinalgTransposePass : /*default=*/"false", "enable propagation through convolutions">, Option<"enableEdgeReshapePropagation", "enable-edge-reshape-propagation", "bool", /*default=*/"false", "Enable propagation of reshapes on the edges of the program">, + Option<"enableSinkTransposeThroughPad", "enable-sink-transpose-through-pad", "bool", + /*default=*/"false", "Enable sinking transpose through pad operations">, ]; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 2e73a25774b4..a1223d8bb83b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -589,6 +589,52 @@ class SinkTransposeThroughExpandShape bool enableEdgeReshapePropagation = true; }; +// Sinks a transpose through a tensor.pad. +class SinkTransposeThroughPad : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + if (!IREE::Flow::isNonNullAndOutsideDispatch(padOp)) { + return failure(); + } + Value source = padOp.getSource(); + auto transposeOp = source.getDefiningOp(); + if (!transposeOp) { + return failure(); + } + + Block &block = padOp.getRegion().front(); + if (llvm::any_of(block.getArguments(), [](BlockArgument blockArg) { + return blockArg.getNumUses(); + })) { + return failure(); + } + + auto invPerm = invertPermutationVector(transposeOp.getPermutation()); + SmallVector lowSizes = padOp.getMixedLowPad(); + SmallVector highSizes = padOp.getMixedHighPad(); + applyPermutationToVector(lowSizes, invPerm); + applyPermutationToVector(highSizes, invPerm); + + RankedTensorType oldPaddedType = cast(padOp.getType()); + RankedTensorType newPaddedType = oldPaddedType.clone( + applyPermutation(oldPaddedType.getShape(), invPerm)); + + auto newPadOp = tensor::PadOp::create( + rewriter, padOp.getLoc(), newPaddedType, transposeOp.getInput(), + lowSizes, highSizes, padOp.getNofold()); + rewriter.cloneRegionBefore(padOp.getRegion(), newPadOp.getRegion(), + newPadOp.getRegion().begin()); + + Value newTransposeOp = + createTranspose(rewriter, newPadOp, transposeOp.getPermutation()); + rewriter.replaceOp(padOp, newTransposeOp); + return success(); + } +}; + // Fuses a transpose with the input of a linalg.generic op or contraction op. // Contraction ops are generalized and then treated as a generic. For example, // @@ -1292,6 +1338,9 @@ void PropagateLinalgTransposePass::runOnOperation() { sinkingPatterns.insert(context); sinkingPatterns.insert( context, enableEdgeReshapePropagation); + if (enableSinkTransposeThroughPad) { + sinkingPatterns.insert(context); + } sinkingPatterns.insert( context, enableAggressivePropagation, enableConvolutionPropagation); sinkingPatterns.insert(context); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index d9be97a6ba8e..49def08acc02 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -4,6 +4,7 @@ // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-bubbling-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=BUBBLE // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-aggressive-propagation-through-conv=true}))" --split-input-file %s | FileCheck %s --check-prefix=CONV // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-edge-reshape-propagation=true}))" %s -o - --split-input-file | FileCheck %s --check-prefix=ENABLE-EDGE-PROP +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-sink-transpose-through-pad=true}))" --split-input-file %s | FileCheck %s --check-prefix=SINK-PAD util.func public @specialize_transpose_op(%arg0 : tensor<1x2x3xf32>, %empty : tensor<3x2x1xf32>) -> tensor<3x2x1xf32> { @@ -1040,6 +1041,7 @@ util.func public @bubble_transpose_through_truncf_and_fuse_with_conv( // BUBBLE: linalg.generic // BUBBLE: } -> tensor<16x2x2x4xbf16> // BUBBLE-NOT: linalg.transpose +// BUBBLE: util.return // With enable-aggressive-propagation-through-conv, transpose is fully fused with conv. // CONV-LABEL: util.func public @bubble_transpose_through_truncf_and_fuse_with_conv @@ -1050,3 +1052,58 @@ util.func public @bubble_transpose_through_truncf_and_fuse_with_conv( // CONV: } -> tensor<16x2x2x4xbf16> // CONV-NOT: linalg.transpose // CONV: util.return %[[TRUNCF]] + +// ----- + +util.func public @sink_transpose_through_pad(%arg0: tensor<16x64x64x128xf16>) -> tensor<16x128x66x66xf16> { + %cst = arith.constant 0.000000e+00 : f16 + %empty = tensor.empty() : tensor<16x128x64x64xf16> + %transposed = linalg.transpose ins(%arg0 : tensor<16x64x64x128xf16>) outs(%empty : tensor<16x128x64x64xf16>) permutation = [0, 3, 1, 2] + %padded = tensor.pad %transposed low[0, 0, 1, 1] high[0, 0, 1, 1] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f16 + } : tensor<16x128x64x64xf16> to tensor<16x128x66x66xf16> + util.return %padded : tensor<16x128x66x66xf16> +} +// With enable-sink-transpose-through-pad=true, transpose sinks through pad. +// SINK-PAD-LABEL: util.func public @sink_transpose_through_pad +// SINK-PAD: %[[PAD:.+]] = tensor.pad +// SINK-PAD: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK-PAD-SAME: ins(%[[PAD]] +// SINK-PAD: util.return %[[TRANSPOSE]] + +// Without the flag, transpose does not sink through pad. +// SINK-LABEL: util.func public @sink_transpose_through_pad +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK: %[[PAD:.+]] = tensor.pad %[[TRANSPOSE]] +// SINK: util.return %[[PAD]] + +// ----- + +util.func public @sink_transpose_through_expand_shape_and_pad(%arg0: tensor<16x2x48x32x288xbf16>) -> tensor<16x3x96x4x48x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %empty = tensor.empty() : tensor<16x288x2x48x32xbf16> + %transposed = linalg.transpose ins(%arg0 : tensor<16x2x48x32x288xbf16>) outs(%empty : tensor<16x288x2x48x32xbf16>) permutation = [0, 4, 1, 2, 3] + %expanded = tensor.expand_shape %transposed [[0], [1, 2], [3], [4], [5]] output_shape [16, 3, 96, 2, 48, 32] : tensor<16x288x2x48x32xbf16> into tensor<16x3x96x2x48x32xbf16> + %padded = tensor.pad %expanded low[0, 0, 0, 1, 0, 0] high[0, 0, 0, 1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : bf16 + } : tensor<16x3x96x2x48x32xbf16> to tensor<16x3x96x4x48x32xbf16> + util.return %padded : tensor<16x3x96x4x48x32xbf16> +} +// With enable-sink-transpose-through-pad=true, transpose sinks through both +// expand_shape and pad. +// SINK-PAD-LABEL: util.func public @sink_transpose_through_expand_shape_and_pad +// SINK-PAD: %[[EXPAND:.+]] = tensor.expand_shape %arg0 +// SINK-PAD: %[[PAD:.+]] = tensor.pad %[[EXPAND]] +// SINK-PAD: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK-PAD-SAME: ins(%[[PAD]] +// SINK-PAD: util.return %[[TRANSPOSE]] + +// Without the flag, transpose sinks through expand_shape but not pad. +// SINK-LABEL: util.func public @sink_transpose_through_expand_shape_and_pad +// SINK: %[[EXPAND:.+]] = tensor.expand_shape %arg0 +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK-SAME: ins(%[[EXPAND]] +// SINK: %[[PAD:.+]] = tensor.pad %[[TRANSPOSE]] +// SINK: util.return %[[PAD]] diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 89d72888b6f3..c44d7affc5e0 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -183,6 +183,11 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { llvm::cl::desc( "Enables propagation of transpose ops through convolutions."), llvm::cl::cat(category)); + binder.opt( + "iree-global-opt-enable-sink-transpose-through-pad", + sinkTransposeThroughPad, + llvm::cl::desc("Enables sinking transpose through pad operations."), + llvm::cl::cat(category)); binder.opt("iree-opt-outer-dim-concat", outerDimConcat, {init_at_opt(llvm::OptimizationLevel::O0, false), init_at_opt(llvm::OptimizationLevel::O1, true)}, diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index eee593f24feb..6a7a9803c349 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -135,6 +135,9 @@ struct GlobalOptimizationOptions { // Enables propagation of transpose ops through convolutions. bool propagateTransposesThroughConv = false; + // Enables sinking transpose through pad operations. + bool sinkTransposeThroughPad = false; + // Enables transposing all concatenations to the outer most dimension. bool outerDimConcat = false; diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 36a6dadd8eab..c5b471f0711a 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -191,6 +191,8 @@ void buildIREEPrecompileTransformPassPipeline( globalOptimizationOptions.aggressiveTransposePropagation; globalTransformOptions.propagateTransposesThroughConv = globalOptimizationOptions.propagateTransposesThroughConv; + globalTransformOptions.sinkTransposeThroughPad = + globalOptimizationOptions.sinkTransposeThroughPad; globalTransformOptions.outerDimConcat = globalOptimizationOptions.outerDimConcat; // The pipeline option has higher priority. diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index 99c945c87c23..4662d4b2b7f9 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -47,7 +47,6 @@ iree_compiler_cc_library( "PadLinalgOps.cpp", "PadToIntrinsics.cpp", "Passes.cpp", - "SinkTransposeThroughPad.cpp", "TransposeMatmul.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index 1d8165618518..3471f1db47a1 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -38,7 +38,6 @@ iree_cc_library( "PadLinalgOps.cpp" "PadToIntrinsics.cpp" "Passes.cpp" - "SinkTransposeThroughPad.cpp" "TransposeMatmul.cpp" DEPS ::PassesIncGen diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index 68be70b0d00d..defec3e3d13a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -172,13 +172,4 @@ def GeneralizeLinalgMatMulPass : ]; } -def SinkTransposeThroughPadPass : - InterfacePass<"iree-preprocessing-sink-transpose-through-pad", "mlir::FunctionOpInterface"> { - let summary = "Sink linalg transpose ops through tensor pad ops"; - let dependentDialects = [ - "mlir::linalg::LinalgDialect", - "mlir::tensor::TensorDialect", - ]; -} - #endif // IREE_PREPROCESSING_COMMON_PASSES diff --git a/compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp b/compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp deleted file mode 100644 index 4d1937a66f8f..000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/SinkTransposeThroughPad.cpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2025 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" -#include "iree/compiler/Preprocessing/Common/Passes.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::iree_compiler::Preprocessing { - -#define GEN_PASS_DEF_SINKTRANSPOSETHROUGHPADPASS -#include "iree/compiler/Preprocessing/Common/Passes.h.inc" - -static Value createTransposeInit(OpBuilder &builder, Value source, - ArrayRef perm) { - SmallVector mixedSizes = - tensor::getMixedSizes(builder, source.getLoc(), source); - applyPermutationToVector(mixedSizes, perm); - Type elemType = cast(source.getType()).getElementType(); - Value empty = - tensor::EmptyOp::create(builder, source.getLoc(), mixedSizes, elemType) - .getResult(); - return empty; -} - -static Value createTranspose(OpBuilder &builder, Value source, - ArrayRef perm) { - Value empty = createTransposeInit(builder, source, perm); - return linalg::TransposeOp::create(builder, source.getLoc(), source, empty, - perm) - ->getResult(0); -} - -// Sinks a transpose through a tensor.pad -class SinkTransposeThroughPadOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const override { - if (!IREE::Flow::isNonNullAndOutsideDispatch(padOp)) { - return failure(); - } - Value source = padOp.getSource(); - auto transposeOp = source.getDefiningOp(); - if (!transposeOp) { - return failure(); - } - - Block &block = padOp.getRegion().front(); - if (llvm::any_of(block.getArguments(), [](BlockArgument blockArg) { - return blockArg.getNumUses(); - })) { - return failure(); - } - - auto invPerm = invertPermutationVector(transposeOp.getPermutation()); - SmallVector lowSizes = padOp.getMixedLowPad(); - SmallVector highSizes = padOp.getMixedHighPad(); - applyPermutationToVector(lowSizes, invPerm); - applyPermutationToVector(highSizes, invPerm); - - RankedTensorType oldPaddedType = cast(padOp.getType()); - RankedTensorType newPaddedType = oldPaddedType.clone( - applyPermutation(oldPaddedType.getShape(), invPerm)); - auto newPadOp = tensor::PadOp::create( - rewriter, padOp.getLoc(), newPaddedType, transposeOp.getInput(), - lowSizes, highSizes, padOp.getNofold()); - rewriter.cloneRegionBefore(padOp.getRegion(), newPadOp.getRegion(), - newPadOp.getRegion().begin()); - Value newTransposeOp = - createTranspose(rewriter, newPadOp, transposeOp.getPermutation()); - rewriter.replaceOp(padOp, newTransposeOp); - return success(); - } -}; - -namespace { -struct SinkTransposeThroughPadPass - : public impl::SinkTransposeThroughPadPassBase< - SinkTransposeThroughPadPass> { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - getOperation().emitError(getPassName()) << " failed to converge."; - return signalPassFailure(); - } - } -}; -} // namespace - -} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index d5ac29ee01cc..68e8d19d31c1 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -29,7 +29,6 @@ iree_lit_test_suite( "pdl_example.mlir", "preprocessing_match_ops.mlir", "transform_symbol_importing.mlir", - "sink_transpose_through_pad.mlir", "transpose_matmul.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index 598ddbcb7692..f0e2bad5da0a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -26,7 +26,6 @@ iree_lit_test_suite( "pad_to_intrinsics_wmma.mlir" "pdl_example.mlir" "preprocessing_match_ops.mlir" - "sink_transpose_through_pad.mlir" "transform_symbol_importing.mlir" "transpose_matmul.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir deleted file mode 100644 index 223100802681..000000000000 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/sink_transpose_through_pad.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-sink-transpose-through-pad))" --split-input-file %s | FileCheck %s - -util.func public @sink_pad_through_transpose(%arg0 : tensor<16x64x64x128xf16>) -> (tensor<16x128x66x66xf16>) { - %2 = tensor.empty() : tensor<16x128x64x64xf16> - %cst = arith.constant 0.000000e+00 : f16 - %transposed = linalg.transpose ins(%arg0 : tensor<16x64x64x128xf16>) outs(%2 : tensor<16x128x64x64xf16>) permutation = [0, 3, 1, 2] - %padded = tensor.pad %transposed low[0, 0, 1, 1] high[0, 0, 1, 1] { - ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index): - tensor.yield %cst : f16 - } : tensor<16x128x64x64xf16> to tensor<16x128x66x66xf16> - util.return %padded : tensor<16x128x66x66xf16> -} -// CHECK-LABEL: util.func public @sink_pad_through_transpose -// CHECK: %[[PAD:.+]] = tensor.pad -// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[PAD]] -// CHECK: util.return %[[TRANSPOSE]] From cf770a7f1bd7d6cd461254de6d2a651f4f24602c Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Wed, 14 Jan 2026 07:49:27 +0100 Subject: [PATCH 30/71] [Encoding] Add dynamic encoding dims to (Un)SetEncodingOp (#22907) --- .../Common/MaterializeEncodingPatterns.cpp | 7 ++-- .../Dialect/Encoding/IR/EncodingOps.td | 22 ++++++++++--- .../Dialect/Encoding/IR/test/roundtrip.mlir | 33 +++++++++++++++++++ .../Transforms/MaterializeEncodings.cpp | 7 ++-- .../DispatchCreation/HoistEncodingOps.cpp | 2 +- .../compiler/DispatchCreation/SetEncoding.cpp | 31 ++++++++++------- .../EncodingExternalModels.cpp | 5 +-- 7 files changed, 81 insertions(+), 26 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp index 5b50c41becbd..191c78306dac 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp @@ -299,11 +299,12 @@ static Value generateEncodingTransferOps(RewriterBase &rewriter, Value src, Value value = src; if (srcType.getEncoding()) { value = IREE::Encoding::UnsetEncodingOp::create( - rewriter, src.getLoc(), srcType.dropEncoding(), value, dynamicDims); + rewriter, src.getLoc(), srcType.dropEncoding(), value, dynamicDims, + /*encodingDims=*/ValueRange{}); } if (destType.getEncoding()) { - value = IREE::Encoding::SetEncodingOp::create(rewriter, src.getLoc(), - destType, value); + value = IREE::Encoding::SetEncodingOp::create( + rewriter, src.getLoc(), destType, value, /*encodingDims=*/ValueRange{}); } return value; } diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td index 29e7de2b410e..af6649e91bfd 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.td @@ -23,13 +23,19 @@ def IREEEncoding_SetEncodingOp : IREEEncoding_PureOp<"set_encoding",[ Operation to assign an encoding to a tensor. The operation does not change the rank or extent of a tensor. Instead it adds a LayoutResolverAttr attribute to the tensor type to represent a change in layout. + + The optional `encoding_dims` operand carries dynamic values needed by the + encoding (e.g., M, N, K dimensions for matmul encodings). These values are + used for runtime layout selection based on problem size. }]; - let arguments = (ins AnyRankedTensor:$source); + let arguments = (ins + AnyRankedTensor:$source, + Variadic:$encoding_dims); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ - attr-dict $source `:` type($source) `->` type($result) + attr-dict $source (`encoding_dims` `{` $encoding_dims^ `}`)? `:` type($source) `->` type($result) }]; let hasVerifier = 1; @@ -49,21 +55,27 @@ def IREEEncoding_SetEncodingOp : IREEEncoding_PureOp<"set_encoding",[ //===----------------------------------------------------------------------===// def IREEEncoding_UnsetEncodingOp : IREEEncoding_PureOp<"unset_encoding", [ - DeclareOpInterfaceMethods, Pure + DeclareOpInterfaceMethods, + AttrSizedOperandSegments, Pure ]> { let summary = [{Perform unpack and extract operation on source.}]; let description = [{ Operation to convert a tensor with LayoutResolverAttr encoding that represents its data layout into a tensor with default layout (i.e. no encoding). For now in IREE the default layout is row-major. + + The optional `encoding_dims` operand carries dynamic values needed by the + encoding (e.g., M, N, K dimensions for matmul encodings). These values are + used for runtime layout selection based on problem size. }]; let arguments = (ins AnyRankedTensor:$source, - Variadic:$result_dims); + Variadic:$result_dims, + Variadic:$encoding_dims); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ - attr-dict $source `:` type($source) `->` type($result) (`` `{` $result_dims^ `}`)? + attr-dict $source (`encoding_dims` `{` $encoding_dims^ `}`)? `:` type($source) `->` type($result) (`` `{` $result_dims^ `}`)? }]; let hasVerifier = 1; diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir index 7cfdf9dcda2d..cc623587b571 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir @@ -259,3 +259,36 @@ func.func @identity_encoding(%arg0: tensor) -> tensor } // CHECK: func.func @identity_encoding(%[[ARG0:.+]]: tensor + +// ----- + +#encoding = #iree_encoding.testing<> +func.func @set_encoding_with_encoding_dims(%arg0: tensor, %m: index, %n: index, %k: index) -> tensor { + %0 = iree_encoding.set_encoding %arg0 encoding_dims{%m, %n, %k} : tensor -> tensor + return %0 : tensor +} +// CHECK: #[[ENCODING:.+]] = #iree_encoding.testing<> +// CHECK: func.func @set_encoding_with_encoding_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[M:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[K:[a-zA-Z0-9]+]]: index +// CHECK: iree_encoding.set_encoding %[[ARG0]] encoding_dims{%[[M]], %[[N]], %[[K]]} : tensor -> tensor + +// ----- + +#encoding = #iree_encoding.testing<> +func.func @unset_encoding_with_encoding_dims( + %arg0: tensor, %d0: index, %d1: index, %m: index, %n: index, %k: index) -> tensor { + %0 = iree_encoding.unset_encoding %arg0 encoding_dims{%m, %n, %k} : tensor -> tensor{%d0, %d1} + return %0 : tensor +} +// CHECK: #[[ENCODING:.+]] = #iree_encoding.testing<> +// CHECK: func.func @unset_encoding_with_encoding_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[D0:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[D1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[M:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[K:[a-zA-Z0-9]+]]: index +// CHECK: iree_encoding.unset_encoding %[[ARG0]] encoding_dims{%[[M]], %[[N]], %[[K]]} : tensor -> tensor{%[[D0]], %[[D1]]} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeEncodings.cpp index da65eac45ece..32170aa51a0a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeEncodings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeEncodings.cpp @@ -142,11 +142,12 @@ static func::FuncOp createWorkgroupFunc(IREE::Stream::TensorEncodeOp encodeOp, if (sourceType != destinationType) { if (sourceType.getEncoding()) { value = IREE::Encoding::UnsetEncodingOp::create( - builder, loc, sourceType.dropEncoding(), value, sourceDynamicDims); + builder, loc, sourceType.dropEncoding(), value, sourceDynamicDims, + /*encodingDims=*/ValueRange{}); } if (destinationType.getEncoding()) { - value = IREE::Encoding::SetEncodingOp::create(builder, loc, - destinationType, value); + value = IREE::Encoding::SetEncodingOp::create( + builder, loc, destinationType, value, /*encodingDims=*/ValueRange{}); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp b/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp index d7c2b6521f64..4561cc366c7d 100644 --- a/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/HoistEncodingOps.cpp @@ -110,7 +110,7 @@ bubbleUpSetEncodingThroughGenericOp(RewriterBase &rewriter, auto resType = RankedTensorType::get( operandType.getShape(), operandType.getElementType(), newEncoding); Value encodedInput = IREE::Encoding::SetEncodingOp::create( - rewriter, loc, resType, operand->get()); + rewriter, loc, resType, operand->get(), /*encodingDims=*/ValueRange{}); encodedOperands.push_back(encodedInput); } diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp index af7fe39078c6..4961db09ed1f 100644 --- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp @@ -36,15 +36,16 @@ using IREE::Encoding::EncodingAttr; //===---------------------------------------------------------------------===// static Value setEncoding(OpBuilder &builder, Location loc, Value source, - Attribute encodingAttr) { + Attribute encodingAttr, ValueRange encodingDims = {}) { auto resultType = cast(source.getType()).cloneWithEncoding(encodingAttr); - return IREE::Encoding::SetEncodingOp::create(builder, loc, resultType, - source); -}; + return IREE::Encoding::SetEncodingOp::create(builder, loc, resultType, source, + encodingDims); +} static Value unsetEncoding(OpBuilder &builder, Location loc, Value source, - SmallVector sizes) { + SmallVector sizes, + ValueRange encodingDims = {}) { SmallVector dynamicSizesVec; SmallVector staticSizesVec; dispatchIndexOpFoldResults(sizes, dynamicSizesVec, staticSizesVec); @@ -53,7 +54,8 @@ static Value unsetEncoding(OpBuilder &builder, Location loc, Value source, auto unsetEncodingReturnType = RankedTensorType::get(sourceType.getShape(), sourceType.getElementType()); return IREE::Encoding::UnsetEncodingOp::create( - builder, loc, unsetEncodingReturnType, source, dynamicSizesVec); + builder, loc, unsetEncodingReturnType, source, dynamicSizesVec, + encodingDims); } static SmallVector @@ -91,15 +93,18 @@ static LogicalResult setDataTilingEncodings(RewriterBase &rewriter, SmallVector encodedInputOperands; for (auto [idx, props] : llvm::enumerate(encProps.operands)) { Value src = linalgOp.getDpsInputs()[idx]; - Value encoded = setEncoding(rewriter, loc, src, props.encoding); + Value encoded = + setEncoding(rewriter, loc, src, props.encoding, props.dynamicValues); encodedInputOperands.push_back(encoded); } // Set encoding on init operand. // For now, we assume single init. assert(encProps.inits.size() == 1 && "Expected single init encoding"); - Value encodedInitOperand = setEncoding( - rewriter, loc, linalgOp.getDpsInits()[0], encProps.inits[0].encoding); + IREE::Encoding::EncodingProperties &initProps = encProps.inits[0]; + Value encodedInitOperand = + setEncoding(rewriter, loc, linalgOp.getDpsInits()[0], initProps.encoding, + initProps.dynamicValues); SmallVector encodedOperands(encodedInputOperands); encodedOperands.push_back(encodedInitOperand); @@ -110,7 +115,8 @@ static LogicalResult setDataTilingEncodings(RewriterBase &rewriter, // Sizes are computed by original output size. SmallVector outSizes = tensor::getMixedSizes(rewriter, loc, linalgOp.getDpsInits()[0]); - Value result = unsetEncoding(rewriter, loc, opTiled, outSizes); + Value result = + unsetEncoding(rewriter, loc, opTiled, outSizes, initProps.dynamicValues); rewriter.replaceOp(linalgOp, result); return success(); @@ -242,7 +248,8 @@ static std::optional padProducerOfValue(RewriterBase &rewriter, // Find the new value to yield. Value newYieldedVal = map.lookup(operand); auto encodingOp = IREE::Encoding::SetEncodingOp::create( - rewriter, returnOp->getLoc(), newResultType, newYieldedVal); + rewriter, returnOp->getLoc(), newResultType, newYieldedVal, + /*encodingDims=*/ValueRange{}); rewriter.modifyOpInPlace( returnOp, [&]() { returnOp.setOperand(resultNumber, encodingOp); }); @@ -276,7 +283,7 @@ static SmallVector padOperandsOfOp(RewriterBase &rewriter, Type operandType = operand.get().getType(); auto unsetEncodignOp = IREE::Encoding::UnsetEncodingOp::create( rewriter, op->getLoc(), operandType, paddedVal->paddedValue, - paddedVal->dynamicDims); + paddedVal->dynamicDims, /*encodingDims=*/ValueRange{}); op->setOperand(operandNum, unsetEncodignOp.getResult()); }); } diff --git a/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp index 999ff6b543ba..4b5f7998e384 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/EncodingExternalModels.cpp @@ -71,7 +71,8 @@ static IREE::Encoding::PropagationResult propagateThroughEncodingCastableOp( } // Otherwise, we need to create a new set_encoding op. auto setEncodingOp = IREE::Encoding::SetEncodingOp::create( - builder, op->getLoc(), encodedOperandType, operand); + builder, op->getLoc(), encodedOperandType, operand, + /*encodingDims=*/ValueRange{}); encodedOperands.push_back(setEncodingOp.getResult()); result.generatedEncodingOps.push_back(setEncodingOp); } @@ -100,7 +101,7 @@ static IREE::Encoding::PropagationResult propagateThroughEncodingCastableOp( std::tie(std::ignore, resultDynamicDims) = decomposeMixedValues(mixedSizes); auto unsetEncodingOp = IREE::Encoding::UnsetEncodingOp::create( builder, op->getLoc(), originalResult.getType(), encodedResult, - resultDynamicDims); + resultDynamicDims, /*encodingDims=*/ValueRange{}); result.generatedEncodingOps.push_back(unsetEncodingOp); result.replacements.push_back(unsetEncodingOp.getResult()); } From 7be5671d8c334e64f87b0c447756bda01a635884 Mon Sep 17 00:00:00 2001 From: Eric Feng <55723758+efric@users.noreply.github.com> Date: Tue, 13 Jan 2026 23:08:30 -0800 Subject: [PATCH 31/71] Reapply "[GPU][Codegen] Expand iteration space based on new `expand_dims` attribute" (#23076) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This relands expand dims. The issue from DispatchCreation (#22978) which expand dims unearthed was resolved upstream in [[mlir][Linalg] Drop unit extent dim in non-trivial expressions](https://github.com/llvm/llvm-project/pull/173873). In addition to the results in the original description, the impact of this change on llama 8b fp16 decode was evaluated. For the two most dominant dispatches where this change is relevant (in combination with chain FMA), register usage is reduced by up to 26%, with no meaningful change in either direction in run time. Original description: --- This patch introduces iteration space expansion for reductions in the VectorDistribute path. Specifically, we: 1. Add a new attribute, `expand_dims`, for reductions. 2. Introduce a new pass, `GPUExpandDimensions`, which uses `expand_dims` to expand the iteration space of relevant dimensions. 3. Refactor common functionality shared between `GPUExpandDimensions` and `BlockDynamicDimensions` into reusable utilities. 4. Refactor encoding helpers from `EncodingAttrs.cpp` into reusable utilities. This change also enables [chain FMA](https://github.com/iree-org/iree/pull/21855) in matvec codegen as we iterate along the K reduction dimension. --- **Performance Summary** **IREE benchmark module** * Only expansion: ~4% improvement * Expansion + chain FMA: ~11% improvement **rocprof** * Only expansion: ~13% worse * Expansion + chain FMA: ~9% better **Register usage** * 10% reduction (60 → 54 registers for matvec dispatches) **Instruction latency (post-reduction loop epilogue)** * 3.5% improvement (340 → 328 total mean latency) --- **Notes** * As a follow-up, we can explore applying iteration space expansion to the reduction in attention * Right now, we only expand one dimension into two although the implementation supports expansion to N dimensions. * Please note this PR changes the reduction order, some expect some minor changes to the numerics * This is does not improve performance by itself/can cause regression without chain FMA https://github.com/iree-org/iree/pull/21855 Traces for matvec dispatches are attached for all variations (original, only expansion, and expansion + chain FMA). [115_expansion_and_chain.tar.gz](https://github.com/user-attachments/files/23268046/115_expansion_and_chain.tar.gz) [115_nothing.tar.gz](https://github.com/user-attachments/files/23268047/115_nothing.tar.gz) [115_only_expansion.tar.gz](https://github.com/user-attachments/files/23268048/115_only_expansion.tar.gz) Fixes: #22153 ci-extra: test_torch Signed-off-by: Eric Feng --- .../Codegen/Common/BlockDynamicDimensions.cpp | 17 +- .../compiler/Codegen/Common/GPU/BUILD.bazel | 2 + .../Codegen/Common/GPU/CMakeLists.txt | 2 + .../Common/GPU/GPUExpandDimensions.cpp | 290 ++++++++++++++++++ .../compiler/Codegen/Common/GPU/Passes.td | 8 + .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + .../GPU/test/gpu_expand_dimensions.mlir | 93 ++++++ .../compiler/Codegen/Common/Transforms.cpp | 20 ++ .../iree/compiler/Codegen/Common/Transforms.h | 14 + .../Dialect/GPU/IR/GPULoweringConfigUtils.cpp | 7 + .../Dialect/GPU/IR/GPULoweringConfigUtils.h | 3 + .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 84 +++++ .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 56 ++++ .../GPU/TargetUtils/ReductionConfigUtils.cpp | 63 +++- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 2 + ...ig_vector_distribute_reduction_gfx942.mlir | 57 ++-- .../Codegen/LLVMGPU/test/config_matvec.mlir | 77 ++--- .../LLVMGPU/test/reduction_pipeline_cuda.mlir | 77 +++-- .../LLVMGPU/test/reduction_pipeline_rocm.mlir | 44 +-- .../compiler/Dialect/Encoding/IR/BUILD.bazel | 1 + .../Dialect/Encoding/IR/CMakeLists.txt | 1 + .../Dialect/Encoding/IR/EncodingAttrs.cpp | 55 +--- compiler/src/iree/compiler/Utils/BUILD.bazel | 2 + .../src/iree/compiler/Utils/CMakeLists.txt | 2 + .../src/iree/compiler/Utils/EncodingUtils.cpp | 86 ++++++ .../src/iree/compiler/Utils/EncodingUtils.h | 39 +++ 27 files changed, 901 insertions(+), 203 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir create mode 100644 compiler/src/iree/compiler/Utils/EncodingUtils.cpp create mode 100644 compiler/src/iree/compiler/Utils/EncodingUtils.h diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp index b033e618392b..e8b66770b0bc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -36,17 +36,6 @@ using TensorDivisibilityInfo = namespace { -struct RemoveOptimizationBarrier final - : public OpRewritePattern { - using Base::Base; - - LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, - PatternRewriter &rewriter) const override { - rewriter.replaceOp(barrierOp, barrierOp.getOperands()); - return success(); - } -}; - /// This pass is used to materialize information about dynamic dimensions of /// `tensor` operands of an operation in the IR. If a dynamic dimension is /// known to be a multiple of a compile-time constant value, this pass @@ -110,10 +99,6 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, /// inverses of each other. The `util.optimization.barrier` avoid these from /// getting folded away during reshape propagation. Return the result of the /// `tensor.collapse_shape generated. -struct ReshapeOps { - tensor::ExpandShapeOp expandShapeOp; - tensor::CollapseShapeOp collapseShapeOp; -}; static std::optional blockDynamicDimensionsOfValue(RewriterBase &rewriter, const TensorDivisibilityInfo &divisibilityInfo, @@ -413,7 +398,7 @@ void BlockDynamicDimensionsPass::runOnOperation() { // Delete the optimization barrier and run some further cleanup. { RewritePatternSet removeBarrierOpsPatterns(context); - removeBarrierOpsPatterns.insert(context); + populateRemoveOptimizationBarrierPatterns(removeBarrierOpsPatterns); tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, context); tensor::CollapseShapeOp::getCanonicalizationPatterns( diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 7e905bf9f5b2..1ae7e44ee08c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -74,6 +74,7 @@ iree_compiler_cc_library( "GPUDistributeScfFor.cpp", "GPUDistributeSharedMemoryCopy.cpp", "GPUDistributionPatterns.cpp", + "GPUExpandDimensions.cpp", "GPUFuseAndHoistParallelLoops.cpp", "GPUGeneralizeNamedOps.cpp", "GPUGreedilyDistributeToThreads.cpp", @@ -125,6 +126,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", + "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 8fa41efd5439..c09cba442a02 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -67,6 +67,7 @@ iree_cc_library( "GPUDistributeScfFor.cpp" "GPUDistributeSharedMemoryCopy.cpp" "GPUDistributionPatterns.cpp" + "GPUExpandDimensions.cpp" "GPUFuseAndHoistParallelLoops.cpp" "GPUGeneralizeNamedOps.cpp" "GPUGreedilyDistributeToThreads.cpp" @@ -159,6 +160,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils iree::compiler::Dialect::TensorExt::IR + iree::compiler::Dialect::Util::IR iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp new file mode 100644 index 000000000000..71cb20c85580 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUExpandDimensions.cpp @@ -0,0 +1,290 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-gpu-expand-dimensions" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUEXPANDDIMENSIONSPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { + +struct GPUExpandDimensionsPass final + : impl::GPUExpandDimensionsPassBase { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +// Compute the expanded shape for a reassociation group. Requires the original +// dimension to be static and evenly divisible by the product of static factors +// in the target shape. +static FailureOr> computeExpandedGroupShape( + RewriterBase &rewriter, Location loc, OpFoldResult origDimSize, + ArrayRef groupTargetShape, unsigned iteratorDim) { + if (groupTargetShape.size() == 1) { + return SmallVector{origDimSize}; + } + + std::optional staticOrigDim = getConstantIntValue(origDimSize); + if (!staticOrigDim) { + return rewriter.notifyMatchFailure( + loc, "dimension " + Twine(iteratorDim) + + " is dynamic, but expand_dims requires static dimensions"); + } + + int64_t staticFactor = llvm::product_of( + llvm::make_filter_range(groupTargetShape, ShapedType::isStatic)); + + if (staticFactor < 1) { + return rewriter.notifyMatchFailure( + loc, "invalid expansion factor " + Twine(staticFactor) + + " for iterator dimension " + Twine(iteratorDim)); + } + + if (staticOrigDim.value() % staticFactor != 0) { + return rewriter.notifyMatchFailure( + loc, "dimension " + Twine(iteratorDim) + + " (size=" + Twine(staticOrigDim.value()) + + ") not divisible by expansion factor " + Twine(staticFactor)); + } + + return llvm::map_to_vector( + groupTargetShape, [&](int64_t size) -> OpFoldResult { + if (ShapedType::isStatic(size)) { + return rewriter.getIndexAttr(size); + } + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + return affine::makeComposedFoldedAffineApply( + rewriter, loc, s0.floorDiv(staticFactor), {origDimSize}); + }); +} + +// For an operation annotated with the `expand_dims` attribute, replace relevant +// operands with tensor.expand_shape/tensor.collapse_shape pair to materialize +// dimension expansion according to the reassociation and output_shape defined +// in the attribute. +// +// Example: +// +// ```mlir +// %0 = (..., %0, ...) { +// lowering_config = #iree_gpu.lowering_config<{ +// expand_dims = #iree_gpu.expand_dims +// [[0], [1, 2]], output_shape = [?, ?, 8]> +// }> +// } : ... -> tensor<4x128xf32> +// ``` +// +// becomes: +// +// ```mlir +// %expanded = tensor.expand_shape %0 [[0], [1, 2]] +// : tensor<4x128xf32> into tensor<4x16x8xf32> +// %barrier = util.optimization_barrier %expanded +// %collapsed = tensor.collapse_shape %barrier [[0], [1, 2]] +// : tensor<4x16x8xf32> into tensor<4x128xf32> +// %1 = (..., %collapsed, ...) : ... -> tensor<4x128xf32> +// ``` +static std::optional +createDimensionExpansionOps(RewriterBase &rewriter, + IREE::GPU::DimensionExpansionAttr config, Value v, + AffineMap indexingMap, linalg::LinalgOp op) { + auto tensorType = dyn_cast(v.getType()); + if (!tensorType) { + return std::nullopt; + } + + Location loc = v.getLoc(); + MLIRContext *ctx = op.getContext(); + int64_t tensorRank = tensorType.getRank(); + ArrayRef outputShape = config.getOutputShape().asArrayRef(); + SmallVector origShape = tensor::getMixedSizes(rewriter, loc, v); + + // Map each tensor dimension to its expanded shape components. + SmallVector> expandedShapes(tensorRank); + for (auto [iterDim, reassocIndices] : + llvm::enumerate(config.getReassociationIndices())) { + std::optional tensorDim = + indexingMap.getResultPosition(getAffineDimExpr(iterDim, ctx)); + if (!tensorDim.has_value()) { + continue; + } + + auto groupOutputShape = llvm::map_to_vector( + reassocIndices, [&](int64_t i) { return outputShape[i]; }); + + FailureOr> groupShape = computeExpandedGroupShape( + rewriter, loc, origShape[tensorDim.value()], groupOutputShape, iterDim); + if (failed(groupShape)) { + return std::nullopt; + } + + expandedShapes[tensorDim.value()] = std::move(groupShape.value()); + } + + // Build reassociation indices and expanded shape in tensor dimension order. + SmallVector reassociation; + SmallVector expandedShape; + for (auto [tensorDim, expanded] : llvm::enumerate(expandedShapes)) { + ReassociationIndices &indices = reassociation.emplace_back(); + auto addDim = [&](OpFoldResult dim) { + indices.push_back(expandedShape.size()); + expandedShape.push_back(dim); + }; + if (expanded.empty()) { + addDim(origShape[tensorDim]); + } else { + llvm::for_each(expanded, addDim); + } + } + + // If no expansion is needed, return early. + if (llvm::equal(origShape, expandedShape)) { + return std::nullopt; + } + + auto staticShape = llvm::map_to_vector(expandedShape, [](OpFoldResult ofr) { + return getConstantIntValue(ofr).value(); + }); + + auto expandedType = RankedTensorType::get( + staticShape, tensorType.getElementType(), tensorType.getEncoding()); + + auto expandOp = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, v, + reassociation, expandedShape); + Value barrier = IREE::Util::OptimizationBarrierOp::create( + rewriter, loc, expandOp.getResult()) + .getResult(0); + auto collapseOp = tensor::CollapseShapeOp::create(rewriter, loc, tensorType, + barrier, reassociation); + + return ReshapeOps{expandOp, collapseOp}; +} + +static LogicalResult expandIterationSpace(RewriterBase &rewriter, + linalg::LinalgOp op) { + auto loweringConfig = getLoweringConfig(op); + if (!loweringConfig) { + return success(); + } + auto config = IREE::GPU::getDimensionExpansion(loweringConfig); + if (!config) { + return success(); + } + + LDBG() << "Expanding dimensions for op: " << *op; + + for (OpOperand &operand : op->getOpOperands()) { + AffineMap indexingMap = op.getMatchingIndexingMap(&operand); + std::optional reshapes = createDimensionExpansionOps( + rewriter, config, operand.get(), indexingMap, op); + if (reshapes.has_value()) { + rewriter.modifyOpInPlace( + op, [&]() { operand.set(reshapes.value().collapseShapeOp); }); + } + } + + return success(); +} + +void GPUExpandDimensionsPass::runOnOperation() { + Operation *operation = getOperation(); + MLIRContext *context = &getContext(); + IRRewriter rewriter(context); + + SmallVector worklist; + operation->walk([&](linalg::LinalgOp op) { + if (auto cfg = getLoweringConfig(op)) { + if (IREE::GPU::getDimensionExpansion(cfg)) { + worklist.push_back(op); + } + } + }); + + for (linalg::LinalgOp op : worklist) { + rewriter.setInsertionPoint(op); + if (failed(expandIterationSpace(rewriter, op))) { + return signalPassFailure(); + } + } + + LDBG() << "After expanding dimensions: " << *operation; + + ConfigTrackingListener listener; + GreedyRewriteConfig config; + config.setListener(&listener); + + { + RewritePatternSet bubbleExpandShapePatterns(context); + linalg::ControlFusionFn controlFn = [](OpOperand *opOperand) { + return !isa_and_nonnull( + opOperand->get().getDefiningOp()); + }; + linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, + controlFn); + IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); + tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + tensor::populateBubbleUpExpandShapePatterns(bubbleExpandShapePatterns); + linalg::FillOp::getCanonicalizationPatterns( + bubbleExpandShapePatterns, bubbleExpandShapePatterns.getContext()); + memref::populateResolveRankedShapedTypeResultDimsPatterns( + bubbleExpandShapePatterns); + if (failed(applyPatternsGreedily( + operation, std::move(bubbleExpandShapePatterns), config))) { + operation->emitOpError( + "failed in application of bubble up expand shape patterns"); + return signalPassFailure(); + } + } + + LDBG() << "After reshape propagation: " << *operation; + + { + RewritePatternSet removeBarrierOpsPatterns(context); + populateRemoveOptimizationBarrierPatterns(removeBarrierOpsPatterns); + tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + removeBarrierOpsPatterns, context); + tensor::populateFoldTensorEmptyPatterns(removeBarrierOpsPatterns); + linalg::FillOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + memref::populateResolveRankedShapedTypeResultDimsPatterns( + removeBarrierOpsPatterns); + if (failed(applyPatternsGreedily(operation, + std::move(removeBarrierOpsPatterns)))) { + operation->emitOpError("failed in cleanup patterns"); + return signalPassFailure(); + } + } + + return; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 78352069d7a6..a86a1284a28a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -383,6 +383,14 @@ def GPUApplyPaddingLevelPass : ]; } +def GPUExpandDimensionsPass : + InterfacePass<"iree-codegen-gpu-expand-dimensions", "mlir::FunctionOpInterface"> { + let summary = "Pass to expand tensor op dims based on `expand_dims` lowering_config"; + let dependentDialects = [ + "::mlir::iree_compiler::IREE::Util::UtilDialect" + ]; +} + def GPUTensorTileToSerialLoopsPass : InterfacePass<"iree-codegen-gpu-tensor-tile-to-serial-loops", "mlir::FunctionOpInterface"> { let summary = "Pass to tile reduction dimensions for certain GPU ops"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index c36e063de06b..2673c42bbdf5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -36,6 +36,7 @@ iree_lit_test_suite( "gpu_distribute_forall.mlir", "gpu_distribute_scf_for.mlir", "gpu_distribute_shared_memory.mlir", + "gpu_expand_dimensions.mlir", "gpu_fuse_and_hoist_forall.mlir", "gpu_generalize_named_ops.mlir", "gpu_greedily_distribute_to_threads.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 002a332570b4..4ce6c005b783 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -32,6 +32,7 @@ iree_lit_test_suite( "gpu_distribute_forall.mlir" "gpu_distribute_scf_for.mlir" "gpu_distribute_shared_memory.mlir" + "gpu_expand_dimensions.mlir" "gpu_fuse_and_hoist_forall.mlir" "gpu_generalize_named_ops.mlir" "gpu_greedily_distribute_to_threads.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir new file mode 100644 index 000000000000..d30de91e5fc3 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_expand_dimensions.mlir @@ -0,0 +1,93 @@ +// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-expand-dimensions))" | FileCheck %s + +func.func @expand_matvec(%a: tensor<4x16384xf16>, %b: tensor<1x16384xf16>) -> tensor<4x1xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x1xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x1xf32>) -> tensor<4x1xf32> + %result = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%a, %b : tensor<4x16384xf16>, tensor<1x16384xf16>) + outs(%fill : tensor<4x1xf32>) + attrs = { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1], [2, 3]], output_shape = [?, ?, ?, 8]>, + lane_basis = [[1, 1, 64, 1], [0, 1, 2, 3]], + partial_reduction = [0, 0, 64, 0], + subgroup_basis = [[1, 1, 1, 1], [0, 1, 2, 3]], + thread = [0, 0, 1, 8], + workgroup = [4, 1, 0, 0]}>} { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %0 = arith.extf %in : f16 to f32 + %1 = arith.extf %in_0 : f16 to f32 + %2 = arith.mulf %0, %1 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<4x1xf32> + return %result : tensor<4x1xf32> +} + +// CHECK-LABEL: func.func @expand_matvec +// CHECK: %[[A_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2]] output_shape [4, 2048, 8] : tensor<4x16384xf16> into tensor<4x2048x8xf16> +// CHECK: %[[B_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2]] output_shape [1, 2048, 8] : tensor<1x16384xf16> into tensor<1x2048x8xf16> +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[A_EXPAND]], %[[B_EXPAND]] : tensor<4x2048x8xf16>, tensor<1x2048x8xf16>) + +// ----- + +func.func @expand_multiple_dims(%a: tensor<4x16384xf16>, %b: tensor<4x16384xf16>) -> tensor<4x16384xf16> { + %empty = tensor.empty() : tensor<4x16384xf16> + %result = linalg.add { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1, 2, 3]], output_shape = [?, ?, 2, 4]> + }>} + ins(%a, %b : tensor<4x16384xf16>, tensor<4x16384xf16>) outs(%empty : tensor<4x16384xf16>) -> tensor<4x16384xf16> + return %result : tensor<4x16384xf16> +} + +// CHECK-LABEL: func.func @expand_multiple_dims +// CHECK: %[[A_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2, 3]] output_shape [4, 2048, 2, 4] : tensor<4x16384xf16> into tensor<4x2048x2x4xf16> +// CHECK: %[[B_EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[}}[0], [1, 2, 3]] output_shape [4, 2048, 2, 4] : tensor<4x16384xf16> into tensor<4x2048x2x4xf16> +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[A_EXPAND]], %[[B_EXPAND]] : tensor<4x2048x2x4xf16>, tensor<4x2048x2x4xf16>) + +// ----- + +// Verify that dynamic dimensions are gracefully handled (no expansion occurs). +func.func @no_expand_dynamic_dims(%a: tensor<4x?xf16>, %b: tensor<4x?xf16>) -> tensor<4x128xf16> { + %empty = tensor.empty() : tensor<4x128xf16> + %result = linalg.add { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1, 2]], output_shape = [?, ?, 8]> + }>} + ins(%a, %b : tensor<4x?xf16>, tensor<4x?xf16>) outs(%empty : tensor<4x128xf16>) -> tensor<4x128xf16> + return %result : tensor<4x128xf16> +} + +// CHECK-LABEL: func.func @no_expand_dynamic_dim +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.add +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x?xf16>, tensor<4x?xf16>) + +// ----- + +// Verify that non-divisible dimensions are gracefully handled (no expansion occurs). +func.func @no_expand_not_divisible(%a: tensor<4x127xf16>, %b: tensor<4x127xf16>) -> tensor<4x127xf16> { + %empty = tensor.empty() : tensor<4x127xf16> + %result = linalg.add { + lowering_config = #iree_gpu.lowering_config<{ + expand_dims = #iree_gpu.expand_dims<[[0], [1, 2]], output_shape = [?, ?, 8]> + }>} + ins(%a, %b : tensor<4x127xf16>, tensor<4x127xf16>) outs(%empty : tensor<4x127xf16>) -> tensor<4x127xf16> + return %result : tensor<4x127xf16> +} + +// CHECK-LABEL: func.func @no_expand_not_divisible +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.add +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x127xf16>, tensor<4x127xf16>) diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp index 92ae83c6c9cb..fd4b75a4e73c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallVectorExtras.h" #include "mlir/Analysis/SliceAnalysis.h" @@ -817,4 +818,23 @@ void populateSwapExtractWithCollapsePattern(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +namespace { + +struct RemoveOptimizationBarrier final + : public OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(barrierOp, barrierOp.getOperands()); + return success(); + } +}; + +} // namespace + +void populateRemoveOptimizationBarrierPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h index b6c0067857b5..413cfb17a584 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h @@ -211,6 +211,20 @@ void populateCombineRelayoutOpPatterns( /// Populate patterns to fuse tilable consumers of forall ops into it. void populateFuseTilableForallConsumersPattern(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// Utilities for iteration space expansion transformations +//===----------------------------------------------------------------------===// + +/// Helper struct to hold the expand/collapse shape ops created for dimension +/// expansion or blocking transformations. +struct ReshapeOps { + tensor::ExpandShapeOp expandShapeOp; + tensor::CollapseShapeOp collapseShapeOp; +}; + +/// Populate patterns to remove optimization barriers. +void populateRemoveOptimizationBarrierPatterns(RewritePatternSet &patterns); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp index a04af998e191..b8c27fd9cf06 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp @@ -151,4 +151,11 @@ std::optional> getPaddingList(LoweringConfigAttr config, return getIntegerVector(array); } +constexpr StringLiteral kDimensionExpansionName = "expand_dims"; + +DimensionExpansionAttr getDimensionExpansion(LoweringConfigAttr config) { + return config.getAttributes().getAs( + kDimensionExpansionName); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h index ad175556116d..e5b8b4730715 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h @@ -62,6 +62,9 @@ IREE::GPU::LoweringConfigAttr setPromotedOperandsList( std::optional> getPaddingList(LoweringConfigAttr config, bool paddingConv = false); +/// Helper to retrieve dimension expansion config from lowering config. +DimensionExpansionAttr getDimensionExpansion(LoweringConfigAttr config); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPULOWERINGCONFIGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 9965b093eff2..6d548c86c46a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" #include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h" +#include "iree/compiler/Utils/EncodingUtils.h" #include "iree/compiler/Utils/Indexing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" @@ -2309,6 +2310,89 @@ GPUPipelineOptionsAttr GPUPipelineOptionsAttr::get( } //===----------------------------------------------------------------------===// +// DimensionExpansionAttr +//===----------------------------------------------------------------------===// + +DimensionExpansionAttr +DimensionExpansionAttr::get(MLIRContext *context, + ArrayRef reassociations, + ArrayRef outputShape) { + Builder b(context); + SmallVector reassociationAttrs; + for (const ReassociationIndices &indices : reassociations) { + SmallVector indexAttrs; + for (int64_t idx : indices) { + indexAttrs.push_back(b.getI64IntegerAttr(idx)); + } + reassociationAttrs.push_back(b.getArrayAttr(indexAttrs)); + } + ArrayAttr reassociationAttr = b.getArrayAttr(reassociationAttrs); + DenseI64ArrayAttr outputShapeAttr = b.getDenseI64ArrayAttr(outputShape); + return get(context, reassociationAttr, outputShapeAttr); +} + +LogicalResult +DimensionExpansionAttr::verify(function_ref emitError, + ArrayAttr reassociations, + DenseI64ArrayAttr outputShape) { + if (reassociations.empty()) { + return emitError() << "reassociations cannot be empty"; + } + + int64_t nextExpected = 0; + + for (auto [groupIdx, attr] : llvm::enumerate(reassociations)) { + auto indexArray = dyn_cast(attr); + if (!indexArray) { + return emitError() << "reassociation at index " << groupIdx + << " must be an array"; + } + + if (indexArray.empty()) { + return emitError() << "reassociation group " << groupIdx + << " cannot be empty"; + } + + int numDynamicDims = 0; + for (auto [innerIdx, idxAttr] : llvm::enumerate(indexArray)) { + auto intAttr = dyn_cast(idxAttr); + if (!intAttr) { + return emitError() << "reassociation index at [" << groupIdx << "][" + << innerIdx << "] must be an integer"; + } + + int64_t idx = intAttr.getInt(); + if (idx != nextExpected) { + return emitError() << "reassociation indices must form contiguous " + << "sequence; expected dimension " << nextExpected + << " at [" << groupIdx << "][" << innerIdx + << "], got " << idx; + } + + if (outputShape[idx] == ShapedType::kDynamic) { + numDynamicDims++; + } + + nextExpected++; + } + + if (numDynamicDims > 1) { + return emitError() + << "reassociation group " << groupIdx + << " has multiple dynamic dimensions; at most 1 allowed"; + } + } + + ArrayRef outputShapeArray = outputShape.asArrayRef(); + if (nextExpected != static_cast(outputShapeArray.size())) { + return emitError() << "reassociations cover " << nextExpected + << " dimensions, but output_shape has rank " + << outputShapeArray.size(); + } + + return success(); +} + // Index Hint Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index f14a5ef98c1c..9c93c3d43c4d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -960,6 +960,62 @@ def IREEGPU_GPUPipelineOptionsAttr : AttrDef { + let mnemonic = "expand_dims"; + let cppNamespace = "::mlir::iree_compiler::IREE::GPU"; + + let summary = [{Attribute for describing static dimension expansion.}]; + let description = [{ + This attribute describes how dimensions in an iteration space should be + expanded. Each original dimension can either remain unchanged or be + split into multiple dimensions. The semantics are similar to the familiar + `tensor.expand_shape` operation. + + The reassociations parameter specifies the mapping from original dimensions + to expanded dimensions. For example, [[0], [1], [2, 3]] means: + - Original dimension 0 maps to output dimension 0 + - Original dimension 1 maps to output dimension 1 + - Original dimension 2 is split into output dimensions 2 and 3 + + The output_shape parameter specifies the sizes of the expanded dimensions. + If the size is ShapedType::kDynamic, the size is determined from the product + of the rest of the static tile sizes in the respective reassociation group. + There can be at most one dynamic size per reassociation group. + + Example: #iree_gpu.expand_dims<[[0], [1], [2, 3]], output_shape = [?, ?, ?, 8]> + }]; + + let parameters = (ins + "ArrayAttr":$reassociations, + "DenseI64ArrayAttr":$output_shape + ); + + let builders = [ + AttrBuilder<(ins + "ArrayRef":$reassociations, + "ArrayRef":$outputShape)> + ]; + + let assemblyFormat = "`<` $reassociations `,` `output_shape` `=` custom($output_shape) `>`"; + + let extraClassDeclaration = [{ + SmallVector getReassociationIndices() { + return llvm::to_vector<4>(llvm::map_range( + getReassociations().getAsRange(), + [](ArrayAttr arrayAttr) -> ReassociationIndices { + return llvm::to_vector<2>(llvm::map_range( + arrayAttr.getAsRange(), + [](IntegerAttr idx) { return idx.getInt(); })); + })); + } + }]; + + let genVerifyDecl = 1; +} + // Lane Index Hint Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp index 6e432f24518d..78828768eedf 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp @@ -4,6 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" @@ -294,10 +296,44 @@ getVectorDistributeReductionConfig( int subgroup = partialReductionSize / subgroupStride; int64_t subgroupBasis = (subgroup == 0) ? 1 : subgroup; - partialReductionTileSizes[lastReductionDim] = partialReductionSize; - threadTileSizes[lastReductionDim] = threadLoads; - threadCounts[lastReductionDim] = threadBasis; - subGroupCounts[lastReductionDim] = subgroupBasis; + SmallVector reassociations; + SmallVector outputShape; + + // We require the reduction dimension to be evenly divisible by threadLoads + // because the current expansion strategy doesn't support padding. + if (ShapedType::isStaticShape(bounds) && threadLoads > 1 && + bounds[lastReductionDim] % threadLoads == 0) { + workgroupTileSizes.push_back(0); + partialReductionTileSizes.push_back(0); + threadTileSizes.push_back(0); + threadCounts.push_back(1); + subGroupCounts.push_back(1); + mapping.push_back(mapping.size()); + + int64_t outer = lastReductionDim; + int64_t inner = lastReductionDim + 1; + + for (int64_t i = 0; i < op.getNumLoops(); ++i) { + if (i == lastReductionDim) { + int64_t idx = outputShape.size(); + reassociations.push_back({idx, idx + 1}); + outputShape.append({ShapedType::kDynamic, threadLoads}); + } else { + reassociations.push_back({static_cast(outputShape.size())}); + outputShape.push_back(ShapedType::kDynamic); + } + } + + partialReductionTileSizes[outer] = partialReductionSize / threadLoads; + threadTileSizes[inner] = threadLoads; + threadCounts[outer] = threadBasis; + subGroupCounts[outer] = subgroupBasis; + } else { + partialReductionTileSizes[lastReductionDim] = partialReductionSize; + threadTileSizes[lastReductionDim] = threadLoads; + threadCounts[lastReductionDim] = threadBasis; + subGroupCounts[lastReductionDim] = subgroupBasis; + } ArrayAttr subgroupBasisAttr = b.getArrayAttr( {b.getI64ArrayAttr(subGroupCounts), b.getI64ArrayAttr(mapping)}); @@ -305,13 +341,20 @@ getVectorDistributeReductionConfig( ArrayAttr threadBasisAttr = b.getArrayAttr( {b.getI64ArrayAttr(threadCounts), b.getI64ArrayAttr(mapping)}); - NamedAttribute configAttrs[] = { - NamedAttribute("workgroup", b.getI64ArrayAttr(workgroupTileSizes)), - NamedAttribute("partial_reduction", + SmallVector configAttrs = { + b.getNamedAttr("workgroup", b.getI64ArrayAttr(workgroupTileSizes)), + b.getNamedAttr("partial_reduction", b.getI64ArrayAttr(partialReductionTileSizes)), - NamedAttribute("thread", b.getI64ArrayAttr(threadTileSizes)), - NamedAttribute("lane_basis", threadBasisAttr), - NamedAttribute("subgroup_basis", subgroupBasisAttr)}; + b.getNamedAttr("thread", b.getI64ArrayAttr(threadTileSizes)), + b.getNamedAttr("lane_basis", threadBasisAttr), + b.getNamedAttr("subgroup_basis", subgroupBasisAttr), + }; + + if (!reassociations.empty()) { + auto dimExpandAttr = + DimensionExpansionAttr::get(context, reassociations, outputShape); + configAttrs.emplace_back(b.getNamedAttr("expand_dims", dimExpandAttr)); + } auto configDict = b.getDictionaryAttr(configAttrs); auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 5a80b5335d49..2f97bcabe869 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -736,6 +736,8 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); funcPassManager.addPass(createGPUPromoteMatmulOperandsPass()); + funcPassManager.addPass(createGPUExpandDimensionsPass()); + // Tile to reduction loops. { GPUApplyTilingLevelPassOptions options; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir index 80be4aa67df1..d0f5410dc619 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir @@ -84,11 +84,12 @@ func.func @reduction_with_no_consumer() { // CHECK-LABEL: func.func @reduction_with_no_consumer // CHECK: lowering_config = #iree_gpu.lowering_config -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64], [0, 1, 2, 3] -// CHECK-SAME: partial_reduction = [0, 0, 1, 4096] -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 8], [0, 1, 2, 3] -// CHECK-SAME: thread = [0, 0, 1, 8], -// CHECK-SAME: workgroup = [1, 1, 0, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2], [3, 4]{{\]}}, output_shape = [?, ?, ?, ?, 8]> +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64, 1], [0, 1, 2, 3, 4]{{\]}} +// CHECK-SAME: partial_reduction = [0, 0, 1, 512, 0] +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 8, 1], [0, 1, 2, 3, 4]{{\]}} +// CHECK-SAME: thread = [0, 0, 1, 1, 8] +// CHECK-SAME: workgroup = [1, 1, 0, 0, 0] // ----- @@ -150,20 +151,22 @@ func.func @test_multiple_reduction() { // CHECK-SAME: ins(%{{.*}} : tensor<2x32x10x16384xf32>) // CHECK-SAME: outs({{.*}}: tensor<2x32xf32>) // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64], [0, 1, 2, 3]], -// CHECK-SAME: partial_reduction = [0, 0, 1, 8192], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16], [0, 1, 2, 3]], -// CHECK-SAME: thread = [0, 0, 1, 8], -// CHECK-SAME: workgroup = [1, 1, 0, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2], [3, 4]{{\]}}, output_shape = [?, ?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 1, 1024, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 1, 8], +// CHECK-SAME: workgroup = [1, 1, 0, 0, 0] // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map1, #map1], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"]} // CHECK-SAME: ins{{.*}}, {{.*}} : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) // CHECK-SAME: outs(%{{.*}} : tensor<2x32xf32>) // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64], [0, 1, 2, 3]], -// CHECK-SAME: partial_reduction = [0, 0, 1, 8192], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16], [0, 1, 2, 3]], -// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2], [3, 4]{{\]}}, output_shape = [?, ?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 1, 64, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 1, 1024, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16, 1], [0, 1, 2, 3, 4]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 1, 8], // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} // CHECK-SAME: ins({{.*}}, %{{.*}}, {{.*}} : tensor<2x32x10x16384xf16>, tensor<2x32xf32>, tensor<2x32xf32>) @@ -250,11 +253,12 @@ func.func @test_multiple_stores(%arg0: !iree_tensor_ext.dispatch.tensor, +// CHECK-SAME: lane_basis = {{\[}}[1, 64, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 1024, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 16, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: thread = [0, 1, 4], +// CHECK-SAME: workgroup = [1, 0, 0] // ----- @@ -291,9 +295,9 @@ func.func @test_gather_config(%arg0: !iree_tensor_ext.dispatch.tensor, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [4, 1, 0, 0] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir index 08b764e77fca..56767f219292 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir @@ -119,11 +119,12 @@ func.func @vmt1() attributes {hal.executable.target = #executable_target_rocm_hs // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [1, 8, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [1, 8, 0, 0] // ----- @@ -162,11 +163,12 @@ func.func @matvec_like_no_m_dim() attributes {hal.executable.target = #executabl // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 64], [0, 1]], -// CHECK-SAME: partial_reduction = [0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1], [0, 1]], -// CHECK-SAME: thread = [0, 8], -// CHECK-SAME: workgroup = [8, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1, 2]{{\]}}, output_shape = [?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 64, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]{{\]}}, +// CHECK-SAME: thread = [0, 1, 8], +// CHECK-SAME: workgroup = [8, 0, 0] // ----- @@ -204,11 +206,12 @@ func.func @matvec_unit_n_dim() attributes {hal.executable.target = #executable_t // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [8, 1, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [8, 1, 0, 0] // ----- @@ -248,11 +251,12 @@ func.func @vmt2() attributes {hal.executable.target = #executable_target_rocm_hs // CDNA3-SAME: translation_info = #[[$TRANSLATION]] // CDNA3: linalg.generic // CDNA3-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CDNA3-SAME: lane_basis = {{\[}}[1, 1, 32], [0, 1, 2]], -// CDNA3-SAME: partial_reduction = [0, 0, 256], -// CDNA3-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CDNA3-SAME: thread = [0, 0, 8], -// CDNA3-SAME: workgroup = [1, 4, 0] +// CDNA3-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CDNA3-SAME: lane_basis = {{\[}}[1, 1, 32, 1], [0, 1, 2, 3]{{\]}}, +// CDNA3-SAME: partial_reduction = [0, 0, 32, 0], +// CDNA3-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CDNA3-SAME: thread = [0, 0, 1, 8], +// CDNA3-SAME: workgroup = [1, 4, 0, 0] // ----- @@ -308,11 +312,12 @@ func.func @i4_dequant_matvec() { // CHECK: linalg.generic // CHECK: linalg.generic // CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 1, 128], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 1, 2], -// CHECK-SAME: workgroup = [8, 0, 0] +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 2]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 1, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 1, 1, 2], +// CHECK-SAME: workgroup = [8, 0, 0, 0] // ----- @@ -353,11 +358,12 @@ func.func @skinny_mmt_lhs_is_vector() { // CHECK: linalg.matmul // CHECK-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]] // CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [2, 1, 0]}>} +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [2, 1, 0, 0]}>} // ----- @@ -395,11 +401,12 @@ func.func @skinny_mmt_lhs_is_matrix() { // CHECK: linalg.matmul // CHECK-SAME: indexing_maps // CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{ -// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64], [0, 1, 2]], -// CHECK-SAME: partial_reduction = [0, 0, 512], -// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1], [0, 1, 2]], -// CHECK-SAME: thread = [0, 0, 8], -// CHECK-SAME: workgroup = [8, 1, 0]}>} +// CHECK-SAME: expand_dims = #iree_gpu.expand_dims<{{\[}}[0], [1], [2, 3]{{\]}}, output_shape = [?, ?, ?, 8]>, +// CHECK-SAME: lane_basis = {{\[}}[1, 1, 64, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: partial_reduction = [0, 0, 64, 0], +// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1], [0, 1, 2, 3]{{\]}}, +// CHECK-SAME: thread = [0, 0, 1, 8], +// CHECK-SAME: workgroup = [8, 1, 0, 0]}>} // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir index b728b6d9e7b6..86c4b0e4b232 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir @@ -37,21 +37,21 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info : vector<1x1x4xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1x1x1x4xf32> +// CHECK-DAG: %[[CST_ACC:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf32> // CHECK-DAG: gpu.thread_id x -// CHECK: %[[R0:.+]] = scf.for %{{.*}} = %c0 to %c10240 step %c1024 iter_args(%[[A0:.+]] = %[[CST]]) -> (vector<1x1x4xf32>) { -// CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<512x10240xf32, #hal.descriptor_type>, vector<4xf32> -// CHECK: %[[STRIDED:.+]] = vector.insert_strided_slice %[[V]], {{.*}} : vector<4xf32> into vector<1x1x4xf32> -// CHECK: %[[ADD:.+]] = arith.addf %[[STRIDED]], %[[A0]] : vector<1x1x4xf32> -// CHECK: scf.yield %[[ADD]] : vector<1x1x4xf32> +// CHECK: %[[R0:.+]] = scf.for %{{.*}} = %c0 to %c2560 step %c256 iter_args(%[[A0:.+]] = %[[CST_ACC]]) -> (vector<1x1x1xf32>) { +// CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32> +// CHECK: %[[STRIDED:.+]] = vector.insert_strided_slice %[[V]], {{.*}} : vector<1x4xf32> into vector<1x1x1x1x1x4xf32> +// CHECK: %[[REDUCE:.+]] = vector.multi_reduction , %[[STRIDED]], %[[CST_ACC]] [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: %[[ADD:.+]] = arith.addf %[[REDUCE]], %[[A0]] : vector<1x1x1xf32> +// CHECK: scf.yield %[[ADD]] : vector<1x1x1xf32> // CHECK: } // CHECK: gpu.subgroup_reduce add {{.*}} cluster(size = 32) : (f32) -> f32 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<10xf32, #gpu.address_space> -// CHECK: vector.transfer_write %{{.*}}, %[[ALLOC]]{{.*}} : vector<1xf32> // CHECK: gpu.barrier // CHECK: vector.transfer_read %[[ALLOC]]{{.*}} // CHECK: gpu.subgroup_reduce add {{.*}} cluster(size = 8) : (f32) -> f32 -// CHECK: vector.transfer_write {{.*}} : vector, memref<512xf32, #hal.descriptor_type> // ----- @@ -103,15 +103,14 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info (vector<1x1x4xf32>) { -// CHECK: vector.transfer_read {{.*}} : memref<512x10240xf32, -// CHECK: arith.addf {{.*}} : vector<1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { +// CHECK: vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32> +// CHECK: vector.multi_reduction , {{.*}} [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: arith.addf {{.*}} : vector<1x1x1xf32> // CHECK: scf.yield // CHECK: gpu.subgroup_reduce -// CHECK: vector.transfer_write {{.*}} : vector<1xf32 // CHECK: gpu.subgroup_reduce // CHECK: arith.divf {{.*}} : vector -// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, {{.*}} // CHECK: return // ----- @@ -144,30 +143,25 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info (vector<1x1x4xf32>) { -// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, -// CHECK: arith.maxnumf {{.*}} : vector<1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { +// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, {{.*}}>, vector<1x4xf32> +// CHECK: vector.multi_reduction , {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32> // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce maxnumf -// CHECK: vector.transfer_write // CHECK: gpu.barrier // CHECK: gpu.subgroup_reduce maxnumf -// CHECK: vector.broadcast %{{.*}} : f32 to vector<1x1x4xf32> -// CHECK: scf.for {{.*}} -> (vector<1x1x4xf32>) { +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp +// CHECK: vector.multi_reduction // CHECK: arith.addf // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce add -// CHECK: vector.transfer_write // CHECK: gpu.barrier -// CHECK: vector.transfer_read // CHECK: gpu.subgroup_reduce add -// CHECK: vector.broadcast -// CHECK: scf.for +// CHECK: scf.forall // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp @@ -206,23 +200,22 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { // CHECK: #[[TRANSLATION_INFO:.+]] = #iree_codegen.translation_info (vector<1x1x4xf32>) { -// CHECK: vector.transfer_read {{.*}} : memref<12x256x40960xf32, -// CHECK: arith.maxnumf {{.*}} : vector<1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { +// CHECK: vector.transfer_read {{.*}} : memref<12x256x40960xf32, {{.*}}>, vector<1x4xf32> +// CHECK: vector.multi_reduction , {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32> +// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32> // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce maxnumf -// CHECK: vector.broadcast %{{.*}} : f32 to vector<1x1x4xf32> -// CHECK: scf.for {{.*}} -> (vector<1x1x4xf32>) { +// CHECK: vector.broadcast %{{.*}} : f32 to vector<1x1x1x1x1x4xf32> +// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) { // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp +// CHECK: vector.multi_reduction // CHECK: arith.addf // CHECK: scf.yield -// CHECK: vector.multi_reduction // CHECK: gpu.subgroup_reduce add -// CHECK: vector.broadcast -// CHECK: scf.for +// CHECK: scf.forall // CHECK: vector.transfer_read // CHECK: arith.subf // CHECK: math.exp @@ -523,9 +516,13 @@ hal.executable private @i4_dequant_matvec { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x4xf16> -// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<1x1x4xf16>) -// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1x1x4xf16> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1x1x4xf16> - -// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1, 2] : vector<1x1x4xf16> to f16 +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf16> +// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<1x1x1xf16>) +// CHECK: vector.transfer_read {{.*}} : memref<4096x32x128xi4, {{.*}}>, vector<1x4xi4> +// CHECK: arith.extui %{{.*}} : vector<1x1x1x1x1x4xi4> to vector<1x1x1x1x1x4xi32> +// CHECK: arith.uitofp %{{.*}} : vector<1x1x1x1x1x4xi32> to vector<1x1x1x1x1x4xf16> +// CHECK: arith.subf %{{.*}}, %{{.*}} : vector<1x1x1x1x1x4xf16> +// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1x1x1x1x1x4xf16> +// CHECK: vector.contract {{.*}} : vector<1x1x1x1x1x4xf16>, vector<1x1x1x1x1x4xf16> into vector<1x1x1xf16> + +// CHECK: vector.extract {{.*}} : f16 from vector<1x1x1xf16> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir index b9ec424ce9c6..d48dc93ec947 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir @@ -138,16 +138,16 @@ hal.executable private @i4_dequant_matvec { // RDNA3-DAG: %[[C0:.+]] = arith.constant 0 : index // RDNA3-DAG: %[[C32:.+]] = arith.constant 32 : index // RDNA3-DAG: %[[C1:.+]] = arith.constant 1 : index -// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x4xf16> -// RDNA3: %[[FOR:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<4x1x1x1x1x4xf16>) -// RDNA3: %{{.*}} = arith.extui %{{.*}} : vector<4x1x1x1x1x4xi4> to vector<4x1x1x1x1x4xi32> -// RDNA3: %{{.*}} = arith.uitofp %{{.*}} : vector<4x1x1x1x1x4xi32> to vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.subf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> -// RDNA3: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x4xf16> +// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x1xf16> +// RDNA3: %[[FOR:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<4x1x1x1x1x1xf16>) +// RDNA3: memref.expand_shape {{.*}} : memref<4x1x128xi4, {{.*}}> into memref<4x1x32x4xi4, {{.*}}> +// RDNA3: %{{.*}} = arith.extui %{{.*}} : vector<4x1x1x1x1x1x1x1x4xi4> to vector<4x1x1x1x1x1x1x1x4xi32> +// RDNA3: %{{.*}} = arith.uitofp %{{.*}} : vector<4x1x1x1x1x1x1x1x4xi32> to vector<4x1x1x1x1x1x1x1x4xf16> +// RDNA3: %{{.*}} = arith.subf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x1x1x1x4xf16> +// RDNA3: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x1x1x1x4xf16> +// RDNA3: vector.contract {{.*}} : vector<1x1x1x1x1x4xf16>, vector<4x1x1x1x1x1x1x1x4xf16> into vector<4x1x1x1x1x1xf16> -// RDNA3: %{{.*}} = vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<4x1x1x1x1x4xf16> to vector<4x1x1xf16> +// RDNA3: vector.shape_cast %{{.*}} : vector<4x1x1x1x1x1xf16> to vector<4x1x1xf16> // ----- @@ -252,13 +252,13 @@ hal.executable private @matvec_fp16 { // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index -// CHECK-DAG: %[[C4096:.+]] = arith.constant 4096 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x1x1x1x1x8xf16> -// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C4096]] step %[[C512]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<8x1x1x1x1x8xf16>) -// CHECK: {{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<8x1x1x1x1x8xf16> -// CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<8x1x1x1x1x8xf16> +// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x1x1x1x1x1xf16> +// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C512]] step %[[C64]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<8x1x1x1x1x1xf16>) +// CHECK: memref.expand_shape {{.*}} : memref<8x512xf16, {{.*}}> into memref<8x64x8xf16, {{.*}}> +// CHECK: vector.contract {{.*}} : vector<1x1x1x1x1x8xf16>, vector<8x1x1x1x1x1x1x1x8xf16> into vector<8x1x1x1x1x1xf16> -// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<8x1x1x1x1x8xf16> to vector<8x1x1xf16> +// CHECK: vector.shape_cast %{{.*}} : vector<8x1x1x1x1x1xf16> to vector<8x1x1xf16> // ----- @@ -304,14 +304,14 @@ hal.executable private @matvec_fp16 { // RDNA3: func.func @matvec_fp16() // RDNA3-SAME: translation_info = #[[$TRANSLATION]] // RDNA3-DAG: %[[C0:.+]] = arith.constant 0 : index -// RDNA3-DAG: %[[C256:.+]] = arith.constant 256 : index -// RDNA3-DAG: %[[C4096:.+]] = arith.constant 4096 : index -// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x8xf16> -// RDNA3: scf.for %{{.+}} = %[[C0]] to %[[C4096]] step %[[C256]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<4x1x1x1x1x8xf16>) -// RDNA3: {{.*}} = arith.mulf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x8xf16> -// RDNA3: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4x1x1x1x1x8xf16> +// RDNA3-DAG: %[[C512:.+]] = arith.constant 512 : index +// RDNA3-DAG: %[[C32:.+]] = arith.constant 32 : index +// RDNA3-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4x1x1x1x1x1xf16> +// RDNA3: scf.for %{{.+}} = %[[C0]] to %[[C512]] step %[[C32]] iter_args(%[[ARG:.+]] = %[[CST]]) -> (vector<4x1x1x1x1x1xf16>) +// RDNA3: memref.expand_shape {{.*}} : memref<4x256xf16, {{.*}}> into memref<4x32x8xf16, {{.*}}> +// RDNA3: vector.contract {{.*}} : vector<1x1x1x1x1x8xf16>, vector<4x1x1x1x1x1x1x1x8xf16> into vector<4x1x1x1x1x1xf16> -// RDNA3: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<4x1x1x1x1x8xf16> to vector<4x1x1xf16> +// RDNA3: vector.shape_cast %{{.*}} : vector<4x1x1x1x1x1xf16> to vector<4x1x1xf16> // ----- diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel index 9b506a2a873a..49e728b41187 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel @@ -70,6 +70,7 @@ iree_compiler_cc_library( ":EncodingTypesGen", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", "//compiler/src/iree/compiler/Dialect/TensorExt/IR", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:DialectUtils", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt index 708af4889087..968fbb11d817 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt @@ -54,6 +54,7 @@ iree_cc_library( MLIRTensorUtils iree::compiler::Dialect::LinalgExt::Utils iree::compiler::Dialect::TensorExt::IR + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 3e84bdf07581..895fbf40bcb3 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" +#include "iree/compiler/Utils/EncodingUtils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -257,60 +258,6 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, b.getTypeArrayAttr(elemTypes), mapsAttr, iterationSizesAttr); } -/// Parse a list of integer values and/or dynamic values ('?') -static FailureOr> -parseDynamicI64IntegerList(AsmParser &parser) { - SmallVector integerVals; - if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&] { - int64_t value = ShapedType::kDynamic; - if (failed(parser.parseOptionalQuestion()) && - failed(parser.parseInteger(value))) { - return failure(); - } - integerVals.push_back(value); - return success(); - }))) { - return failure(); - } - return integerVals; -} - -/// Utility to parse an array of integer and/or dynamic values (`?`). -static ParseResult parseDynamicI64ArrayAttr(AsmParser &p, ArrayAttr &attr) { - FailureOr> integerVals = parseDynamicI64IntegerList(p); - if (failed(integerVals)) { - return failure(); - } - auto integerValsAttr = - llvm::map_to_vector(integerVals.value(), [&](int64_t val) -> Attribute { - return IntegerAttr::get(IntegerType::get(p.getContext(), 64), val); - }); - attr = ArrayAttr::get(p.getContext(), integerValsAttr); - return success(); -} - -/// Print a list of integer values and/or dynamic values ('?') -static void printDynamicI64IntegerList(AsmPrinter &printer, - ArrayRef vals) { - printer << "["; - llvm::interleaveComma(vals, printer, [&](int64_t val) { - if (ShapedType::isDynamic(val)) { - printer << "?"; - } else { - printer << val; - } - }); - printer << "]"; -} - -/// Utility to print an array of integer and/or dynamic values. Dynamic values -/// are printed as `?`. -static void printDynamicI64ArrayAttr(AsmPrinter &p, ArrayAttr attrs) { - SmallVector intVals = llvm::map_to_vector( - attrs, [&](Attribute attr) { return cast(attr).getInt(); }); - return printDynamicI64IntegerList(p, intVals); -} - LogicalResult EncodingAttr::verify(function_ref emitError, IntegerAttr operandIndexAttr, diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index 36f0d197021a..9b5109c2a4f7 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -34,6 +34,7 @@ iree_compiler_cc_library( name = "Utils", srcs = [ "ConversionUtils.cpp", + "EncodingUtils.cpp", "EquivalenceUtils.cpp", "FlatbufferUtils.cpp", "Indexing.cpp", @@ -49,6 +50,7 @@ iree_compiler_cc_library( hdrs = [ "ConversionUtils.h", "EmbeddedDataDirectory.h", + "EncodingUtils.h", "EquivalenceUtils.h", "FlatbufferUtils.h", "Folding.h", diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt index 6dd027fc39a4..b7f0fee64fdf 100644 --- a/compiler/src/iree/compiler/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt @@ -16,6 +16,7 @@ iree_cc_library( HDRS "ConversionUtils.h" "EmbeddedDataDirectory.h" + "EncodingUtils.h" "EquivalenceUtils.h" "FlatbufferUtils.h" "Folding.h" @@ -35,6 +36,7 @@ iree_cc_library( "TracingUtils.h" SRCS "ConversionUtils.cpp" + "EncodingUtils.cpp" "EquivalenceUtils.cpp" "FlatbufferUtils.cpp" "Indexing.cpp" diff --git a/compiler/src/iree/compiler/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Utils/EncodingUtils.cpp new file mode 100644 index 000000000000..d6a5c51d6291 --- /dev/null +++ b/compiler/src/iree/compiler/Utils/EncodingUtils.cpp @@ -0,0 +1,86 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Utils/EncodingUtils.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir::iree_compiler { + +/// Parse a list of integer values and/or dynamic values ('?') +FailureOr> parseDynamicI64IntegerList(AsmParser &parser) { + SmallVector integerVals; + if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&] { + int64_t value = ShapedType::kDynamic; + if (failed(parser.parseOptionalQuestion()) && + failed(parser.parseInteger(value))) { + return failure(); + } + integerVals.push_back(value); + return success(); + }))) { + return failure(); + } + return integerVals; +} + +/// Print a list of integer values and/or dynamic values ('?') +void printDynamicI64IntegerList(AsmPrinter &printer, ArrayRef vals) { + printer << "["; + llvm::interleaveComma(vals, printer, [&](int64_t val) { + if (ShapedType::isDynamic(val)) { + printer << "?"; + } else { + printer << val; + } + }); + printer << "]"; +} + +/// Parse a list of integer values and/or dynamic values ('?') into an ArrayAttr +ParseResult parseDynamicI64ArrayAttr(AsmParser &parser, ArrayAttr &attr) { + FailureOr> integerVals = + parseDynamicI64IntegerList(parser); + if (failed(integerVals)) { + return failure(); + } + auto integerValsAttr = + llvm::map_to_vector(integerVals.value(), [&](int64_t val) -> Attribute { + return IntegerAttr::get(IntegerType::get(parser.getContext(), 64), val); + }); + attr = ArrayAttr::get(parser.getContext(), integerValsAttr); + return success(); +} + +/// Print an ArrayAttr of integer values and/or dynamic values ('?') +void printDynamicI64ArrayAttr(AsmPrinter &printer, ArrayAttr attrs) { + SmallVector intVals = llvm::map_to_vector( + attrs, [&](Attribute attr) { return cast(attr).getInt(); }); + return printDynamicI64IntegerList(printer, intVals); +} + +/// Parse a list of integer values and/or dynamic values ('?') into a +/// DenseI64ArrayAttr +ParseResult parseDynamicI64DenseArrayAttr(AsmParser &parser, + DenseI64ArrayAttr &attr) { + FailureOr> integerVals = + parseDynamicI64IntegerList(parser); + if (failed(integerVals)) { + return failure(); + } + attr = DenseI64ArrayAttr::get(parser.getContext(), *integerVals); + return success(); +} + +/// Print a DenseI64ArrayAttr as a list of integer values and/or dynamic values +/// ('?') +void printDynamicI64DenseArrayAttr(AsmPrinter &printer, + DenseI64ArrayAttr attr) { + printDynamicI64IntegerList(printer, attr.asArrayRef()); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Utils/EncodingUtils.h b/compiler/src/iree/compiler/Utils/EncodingUtils.h new file mode 100644 index 000000000000..a6e52f07ab8c --- /dev/null +++ b/compiler/src/iree/compiler/Utils/EncodingUtils.h @@ -0,0 +1,39 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_UTILS_ENCODINGUTILS_H_ +#define IREE_COMPILER_UTILS_ENCODINGUTILS_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir::iree_compiler { + +/// Parse a list of integer values and/or dynamic values ('?') +FailureOr> parseDynamicI64IntegerList(AsmParser &parser); + +/// Print a list of integer values and/or dynamic values ('?') +void printDynamicI64IntegerList(AsmPrinter &printer, ArrayRef vals); + +/// Parse a list of integer values and/or dynamic values ('?') into an ArrayAttr +ParseResult parseDynamicI64ArrayAttr(AsmParser &parser, ArrayAttr &attr); + +/// Print an ArrayAttr of integer values and/or dynamic values ('?') +void printDynamicI64ArrayAttr(AsmPrinter &printer, ArrayAttr attrs); + +/// Parse a list of integer values and/or dynamic values ('?') into a +/// DenseI64ArrayAttr +ParseResult parseDynamicI64DenseArrayAttr(AsmParser &parser, + DenseI64ArrayAttr &attr); + +/// Print a DenseI64ArrayAttr as a list of integer values and/or dynamic values +/// ('?') +void printDynamicI64DenseArrayAttr(AsmPrinter &printer, DenseI64ArrayAttr attr); + +} // namespace mlir::iree_compiler + +#endif // IREE_COMPILER_UTILS_ENCODINGUTILS_H_ From 7c47359b44b0b7f4e38ef04fb357f847bcd3c769 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 14 Jan 2026 13:08:28 +0530 Subject: [PATCH 32/71] [LinalgExt] Implement generateScalarImplementation for map_gather (#23070) -- This commit implements `generateScalarImplementation` for `map_gather` op. Signed-off-by: Abhishek Varma --- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 66 ++++++++++++------- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 12 +++- .../LinalgExt/IR/TilingInterfaceImpl.cpp | 65 ++++++++++++++++++ .../Transforms/test/convert_to_loops.mlir | 46 +++++++++++++ 4 files changed, 163 insertions(+), 26 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 22728e5167e2..6d2badfdb75c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -656,6 +656,45 @@ void MapGatherOp::insertTransformationAtStart( transformBody.eraseArguments(0, oldOutputIndices.size()); } +/// Shared implementation for inlining the transformation body of map_gather +/// and map_scatter ops. +static void inlineMapGatherScatterBodyImpl( + OpBuilder &b, Location loc, Region &transformRegion, + ValueRange transformBodyIndices, + function_ref)> bodyBuilder) { + Block &transformBlock = transformRegion.front(); + IRMapping mapping; + // Map the induction variables of the loop nest to the block arguments of the + // transformation body. + for (auto [idx, arg] : llvm::enumerate(transformBlock.getArguments())) { + mapping.map(arg, transformBodyIndices[idx]); + } + // Clone the operations within the transformation body to the current + // insertion point, and map their results to the new cloned operations' + // results. + for (Operation &op : transformBlock.without_terminator()) { + Operation *clonedOp = b.clone(op, mapping); + for (auto [result, clonedResult] : + llvm::zip_equal(op.getResults(), clonedOp->getResults())) { + mapping.map(result, clonedResult); + } + } + + // Get the cloned values that were yielded by the transformation body to pass + // to the bodyBuilder. + SmallVector mappedYieldedValues = llvm::map_to_vector( + transformBlock.getTerminator()->getOperands(), + [&](Value operand) -> Value { return mapping.lookupOrDefault(operand); }); + bodyBuilder(b, loc, mappedYieldedValues); +} + +void MapGatherOp::inlineMapGatherBody( + OpBuilder &b, Location loc, ValueRange transformBodyIndices, + function_ref)> bodyBuilder) { + inlineMapGatherScatterBodyImpl(b, loc, getTransformationRegion(), + transformBodyIndices, bodyBuilder); +} + //===----------------------------------------------------------------------===// // MapScatterOp //===----------------------------------------------------------------------===// @@ -787,31 +826,8 @@ void MapScatterOp::insertTransformationAtStart( void MapScatterOp::inlineMapScatterBody( OpBuilder &b, Location loc, ValueRange transformBodyIndices, function_ref)> bodyBuilder) { - Block &transformBlock = getTransformationRegion().front(); - IRMapping mapping; - // Map the induction variables of the loop nest to the block arguments of the - // transformation body. The induction variables are the indices looping over - // the elements of input operand. - for (auto [idx, arg] : llvm::enumerate(transformBlock.getArguments())) { - mapping.map(arg, transformBodyIndices[idx]); - } - // Clone the operations within the transformation body to the current - // insertion point, and map their results to the new cloned operations' - // results. - for (Operation &op : transformBlock.without_terminator()) { - Operation *clonedOp = b.clone(op, mapping); - for (auto [result, clonedResult] : - llvm::zip_equal(op.getResults(), clonedOp->getResults())) { - mapping.map(result, clonedResult); - } - } - - // Get the cloned values that were yielded by the transformation body to pass - // to the bodyBuilder. - SmallVector mappedYieldedValues = llvm::map_to_vector( - transformBlock.getTerminator()->getOperands(), - [&](Value operand) -> Value { return mapping.lookupOrDefault(operand); }); - bodyBuilder(b, loc, mappedYieldedValues); + inlineMapGatherScatterBodyImpl(b, loc, getTransformationRegion(), + transformBodyIndices, bodyBuilder); } bool MapScatterOp::isIdentity() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 01510e5e7703..f09805746254 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -280,7 +280,8 @@ def IREELinalgExt_MapGatherOp : IREELinalgExt_Op<"map_gather", "getTiledImplementation", "generateResultTileValue", "getIterationDomainTileFromOperandTiles", - "getTiledImplementationFromOperandTiles"]>]> { + "getTiledImplementationFromOperandTiles", + "generateScalarImplementation"]>]> { let summary = [{Gather with a mapping from result indices to source indices.}]; let description = [{ Takes two operands, `source` and `output`, and reads every element from @@ -337,6 +338,15 @@ def IREELinalgExt_MapGatherOp : IREELinalgExt_Op<"map_gather", transformationBuilder, int64_t numOutputIndices); + // Inline the transformation region of the map_gather op without its + // terminator, replacing the block arguments with the passed + // `transformBodyIndices`. The `bodyBuilder` function is called with the + // cloned `Value`s that would have been yielded by the terminator of + // the inlined transformation body (source indices and padding value). + void inlineMapGatherBody( + OpBuilder &b, Location loc, ValueRange transformBodyIndices, + function_ref)> bodyBuilder); + // Method to implement for specifying output range for // DestinationStyleOpInterface MutableOperandRange getDpsInitsMutable() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp index f3b71954fb0c..306eff61f43d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp @@ -565,6 +565,71 @@ FailureOr MapGatherOp::getTiledImplementationFromOperandTiles( return getTiledImplementation(b, mappedOffsets, mappedSizes); } +/// The body of the transformation_region is inlined, and the yielded indices +/// are used to read values from the source and write to the output. Bounds +/// checking is performed on the source indices, and the padding value is used +/// if the indices are out of bounds. +LogicalResult MapGatherOp::generateScalarImplementation(OpBuilder &b, + Location loc, + ValueRange ivs) { + // The scalar implementation is currently only implemented for buffer + // semantics. + if (!hasPureBufferSemantics()) { + return failure(); + } + + auto bodyBuilder = [&](OpBuilder nestedBuilder, Location nestedLoc, + ArrayRef yieldedValues) { + // The last yielded Value is the padding, the rest are source indices. + Value paddingValue = yieldedValues.back(); + ArrayRef loadIndices = yieldedValues.drop_back(); + + // Check bounds for each source dimension. Start with true so that + // for 0-D sources, inBounds is always true. + Value inBounds = nestedBuilder.createOrFold( + nestedLoc, /*value=*/1, /*width=*/1); + Value zero = + nestedBuilder.createOrFold(nestedLoc, 0); + for (auto [dim, idx] : llvm::enumerate(loadIndices)) { + Value dimSize = + memref::DimOp::create(nestedBuilder, nestedLoc, getSource(), dim); + + // Check: idx >= 0 + Value geZero = arith::CmpIOp::create( + nestedBuilder, nestedLoc, arith::CmpIPredicate::sge, idx, zero); + // Check: idx < dimSize + Value ltDim = arith::CmpIOp::create( + nestedBuilder, nestedLoc, arith::CmpIPredicate::slt, idx, dimSize); + // Combine: idx >= 0 && idx < dimSize + Value dimInBounds = + arith::AndIOp::create(nestedBuilder, nestedLoc, geZero, ltDim); + + inBounds = arith::AndIOp::create(nestedBuilder, nestedLoc, inBounds, + dimInBounds); + } + + // Create if-else: if in bounds, load from source; else use padding. + // The if yields the value to store. + auto ifOp = scf::IfOp::create(nestedBuilder, nestedLoc, + TypeRange{paddingValue.getType()}, inBounds, + /*addThenBlock=*/true, /*addElseBlock=*/true); + { + auto thenBuilder = ifOp.getThenBodyBuilder(); + Value loaded = memref::LoadOp::create(thenBuilder, nestedLoc, getSource(), + loadIndices); + scf::YieldOp::create(thenBuilder, nestedLoc, loaded); + } + { + auto elseBuilder = ifOp.getElseBodyBuilder(); + scf::YieldOp::create(elseBuilder, nestedLoc, paddingValue); + } + memref::StoreOp::create(nestedBuilder, nestedLoc, ifOp.getResult(0), + getOutput(), ivs); + }; + inlineMapGatherBody(b, loc, ivs, bodyBuilder); + return success(); +} + //===----------------------------------------------------------------------===// // MapScatterOp //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir index 32c39875fa04..06ed7cd81d59 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir @@ -1703,3 +1703,49 @@ func.func @map_scatter_memref( // CHECK-NEXT: %[[INPUT_ELEM:.+]] = memref.load %[[INPUT]][%[[IV]]] // CHECK-NEXT: memref.store %[[INPUT_ELEM]], %[[OUTPUT]] // CHECK-SAME: [%[[OUT_IDX]]#0, %[[OUT_IDX]]#1] : memref + +// ----- + +func.func @map_gather_memref( + %source: memref, %output: memref +) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = memref.dim %source, %c0 : memref + %dim1 = memref.dim %source, %c1 : memref + iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index): + %src_idx:2 = affine.delinearize_index %idx0 into (%dim0, %dim1) : index, index + %pad = arith.constant 0.0 : f32 + iree_linalg_ext.yield %src_idx#0, %src_idx#1, %pad : index, index, f32 + } : memref into memref + return +} +// CHECK: func @map_gather_memref +// CHECK-SAME: %[[SOURCE:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.{{0+}}e+00 : f32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[SRC_D0:.+]] = memref.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[SRC_D1:.+]] = memref.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[OUT_D0:.+]] = memref.dim %[[OUTPUT]], %[[C0]] +// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[OUT_D0]] step %[[C1]] +// CHECK: %[[SRC_IDX:.+]]:2 = affine.delinearize_index %[[IV]] into (%[[SRC_D0]], %[[SRC_D1]]) : index, index +// CHECK-DAG: %[[BOUND_D0:.+]] = memref.dim %[[SOURCE]], %[[C0]] +// CHECK-DAG: %[[GE_ZERO_0:.+]] = arith.cmpi sge, %[[SRC_IDX]]#0, %[[C0]] : index +// CHECK-DAG: %[[LT_DIM_0:.+]] = arith.cmpi slt, %[[SRC_IDX]]#0, %[[BOUND_D0]] : index +// CHECK-DAG: %[[IN_BOUNDS_0:.+]] = arith.andi %[[GE_ZERO_0]], %[[LT_DIM_0]] +// CHECK-DAG: %[[BOUND_D1:.+]] = memref.dim %[[SOURCE]], %[[C1]] +// CHECK-DAG: %[[GE_ZERO_1:.+]] = arith.cmpi sge, %[[SRC_IDX]]#1, %[[C0]] : index +// CHECK-DAG: %[[LT_DIM_1:.+]] = arith.cmpi slt, %[[SRC_IDX]]#1, %[[BOUND_D1]] : index +// CHECK-DAG: %[[IN_BOUNDS_1:.+]] = arith.andi %[[GE_ZERO_1]], %[[LT_DIM_1]] +// CHECK-DAG: %[[IN_BOUNDS:.+]] = arith.andi %[[IN_BOUNDS_0]], %[[IN_BOUNDS_1]] +// CHECK: %[[IF_RESULT:.+]] = scf.if %[[IN_BOUNDS]] -> (f32) { +// CHECK: %[[SOURCE_ELEM:.+]] = memref.load %[[SOURCE]] +// CHECK-SAME: [%[[SRC_IDX]]#0, %[[SRC_IDX]]#1] : memref +// CHECK: scf.yield %[[SOURCE_ELEM]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[PAD]] : f32 +// CHECK: } +// CHECK: memref.store %[[IF_RESULT]], %[[OUTPUT]][%[[IV]]] From c4b6725273d11089c7ddc171a8dae21e045c62f5 Mon Sep 17 00:00:00 2001 From: Zhewen Yu Date: Wed, 14 Jan 2026 09:35:07 +0000 Subject: [PATCH 33/71] [GPU] Account for scale operands in shared memory calculation (#23111) This PR fixes the shared memory calculation for scaled matmul operations by properly accounting for scale operands in the GPU heuristics. --------- Signed-off-by: Yu-Zhewen --- .../Codegen/Common/GPU/GPUHeuristics.cpp | 47 ++++++++++++++---- .../Codegen/Common/GPU/GPUHeuristics.h | 13 +++-- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 49 ++++++++++--------- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 27 ++++++---- .../ROCDL/config_tile_and_fuse_gfx950.mlir | 26 ++++++++++ .../compiler/Codegen/SPIRV/KernelConfig.cpp | 6 +-- 6 files changed, 120 insertions(+), 48 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index 70a501317a3e..8f91170b2f37 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Remarks.h" #define DEBUG_TYPE "iree-codegen-gpu-heuristics" @@ -65,13 +66,28 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const GemmSize &gemmSize) { static int64_t calculateOperandsSharedMemoryUsedInBytes( const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth, + int64_t lhsScaleBitwidth = 0, int64_t rhsScaleBitwidth = 0, int64_t numRhs = 1) { int64_t tileM = schedule.getTotalMSize() * schedule.getTotalMTileSize() * schedule.getTotalMSubgroupCount(); int64_t tileN = schedule.getTotalNSize() * schedule.getTotalNTileSize() * schedule.getTotalNSubgroupCount(); + + // For scaled matmul, the K dimension is split into Ko (outer) and Kb (block), + // where elements in a Kb block share the same scale. For lhs and rhs we + // account for both Ko and Kb, while for scale operands, only Ko. For regular + // matmul, scale bitwidth is 0 so the scale terms below have no effect. int64_t tileK = schedule.getTotalKSize() * schedule.getTotalKTileSize(); - return (tileM * tileK * lhsBitwidth + numRhs * tileN * tileK * rhsBitwidth) / + int64_t tileKb = schedule.kSizes.back() * schedule.kTileSizes.back(); + int64_t tileKo = tileK / tileKb; + + int64_t lhsSharedMemoryUsed = tileM * tileK * lhsBitwidth; + int64_t rhsSharedMemoryUsed = numRhs * tileN * tileK * rhsBitwidth; + int64_t aScaleSharedMemoryUsed = tileM * tileKo * lhsScaleBitwidth; + int64_t bScaleSharedMemoryUsed = numRhs * tileN * tileKo * rhsScaleBitwidth; + + return (lhsSharedMemoryUsed + rhsSharedMemoryUsed + aScaleSharedMemoryUsed + + bScaleSharedMemoryUsed) / 8; } @@ -647,9 +663,9 @@ static int64_t adjustSeedsForWgpCount(const GPUMatmulShapeType &problem, FailureOr deduceMMASchedule( const GPUMatmulShapeType &problem, ArrayRef intrinsics, const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes, - int64_t subgroupSize, std::optional wgpCount, bool transposedLhs, - bool transposedRhs, bool canUpcastAcc, bool mustBeAligned, - bool doCPromotion, int64_t splitReductionTripCnt) { + int64_t subgroupSize, std::optional wgpCount, Location loc, + bool transposedLhs, bool transposedRhs, bool canUpcastAcc, + bool mustBeAligned, bool doCPromotion, int64_t splitReductionTripCnt) { SmallVector sortedIntrinsics = sortMMAIntrinsics(problem, intrinsics); @@ -673,14 +689,19 @@ FailureOr deduceMMASchedule( LDBG() << "Chosen MMA schedule:\n" << schedule; auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool { - int64_t lhsBitwidth = intrinsic.aType.getIntOrFloatBitWidth(); - int64_t rhsBitwidth = intrinsic.bType.getIntOrFloatBitWidth(); - int64_t resultBitwidth = intrinsic.cType.getIntOrFloatBitWidth(); + int64_t lhsBitwidth = problem.aType.getIntOrFloatBitWidth(); + int64_t rhsBitwidth = problem.bType.getIntOrFloatBitWidth(); + int64_t resultBitwidth = problem.cType.getIntOrFloatBitWidth(); + int64_t lhsScaleBitwidth = + problem.aScaleType ? problem.aScaleType.getIntOrFloatBitWidth() : 0; + int64_t rhsScaleBitwidth = + problem.bScaleType ? problem.bScaleType.getIntOrFloatBitWidth() : 0; bool isAligned = isValidMMASchedule(problem, schedule, mustBeAligned, subgroupSize, transposedLhs, transposedRhs); int64_t sharedMemoryUsed = calculateOperandsSharedMemoryUsedInBytes( - schedule, lhsBitwidth, rhsBitwidth, problem.numHorizontallyFusedOps); + schedule, lhsBitwidth, rhsBitwidth, lhsScaleBitwidth, + rhsScaleBitwidth, problem.numHorizontallyFusedOps); // Add accumulator/result memory when it uses shared memory (LDS): // - Result needs padding in shared memory, OR // - matmul_accumulate loads accumulator from global memory via shared mem @@ -694,7 +715,15 @@ FailureOr deduceMMASchedule( LDBG() << "Available Shared Memory: " << sharedMemLimitInBytes << " bytes" << "Predicted Shared Memory Used by Schedule: " << sharedMemoryUsed << " bytes"; - return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; + + bool isValid = isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; + if (isValid) { + // Only emit remark for the shared memory usage of the valid schedule. + remark::analysis(loc, remark::RemarkOpts::name("SharedMemoryUsage") + .category("deduceMMASchedule")) + << std::to_string(sharedMemoryUsed); + } + return isValid; }; return fitScheduleInSharedMemory(schedule, isValidSchedule); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h index d613fb4e8c6d..16902b4064b6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -20,9 +20,13 @@ struct GPUMatmulShapeType { SmallVector nSizes; SmallVector kSizes; SmallVector batchSizes; + Type aType; Type bType; Type cType; + Type aScaleType; + Type bScaleType; + GemmSize gemmSize = GemmSize::NotSet; // Number of horizontally fused operations. @@ -34,11 +38,14 @@ struct GPUMatmulShapeType { int64_t numHorizontallyFusedOps = 1) : mSizes({m}), nSizes({n}), kSizes({k}), batchSizes({}), aType(a), bType(b), cType(c), numHorizontallyFusedOps(numHorizontallyFusedOps) {} + GPUMatmulShapeType(ArrayRef m, ArrayRef n, ArrayRef k, ArrayRef batch, Type a, - Type b, Type c, int64_t numHorizontallyFusedOps = 1) + Type b, Type c, Type aScale = nullptr, + Type bScale = nullptr, int64_t numHorizontallyFusedOps = 1) : mSizes(m), nSizes(n), kSizes(k), batchSizes(batch), aType(a), bType(b), - cType(c), numHorizontallyFusedOps(numHorizontallyFusedOps) {} + cType(c), aScaleType(aScale), bScaleType(bScale), + numHorizontallyFusedOps(numHorizontallyFusedOps) {} }; /// Struct containing information about a GPU MMA intrinsic type. @@ -147,7 +154,7 @@ struct GPUMMASchedule { FailureOr deduceMMASchedule( const GPUMatmulShapeType &problem, ArrayRef intrinsics, const GPUMMAHeuristicSeeds &seeds, int64_t sharedMemLimitInBytes, - int64_t subgroupSize, std::optional cuCount, + int64_t subgroupSize, std::optional cuCount, Location loc, bool transposedLhs = false, bool transposedRhs = false, bool canUpcastAcc = false, bool mustBeAligned = true, bool doCPromotion = false, int64_t splitReductionTripCnt = 0); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 18ea5f946f75..555635a3e3dc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -338,7 +338,7 @@ getContractionHeuristicSeeds(GPUMatmulShapeType problem, bool isGemm, /// due to padding requirements or because the operation has an existing /// accumulator that needs to be loaded from global memory (matmul_accumulate). static std::optional getMmaScheduleFromProblemAndTarget( - IREE::GPU::TargetAttr target, GPUMatmulShapeType problem, + IREE::GPU::TargetAttr target, GPUMatmulShapeType problem, Location loc, bool transposedLhs, bool transposedRhs, bool isGemm, bool mustBeAligned = true, bool doCPromotion = false, bool scaled = false, int64_t splitReductionTripCnt = 0) { @@ -433,7 +433,7 @@ static std::optional getMmaScheduleFromProblemAndTarget( // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule( problem, intrinsics, seeds, maxSharedMemoryBytes, targetSubgroupSize, - wgpCount, transposedLhs, transposedRhs, /*canUpcastAcc=*/false, + wgpCount, loc, transposedLhs, transposedRhs, /*canUpcastAcc=*/false, /*mustBeAligned=*/mustBeAligned, doCPromotion, splitReductionTripCnt); return schedule; } @@ -774,24 +774,24 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( assert((operands.size() == 3 || scaled) && "expected 3 operands"); assert((operands.size() == 5 || !scaled) && "expected 5 operands"); - Value lhs = operands[0]; - Value rhs = operands[1]; - - Value init = operands[2]; + Type lhsElemType = getElementTypeOrSelf(operands[0]); + Type rhsElemType = getElementTypeOrSelf(operands[1]); + Type initElemType = getElementTypeOrSelf(operands[2]); + Type lhsScaleType; + Type rhsScaleType; if (scaled) { - init = operands[4]; assert(llvm::all_of(operands, [](Value a) { return isa(a.getType()); }) && "All operands must be a shaped type"); - assert(*getRank(lhs) > *getRank(operands[2]) && - *getRank(rhs) > *getRank(operands[3]) && + assert(*getRank(operands[0]) > *getRank(operands[2]) && + *getRank(operands[1]) > *getRank(operands[3]) && "Expected operand #0 (lhs) and operand #1 (rhs) to have a greater " "rank than their corresponding scales, operand #2 (lhs_scale) and " "operand #3 (rhs_scale)"); + lhsScaleType = getElementTypeOrSelf(operands[2]); + rhsScaleType = getElementTypeOrSelf(operands[3]); + initElemType = getElementTypeOrSelf(operands[4]); } - Type lhsElemType = getElementTypeOrSelf(lhs); - Type rhsElemType = getElementTypeOrSelf(rhs); - Type initElemType = getElementTypeOrSelf(init); // Intentionally padded GEMM proved to be beneficial for performance for // the following layouts: 1) [M, K] x [K, N] 2) [M, K] x [N, K] // Therefore we disallow padding only when LHS is transposed. @@ -801,7 +801,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( getDimBoundsNoPad(batchDims), lhsElemType, rhsElemType, - initElemType}; + initElemType, + lhsScaleType, + rhsScaleType}; // Accumulator needs shared memory if: // - Padding requires C promotion, OR @@ -810,8 +812,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( (couldNeedPadding && CPromoteIfPadding) || hasExistingAccumulator; bool mustBeAligned = true; + Location loc = operands[0].getLoc(); std::optional schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, isGemm, + target, problem, loc, transposedLhs, transposedRhs, isGemm, /*mustBeAligned=*/true, doCPromotion, scaled, splitReductionTripCnt); if (!schedule && canSupportUnaligned) { @@ -821,8 +824,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( // accumulator. bool doCPromotionUnaligned = CPromoteIfPadding || hasExistingAccumulator; schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, isGemm, mustBeAligned, - doCPromotionUnaligned, scaled, splitReductionTripCnt); + target, problem, loc, transposedLhs, transposedRhs, isGemm, + mustBeAligned, doCPromotionUnaligned, scaled, splitReductionTripCnt); } if (!schedule) { @@ -890,7 +893,7 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( // Attach the MMA schedule as an attribute to the entry point export function // for later access in the pipeline. - MLIRContext *context = lhs.getContext(); + MLIRContext *context = target.getContext(); Builder b(context); SmallVector attrs = { {"workgroup", b.getI64ArrayAttr(workgroupTileSizes)}, @@ -1873,17 +1876,17 @@ setDirectConvolutionLoweringConfig(IREE::GPU::TargetAttr target, bool transposedRhs = rhsKPos > nPos; bool mustBeAligned = true; std::optional schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, /*isGemm=*/false, - mustBeAligned, /*doCPromotion=*/false, /*scaled=*/false, - splitReductionTripCnt); + target, problem, linalgOp.getLoc(), transposedLhs, transposedRhs, + /*isGemm=*/false, mustBeAligned, /*doCPromotion=*/false, + /*scaled=*/false, splitReductionTripCnt); if (!schedule && canSupportUnaligned) { LDBG() << "Attempting to deduce unaligned TileAndFuse MMA schedule"; mustBeAligned = false; schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs, /*isGemm=*/false, - mustBeAligned, /*doCPromotion=*/false, /*scaled=*/false, - splitReductionTripCnt); + target, problem, linalgOp.getLoc(), transposedLhs, transposedRhs, + /*isGemm=*/false, mustBeAligned, /*doCPromotion=*/false, + /*scaled=*/false, splitReductionTripCnt); } if (!schedule) { LDBG() << "Failed to deduce TileAndFuse MMA schedule"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index aced4d9490f0..a9c25fd7d343 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -330,12 +330,12 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, // First try to find a schedule with an exactly matching intrinsic. FailureOr schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - targetSubgroupSize, wgpCount); + targetSubgroupSize, wgpCount, op.getLoc()); if (failed(schedule)) { // Then try again by allowing upcasting accumulator. schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - targetSubgroupSize, wgpCount, + targetSubgroupSize, wgpCount, op.getLoc(), /*transposedLhs*/ false, /*transposedRhs*/ false, /*canUpcastAcc=*/true); } @@ -515,9 +515,16 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // all instances of schedule->m/nSubgroupCounts[0], // schedule->m/n/kTileSizes[0] and schedule->m/n/kSizes[0] need to use the // full list of sizes instead of just the first element. - GPUMatmulShapeType problem{ - {bounds[mDim]}, {bounds[nDim]}, {bounds[kDim]}, getDimBounds(batchDims), - lhsElemType, rhsElemType, initElemType, numHorizontallyFusedOps}; + GPUMatmulShapeType problem{{bounds[mDim]}, + {bounds[nDim]}, + {bounds[kDim]}, + getDimBounds(batchDims), + lhsElemType, + rhsElemType, + initElemType, + /*aScaleType=*/nullptr, + /*bScaleType=*/nullptr, + numHorizontallyFusedOps}; // Helper fn to store mma information. auto storeMmaInfo = [](IREE::GPU::MmaInterfaceAttr mma, @@ -582,13 +589,13 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, - targetSubgroupSize, wgpCount); + targetSubgroupSize, wgpCount, op.getLoc()); if (!schedule) { // Then try again by allowing upcasting accumulator. - schedule = deduceMMASchedule(problem, intrinsics, seeds, - maxSharedMemoryBytes, targetSubgroupSize, - wgpCount, transposedLhs, transposedRhs, - /*canUpcastAcc=*/true); + schedule = + deduceMMASchedule(problem, intrinsics, seeds, maxSharedMemoryBytes, + targetSubgroupSize, wgpCount, op.getLoc(), + transposedLhs, transposedRhs, /*canUpcastAcc=*/true); } if (!schedule) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir index 7e9c14c0f678..824caf09bd88 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir @@ -3,6 +3,12 @@ // RUN: --iree-codegen-llvmgpu-use-igemm=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s +// RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx950 \ +// RUN: --iree-codegen-llvmgpu-use-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \ +// RUN: --iree-codegen-llvmgpu-use-igemm=false \ +// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" \ +// RUN: --remarks-filter=".*" %s 2>&1 | FileCheck %s --check-prefix=CHECK-REMARKS + #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)> #rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)> #scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)> @@ -35,6 +41,10 @@ func.func @scaled_matmul( // CHECK-SAME: subgroup = [4, 8, 0, 0] // CHECK-SAME: workgroup = [256, 256, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=34816 + // ----- #lhs_map = affine_map<(B, M, N, Ko, Kb) -> (B, M, Ko, Kb)> @@ -70,6 +80,10 @@ func.func @scaled_matmul_with_batch( // CHECK-SAME: subgroup = [0, 4, 8, 0, 0] // CHECK-SAME: workgroup = [1, 256, 256, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=34816 + // ----- #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)> @@ -132,6 +146,10 @@ func.func @scaled_matmul_with_dynamic_batch( // CHECK-SAME: subgroup = [0, 4, 4, 0, 0] // CHECK-SAME: workgroup = [1, 128, 256, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=26112 + // ----- #lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)> @@ -166,6 +184,10 @@ func.func @small_scaled_matmul( // CHECK-SAME: subgroup = [1, 1, 0, 0] // CHECK-SAME: workgroup = [16, 16, 0, 0] +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=2176 + // ----- module { @@ -273,3 +295,7 @@ func.func @scaled_matmul_accumulate( // CHECK-SAME: reduction = [0, 0, 1, 1] // CHECK-SAME: subgroup = [2, 8, 0, 0] // CHECK: workgroup = [128, 256, 0, 0] + +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=157184 diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 3334ce7885ca..a4289924ff6f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -926,9 +926,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, bool transposedRhs = nIndex != cast(maps[1].getResults().back()).getPosition(); - FailureOr schedule = - deduceMMASchedule(problem, intrinsics, seeds, sharedMemoryLimitInBytes, - subgroupSize, transposedLhs, transposedRhs); + FailureOr schedule = deduceMMASchedule( + problem, intrinsics, seeds, sharedMemoryLimitInBytes, subgroupSize, + /*cuCount=*/std::nullopt, op.getLoc(), transposedLhs, transposedRhs); if (failed(schedule)) return failure(); assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension"); From d5854659a42e11474de0d11489f1d0b396e9facd Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 14 Jan 2026 08:09:20 -0800 Subject: [PATCH 34/71] Adding vm.optimization_barrier and on-demand ordinal analysis. (#23115) Moves the optimization barrier into the VM dialect so that VM modules are self-contained and don't need to pull in util dialect ops. A new VM DropOptimizationBarriers pass strips them prior to serialization. OrdinalAnalysis now computes function/global/rodata ordinals on-demand during bytecode emission instead of storing them as attributes on ops to keep the IR cleaner/prevent drift (avoids needing to keep attributes in sync across passes). EmitC has issues but (always has, it's been silently folding for ages). ci-extra: all --- build_tools/cmake/iree_c_module.cmake | 17 ++ .../bindings/c/iree/compiler/embedding_api.h | 3 + .../plugins/target/VMVX/test/smoketest.mlir | 3 + .../iree/compiler/API/Internal/BUILD.bazel | 1 + .../iree/compiler/API/Internal/CMakeLists.txt | 1 + .../compiler/API/Internal/CompilerDriver.cpp | 6 + .../Util/Transforms/DropCompilerHints.cpp | 4 + .../compiler/Dialect/VM/Analysis/BUILD.bazel | 15 + .../Dialect/VM/Analysis/CMakeLists.txt | 14 + .../Dialect/VM/Analysis/OrdinalAnalysis.cpp | 108 +++++++ .../Dialect/VM/Analysis/OrdinalAnalysis.h | 78 +++++ .../VM/Conversion/UtilToVM/Patterns.cpp | 17 ++ .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 6 + .../src/iree/compiler/Dialect/VM/IR/VMOps.cpp | 12 + .../src/iree/compiler/Dialect/VM/IR/VMOps.td | 43 +++ .../Dialect/VM/Target/Bytecode/BUILD.bazel | 4 +- .../VM/Target/Bytecode/BytecodeEncoder.cpp | 18 +- .../VM/Target/Bytecode/BytecodeEncoder.h | 5 +- .../Target/Bytecode/BytecodeModuleTarget.cpp | 109 ++----- .../VM/Target/Bytecode/BytecodeModuleTarget.h | 3 - .../Dialect/VM/Target/Bytecode/CMakeLists.txt | 4 +- .../Bytecode/test/constant_encoding.mlir | 35 +-- .../VM/Target/Bytecode/test/dependencies.mlir | 10 + .../Dialect/VM/Target/C/CModuleTarget.cpp | 19 +- .../VM/Target/C/test/control_flow.mlir | 5 +- .../Dialect/VM/Transforms/BUILD.bazel | 1 + .../Dialect/VM/Transforms/CMakeLists.txt | 1 + .../Dialect/VM/Transforms/Conversion.cpp | 2 + .../Transforms/DropOptimizationBarriers.cpp | 28 ++ .../VM/Transforms/MaterializeRefDiscards.cpp | 32 ++- .../compiler/Dialect/VM/Transforms/Passes.cpp | 10 + .../compiler/Dialect/VM/Transforms/Passes.td | 10 + .../test/materialize_ref_discards.mlir | 271 +++++++++++++++--- .../iree/compiler/Tools/iree_compile_lib.cc | 12 +- runtime/src/iree/vm/test/arithmetic_ops.mlir | 68 ++--- .../src/iree/vm/test/arithmetic_ops_f32.mlir | 78 ++--- .../src/iree/vm/test/arithmetic_ops_f64.mlir | 78 ++--- .../src/iree/vm/test/arithmetic_ops_i64.mlir | 68 ++--- runtime/src/iree/vm/test/assignment_ops.mlir | 12 +- .../src/iree/vm/test/assignment_ops_f32.mlir | 10 +- .../src/iree/vm/test/assignment_ops_f64.mlir | 10 +- .../src/iree/vm/test/assignment_ops_i64.mlir | 10 +- runtime/src/iree/vm/test/async_ops.mlir | 28 +- runtime/src/iree/vm/test/buffer_ops.mlir | 54 ++-- runtime/src/iree/vm/test/call_ops.mlir | 12 +- runtime/src/iree/vm/test/comparison_ops.mlir | 36 +-- .../src/iree/vm/test/comparison_ops_f32.mlir | 28 +- .../src/iree/vm/test/comparison_ops_f64.mlir | 28 +- .../src/iree/vm/test/comparison_ops_i64.mlir | 36 +-- .../src/iree/vm/test/control_flow_ops.mlir | 20 +- runtime/src/iree/vm/test/conversion_ops.mlir | 12 +- .../src/iree/vm/test/conversion_ops_f32.mlir | 46 +-- .../src/iree/vm/test/conversion_ops_f64.mlir | 18 +- .../src/iree/vm/test/conversion_ops_i64.mlir | 2 +- runtime/src/iree/vm/test/global_ops.mlir | 2 +- runtime/src/iree/vm/test/list_ops.mlir | 4 +- .../src/iree/vm/test/list_variant_ops.mlir | 2 +- runtime/src/iree/vm/test/ref_ops.mlir | 128 ++++----- runtime/src/iree/vm/test/shift_ops.mlir | 6 +- runtime/src/iree/vm/test/shift_ops_i64.mlir | 6 +- tests/compiler_driver/streams.mlir | 4 +- tools/iree-dump-module-main.c | 40 ++- tools/test/iree-dump-module.mlir | 12 +- 63 files changed, 1185 insertions(+), 580 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp create mode 100644 compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h create mode 100644 compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp diff --git a/build_tools/cmake/iree_c_module.cmake b/build_tools/cmake/iree_c_module.cmake index 3e4ad104e59d..11359824cc57 100644 --- a/build_tools/cmake/iree_c_module.cmake +++ b/build_tools/cmake/iree_c_module.cmake @@ -89,6 +89,14 @@ function(iree_c_module) DEPENDS ${_COMPILE_TOOL} ${_SRC_PATH} ) + # Generated EmitC code may have unused variables from optimization + # barriers and other cases where an SSA value is consumed by an op + # that produces a new value. Suppress this warning for Clang/GCC. + iree_select_compiler_opts(_EMITC_SUPPRESS_OPTS + CLANG_OR_GCC + "-Wno-unused-but-set-variable" + ) + iree_cc_library( NAME ${_RULE_NAME} HDRS "${_RULE_H_FILE_OUTPUT}" @@ -96,12 +104,21 @@ function(iree_c_module) INCLUDES "${CMAKE_CURRENT_BINARY_DIR}" COPTS "-DEMITC_IMPLEMENTATION=\"${_RULE_H_FILE_OUTPUT}\"" + ${_EMITC_SUPPRESS_OPTS} "${_TESTONLY_ARG}" DEPS # Include paths and options for the runtime sources. iree_defs ) + # Apply warning suppression to consumers (tests, etc.) that include the + # generated headers. + iree_package_name(_PACKAGE_NAME) + set(_TARGET "${_PACKAGE_NAME}_${_RULE_NAME}") + if(_EMITC_SUPPRESS_OPTS) + target_compile_options(${_TARGET} INTERFACE ${_EMITC_SUPPRESS_OPTS}) + endif() + if(_RULE_NO_RUNTIME) return() endif() diff --git a/compiler/bindings/c/iree/compiler/embedding_api.h b/compiler/bindings/c/iree/compiler/embedding_api.h index 6da0379e9afd..eaf873dcf35e 100644 --- a/compiler/bindings/c/iree/compiler/embedding_api.h +++ b/compiler/bindings/c/iree/compiler/embedding_api.h @@ -268,6 +268,9 @@ enum iree_compiler_pipeline_t { // This is experimental and this should be changed as we move to a more // cohesive approach for managing compilation phases. IREE_COMPILER_PIPELINE_PRECOMPILE = 2, + // VM transformation pipeline only. Converts from input dialects to the VM + // dialect without serialization. + IREE_COMPILER_PIPELINE_VM = 3, }; IREE_EMBED_EXPORTED bool ireeCompilerInvocationPipeline(iree_compiler_invocation_t *inv, diff --git a/compiler/plugins/target/VMVX/test/smoketest.mlir b/compiler/plugins/target/VMVX/test/smoketest.mlir index c4217e983a05..0b2834da1cda 100644 --- a/compiler/plugins/target/VMVX/test/smoketest.mlir +++ b/compiler/plugins/target/VMVX/test/smoketest.mlir @@ -52,9 +52,11 @@ stream.executable public @add_dispatch_0 { // CHECK-DAG: %[[C1_I32:.+]] = vm.const.i32 1 // CHECK-DAG: %[[C1_I64:.+]] = vm.const.i64 1 // CHECK-DAG: %[[C2_I32:.+]] = vm.const.i32 2 +// CHECK: vm.discard.refs %[[SCRATCHPAD]], %[[CONSTANTS]] // CHECK-NEXT: %[[LHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C0_I32]] : (!vm.list, i32) -> !vm.buffer // CHECK-NEXT: %[[RHS_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C1_I32]] : (!vm.list, i32) -> !vm.buffer // CHECK-NEXT: %[[RET_BUF:.+]] = vm.list.get.ref %[[BINDINGS]], %[[C2_I32]] : (!vm.list, i32) -> !vm.buffer +// CHECK-NEXT: vm.discard.refs %[[BINDINGS]] // CHECK: vm.br ^bb1(%[[C0_I64]] : i64) // CHECK-NEXT: ^bb1(%[[IDX:.+]]: i64): // CHECK-NEXT: %slt = vm.cmp.lt.i64.s %[[IDX]], %{{.+}} : i64 @@ -68,6 +70,7 @@ stream.executable public @add_dispatch_0 { // CHECK-NEXT: %[[NEXT_IDX:.+]] = vm.add.i64 %[[IDX]], %[[C1_I64]] : i64 // CHECK-NEXT: vm.br ^bb1(%[[NEXT_IDX]] : i64) // CHECK-NEXT: ^bb3: +// CHECK: vm.discard.refs %[[LHS_BUF]], %[[RHS_BUF]], %[[RET_BUF]] // CHECK-NEXT: vm.return // CHECK-NEXT: } // CHECK-NEXT: vm.export @add_dispatch_0 diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index 11db7604420b..4d801083ab61 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -27,6 +27,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/VM/Target:init_targets", "//compiler/src/iree/compiler/Dialect/VM/Target/Bytecode", "//compiler/src/iree/compiler/Dialect/VM/Target/C", + "//compiler/src/iree/compiler/Dialect/VM/Transforms", "//compiler/src/iree/compiler/Pipelines", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/PluginAPI:PluginManager", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index b68b35a226a6..18272b115a91 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -34,6 +34,7 @@ iree_cc_library( iree::compiler::Dialect::VM::Target::Bytecode iree::compiler::Dialect::VM::Target::C iree::compiler::Dialect::VM::Target::init_targets + iree::compiler::Dialect::VM::Transforms iree::compiler::Pipelines iree::compiler::PluginAPI iree::compiler::PluginAPI::PluginManager diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index ddade384115e..f4cea928b8c4 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -43,6 +43,7 @@ #include "iree/compiler/API/Internal/Diagnostics.h" #include "iree/compiler/ConstEval/Passes.h" #include "iree/compiler/Dialect/VM/Target/init_targets.h" +#include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "iree/compiler/Pipelines/Pipelines.h" #include "iree/compiler/PluginAPI/PluginManager.h" #include "iree/compiler/Tools/init_dialects.h" @@ -1053,6 +1054,11 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { *passManager, compileFrom, compileTo); break; } + case IREE_COMPILER_PIPELINE_VM: { + IREE::VM::buildVMTransformPassPipeline(*passManager, + session.vmTargetOptions); + break; + } default: parsedModule->emitError() << "unsupported pipeline type " << (int)pipeline; return false; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp index 343e2b29a755..c7d64381bad3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp @@ -29,6 +29,10 @@ struct DropCompilerHintsPass op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto op = dyn_cast(genericOp)) { + // TODO(benvanik): #19348 was a terrible approach and this needs to be + // undone. If LLVMGPU wants to keep the hints it should have its own + // codegen op that carries the information. DropCompilerHints is meant + // to drop all compiler hints. if (keepAssumeInt) return; op.replaceAllUsesWith(op.getOperands()); diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel index aab9e7200cb6..fbf14d3a3edb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/BUILD.bazel @@ -55,3 +55,18 @@ iree_compiler_cc_library( "@llvm-project//mlir:Support", ], ) + +iree_compiler_cc_library( + name = "OrdinalAnalysis", + srcs = [ + "OrdinalAnalysis.cpp", + ], + hdrs = [ + "OrdinalAnalysis.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/Util/IR", + "//compiler/src/iree/compiler/Dialect/VM/IR", + "@llvm-project//llvm:Support", + ], +) diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt index d0a857a822a1..bec12e5727a3 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/CMakeLists.txt @@ -52,4 +52,18 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + OrdinalAnalysis + HDRS + "OrdinalAnalysis.h" + SRCS + "OrdinalAnalysis.cpp" + DEPS + LLVMSupport + iree::compiler::Dialect::Util::IR + iree::compiler::Dialect::VM::IR + PUBLIC +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp new file mode 100644 index 000000000000..8cbb34d1b8d0 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp @@ -0,0 +1,108 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h" + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::iree_compiler::IREE::VM { + +// Returns the size in bytes of the global when stored in memory. +// Valid only for globals using primitive storage. +static size_t getGlobalStorageSize(IREE::Util::GlobalOpInterface globalOp) { + auto storageType = globalOp.getGlobalType(); + assert(storageType.isIntOrFloat()); + assert(storageType.getIntOrFloatBitWidth() % 8 == 0); + return IREE::Util::getRoundedElementByteWidth(storageType); +} + +OrdinalAnalysis::OrdinalAnalysis(IREE::VM::ModuleOp moduleOp) { + // Assign ordinals based on IR order (which should be deterministic). + int nextFuncOrdinal = 0; + int nextImportOrdinal = 0; + int nextExportOrdinal = 0; + int nextGlobalRefOrdinal = 0; + int nextRodataOrdinal = 0; + + // Bucket the primitive global ops by byte size for alignment packing. + SmallVector, 8> primitiveGlobalOps( + sizeof(int64_t) + 1); + + for (auto &op : moduleOp.getBlock().getOperations()) { + if (auto funcOp = dyn_cast(op)) { + ordinals_[&op] = nextFuncOrdinal++; + } else if (isa(op)) { + ordinals_[&op] = nextExportOrdinal++; + } else if (isa(op)) { + ordinals_[&op] = nextImportOrdinal++; + } else if (isa(op)) { + ordinals_[&op] = nextRodataOrdinal++; + } else if (auto globalOp = dyn_cast(op)) { + if (isa(globalOp.getGlobalType())) { + ordinals_[&op] = nextGlobalRefOrdinal++; + } else { + // Bucket the primitive global ops by byte size for alignment packing. + size_t storageSize = getGlobalStorageSize(globalOp); + primitiveGlobalOps[storageSize].push_back(globalOp); + } + } + } + + // Assign byte offset values to primitive globals, ensuring that we meet + // natural alignment requirements on each size type. + int nextGlobalBytesOrdinal = 0; + int globalBytes = 0; + for (auto sizeGlobalOps : llvm::enumerate(primitiveGlobalOps)) { + size_t storageSize = sizeGlobalOps.index(); + if (sizeGlobalOps.value().empty()) + continue; + nextGlobalBytesOrdinal = llvm::alignTo(nextGlobalBytesOrdinal, storageSize); + for (auto &globalOp : sizeGlobalOps.value()) { + ordinals_[globalOp] = nextGlobalBytesOrdinal; + nextGlobalBytesOrdinal += storageSize; + globalBytes = std::max(globalBytes, nextGlobalBytesOrdinal); + } + } + + // Record counts. + counts_.importFuncs = nextImportOrdinal; + counts_.exportFuncs = nextExportOrdinal; + counts_.internalFuncs = nextFuncOrdinal; + counts_.globalBytes = globalBytes; + counts_.globalRefs = nextGlobalRefOrdinal; + counts_.rodatas = nextRodataOrdinal; + counts_.rwdatas = 0; +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::FuncOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::ExportOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::ImportOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(IREE::VM::RodataOp op) const { + return getOrdinal(op.getOperation()); +} + +int64_t +OrdinalAnalysis::getGlobalOrdinal(IREE::Util::GlobalOpInterface op) const { + return getOrdinal(op.getOperation()); +} + +int64_t OrdinalAnalysis::getOrdinal(Operation *op) const { + auto it = ordinals_.find(op); + assert(it != ordinals_.end() && "ordinal not found for operation"); + return it->second; +} + +} // namespace mlir::iree_compiler::IREE::VM diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h new file mode 100644 index 000000000000..67e2580d5e91 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h @@ -0,0 +1,78 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_VM_ANALYSIS_ORDINALANALYSIS_H_ +#define IREE_COMPILER_DIALECT_VM_ANALYSIS_ORDINALANALYSIS_H_ + +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir::iree_compiler::IREE::VM { + +// Computes ordinal assignments for module-level symbols. +// +// Each ordinal is unique per-category and ordinals are contiguous starting +// from zero. Categories include: +// - Internal functions (vm.func) +// - Import functions (vm.import) +// - Export functions (vm.export) +// - Rodata segments (vm.rodata) +// - Global refs (vm.global.ref) +// - Global bytes (byte offset for primitive globals) +// +// This analysis is computed on-demand when ordinals are needed for +// serialization, avoiding the need to store ordinals as attributes on ops. +class OrdinalAnalysis { +public: + // Summary counts of module-level symbols. + struct OrdinalCounts { + int32_t importFuncs = 0; + int32_t exportFuncs = 0; + int32_t internalFuncs = 0; + int32_t globalBytes = 0; + int32_t globalRefs = 0; + int32_t rodatas = 0; + int32_t rwdatas = 0; // Currently unused, reserved. + }; + + OrdinalAnalysis() = default; + explicit OrdinalAnalysis(IREE::VM::ModuleOp moduleOp); + + OrdinalAnalysis(OrdinalAnalysis &&) = default; + OrdinalAnalysis &operator=(OrdinalAnalysis &&) = default; + OrdinalAnalysis(const OrdinalAnalysis &) = delete; + OrdinalAnalysis &operator=(const OrdinalAnalysis &) = delete; + + // Returns the ordinal counts for the module. + const OrdinalCounts &getCounts() const { return counts_; } + + // Returns the ordinal for a vm.func op. + int64_t getOrdinal(IREE::VM::FuncOp op) const; + + // Returns the ordinal for a vm.export op. + int64_t getOrdinal(IREE::VM::ExportOp op) const; + + // Returns the ordinal for a vm.import op. + int64_t getOrdinal(IREE::VM::ImportOp op) const; + + // Returns the ordinal for a vm.rodata op. + int64_t getOrdinal(IREE::VM::RodataOp op) const; + + // Returns the byte offset ordinal for a primitive global. + // Returns -1 if the global is a ref type. + int64_t getGlobalOrdinal(IREE::Util::GlobalOpInterface op) const; + + // Generic ordinal lookup for any operation with an ordinal. + int64_t getOrdinal(Operation *op) const; + +private: + OrdinalCounts counts_; + llvm::DenseMap ordinals_; +}; + +} // namespace mlir::iree_compiler::IREE::VM + +#endif // IREE_COMPILER_DIALECT_VM_ANALYSIS_ORDINALANALYSIS_H_ diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp index 27f9a80987fc..13b2ba2687e5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp @@ -103,6 +103,22 @@ struct CmpNEOpConversion : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// util.optimization_barrier +//===----------------------------------------------------------------------===// + +struct OptimizationBarrierOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::Util::OptimizationBarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands()); + return success(); + } +}; + } // namespace void populateUtilToVMPatterns(MLIRContext *context, @@ -113,6 +129,7 @@ void populateUtilToVMPatterns(MLIRContext *context, patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); populateUtilAlignmentToVMPatterns(context, conversionTarget, typeConverter, patterns); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 4a1c0b6f3dea..109b0ef21e1b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -5262,6 +5262,12 @@ class ConvertVMToEmitCPass void runOnOperation() override { IREE::VM::ModuleOp moduleOp = getOperation(); + // Erase vm.discard.refs ops before analysis. These are inserted by + // MaterializeRefDiscardsPass for the bytecode backend but are not used + // by EmitC. Erasing them here avoids inflating register pressure during + // the register allocation analysis. + moduleOp.walk([](IREE::VM::DiscardRefsOp op) { op.erase(); }); + ConversionTarget target(getContext()); EmitCTypeConverter typeConverter(moduleOp); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index 9e466acb63f0..c6798a85fdc5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -1742,6 +1742,18 @@ SuccessorOperands CondBreakOp::getSuccessorOperands(unsigned index) { return SuccessorOperands(getDestOperandsMutable()); } +//===----------------------------------------------------------------------===// +// vm.optimization_barrier +//===----------------------------------------------------------------------===// + +void OptimizationBarrierOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, + ArrayRef attributes) { + state.addOperands(operands); + state.addTypes(llvm::to_vector(operands.getTypes())); + state.addAttributes(attributes); +} + } // namespace mlir::iree_compiler::IREE::VM //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index bf36457bbdd6..2e6a99e5df73 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -5344,4 +5344,47 @@ def VM_CondBreakOp : VM_Op<"cond_break", [ } // OpGroupDebuggingOps +//===----------------------------------------------------------------------===// +// Compiler hints +//===----------------------------------------------------------------------===// + +def OpGroupCompilerHintOps : OpDocGroup { + let summary = "Compiler hint ops"; + let description = ""; +} + +let opDocGroup = OpGroupCompilerHintOps in { + +def VM_OptimizationBarrierOp : VM_Op<"optimization_barrier", [ + VM_PseudoOp, + AllTypesMatch<["operands", "results"]>, + ]> { + let summary = [{Prevents compiler optimizations across a value.}]; + let description = [{ + Wraps any operands in an unoptimizable identity to prevent its results from + being folded. It will be dropped during the final step in compilation and + has no effect at runtime. + }]; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs + Variadic:$results + ); + + let assemblyFormat = [{ + attr-dict ($operands^ `:` type($operands))? + }]; + + let builders = [ + OpBuilder<(ins + "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes + )>, + ]; +} + +} // OpGroupCompilerHintOps + #endif // IREE_DIALECT_VM_OPS diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel index 4bf8864c5e00..30df7a9c5403 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BUILD.bazel @@ -29,19 +29,17 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/Util/IR", - "//compiler/src/iree/compiler/Dialect/Util/Transforms", "//compiler/src/iree/compiler/Dialect/VM/Analysis", + "//compiler/src/iree/compiler/Dialect/VM/Analysis:OrdinalAnalysis", "//compiler/src/iree/compiler/Dialect/VM/Analysis:ValueLiveness", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Dialect/VM/IR", - "//compiler/src/iree/compiler/Dialect/VM/Transforms", "//compiler/src/iree/compiler/Dialect/VM/Utils:CallingConvention", "//compiler/src/iree/compiler/Dialect/VM/Utils:TypeTable", "//compiler/src/iree/compiler/Utils", "//runtime/src/iree/schemas:bytecode_module_def_c_fbs", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp index ad2958f544f4..070329e4c377 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp @@ -24,8 +24,10 @@ namespace { class V0BytecodeEncoder : public BytecodeEncoder { public: V0BytecodeEncoder(llvm::DenseMap *typeTable, - RegisterAllocation *registerAllocation) - : typeTable_(typeTable), registerAllocation_(registerAllocation) {} + RegisterAllocation *registerAllocation, + const OrdinalAnalysis *ordinalAnalysis) + : typeTable_(typeTable), registerAllocation_(registerAllocation), + ordinalAnalysis_(ordinalAnalysis) {} ~V0BytecodeEncoder() = default; LogicalResult beginBlock(Block *block) override { @@ -59,11 +61,7 @@ class V0BytecodeEncoder : public BytecodeEncoder { if (!symbolOp) { return currentOp_->emitOpError() << "target symbol not found: " << name; } - auto ordinalAttr = symbolOp->getAttrOfType("ordinal"); - if (!ordinalAttr) { - return symbolOp->emitOpError() << "missing ordinal"; - } - int32_t ordinal = ordinalAttr.getInt(); + int32_t ordinal = ordinalAnalysis_->getOrdinal(symbolOp); if (isa(symbolOp)) { // Imported functions have their MSB set. ordinal |= 0x80000000u; @@ -387,6 +385,7 @@ class V0BytecodeEncoder : public BytecodeEncoder { llvm::DenseMap *typeTable_; RegisterAllocation *registerAllocation_; + const OrdinalAnalysis *ordinalAnalysis_; Operation *currentOp_ = nullptr; @@ -400,7 +399,8 @@ class V0BytecodeEncoder : public BytecodeEncoder { // static std::optional BytecodeEncoder::encodeFunction( IREE::VM::FuncOp funcOp, llvm::DenseMap &typeTable, - SymbolTable &symbolTable, DebugDatabaseBuilder &debugDatabase) { + SymbolTable &symbolTable, const OrdinalAnalysis &ordinalAnalysis, + DebugDatabaseBuilder &debugDatabase) { EncodedBytecodeFunction result; // Perform register allocation first so that we can quickly lookup values as @@ -414,7 +414,7 @@ std::optional BytecodeEncoder::encodeFunction( FunctionSourceMap sourceMap; sourceMap.localName = funcOp.getName().str(); - V0BytecodeEncoder encoder(&typeTable, ®isterAllocation); + V0BytecodeEncoder encoder(&typeTable, ®isterAllocation, &ordinalAnalysis); for (auto &block : funcOp.getBlocks()) { size_t blockStart = encoder.getOffset(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h index 79ba38f22fa9..e29f1fbf36c5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_BYTECODEENCODER_H_ #define IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_BYTECODEENCODER_H_ +#include "iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h" #include "iree/compiler/Dialect/VM/IR/VMFuncEncoder.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.h" @@ -43,7 +44,9 @@ class BytecodeEncoder : public VMFuncEncoder { // Returns None on failure. static std::optional encodeFunction(IREE::VM::FuncOp funcOp, llvm::DenseMap &typeTable, - SymbolTable &symbolTable, DebugDatabaseBuilder &debugDatabase); + SymbolTable &symbolTable, + const OrdinalAnalysis &ordinalAnalysis, + DebugDatabaseBuilder &debugDatabase); BytecodeEncoder() = default; ~BytecodeEncoder() = default; diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index e1bbef46d064..8bba90c5146d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -11,14 +11,13 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" -#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.h" #include "iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h" #include "iree/compiler/Dialect/VM/Analysis/ValueLiveness.h" #include "iree/compiler/Dialect/VM/IR/VMDialect.h" #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.h" #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h" -#include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "iree/compiler/Dialect/VM/Utils/CallingConvention.h" #include "iree/compiler/Dialect/VM/Utils/TypeTable.h" #include "iree/compiler/Utils/FlatbufferUtils.h" @@ -33,12 +32,9 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LocationSnapshot.h" -#include "mlir/Transforms/Passes.h" IREE_DEFINE_COMPILER_OPTION_FLAGS( mlir::iree_compiler::IREE::VM::BytecodeTargetOptions); @@ -125,15 +121,16 @@ serializeEmbeddedData(Location loc, Attribute valueAttr, uint64_t alignment, } // Canonicalizes the module to its final form prior to emission. -// This verifies that we only have ops we can serialize and performs any of the -// required transformations (such as debug op stripping). +// This verifies that we only have ops we can serialize and removes any +// pseudo-ops and debug ops (when stripping is enabled). +// All transformation passes should have run in the main VM transformation +// pipeline before this is called. static LogicalResult canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, IREE::VM::ModuleOp moduleOp) { RewritePatternSet patterns(moduleOp.getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); - target.addLegalOp(); // Add all VM canonicalization patterns and mark pseudo-ops illegal. auto *context = moduleOp.getContext(); @@ -145,7 +142,6 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, } // Debug ops must not be present when stripping. - // TODO(benvanik): add RemoveDisabledDebugOp pattern. if (op.hasTrait() && bytecodeOptions.stripDebugOps) { target.setOpAction(op, ConversionTarget::LegalizationAction::Illegal); @@ -156,48 +152,6 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, return moduleOp.emitError() << "unable to fully apply conversion to module"; } - PassManager passManager(context); - // TODO(12938): Handle or investigate failure result. - auto logicalRes = mlir::applyPassManagerCLOptions(passManager); - (void)logicalRes; - mlir::applyDefaultTimingPassManagerCLOptions(passManager); - passManager.addInstrumentation(std::make_unique()); - auto &modulePasses = passManager.nest(); - - // TODO(benvanik): these ideally happen beforehand but when performing - // serialization the input IR often has some of these low-level VM ops. In - // real workflows these have already run earlier and are no-ops. - modulePasses.addPass(IREE::VM::createGlobalInitializationPass()); - modulePasses.addPass(IREE::VM::createDropEmptyModuleInitializersPass()); - - if (bytecodeOptions.optimize) { - // TODO(benvanik): run this as part of a fixed-point iteration. - modulePasses.addPass(mlir::createInlinerPass()); - modulePasses.addPass(mlir::createCSEPass()); - // TODO(benvanik): re-evaluate whether this canonicalizer pass should exist - // in the bytecode target. It may be removing ops (like vm.discard.refs) - // that were intentionally inserted by earlier passes. - modulePasses.addPass(mlir::createCanonicalizerPass()); - } - - modulePasses.addPass(IREE::Util::createDropCompilerHintsPass()); - - // Insert explicit discard ops for ref values at their last use points. - // Uses edge-based placement: refs dying on control flow edges get discards - // inserted on those edges, refs dying mid-block get discards after last use. - modulePasses.addPass(IREE::VM::createMaterializeRefDiscardsPass()); - - // Mark up the module with ordinals for each top-level op (func, etc). - // This will make it easier to correlate the MLIR textual output to the - // binary output. - // We don't want any more modifications after this point as they could - // invalidate the ordinals. - modulePasses.addPass(IREE::VM::createOrdinalAllocationPass()); - - if (failed(passManager.run(moduleOp->getParentOfType()))) { - return moduleOp.emitError() << "failed during transform passes"; - } - return success(); } @@ -317,12 +271,11 @@ static iree_vm_FeatureBits_enum_t findRequiredFeatures(Operation *rootOp) { // has been packed into the top-level table. This results in a messier function // here during serialization but a much more trivial (and cache-friendly) // representation at runtime. -static LogicalResult -buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, - IREE::VM::BytecodeTargetOptions bytecodeOptions, - IREE::VM::ModuleOp moduleOp, - MutableArrayRef rodataRefs, - FlatbufferBuilder &fbb) { +static LogicalResult buildFlatBufferModule( + IREE::VM::TargetOptions vmOptions, + IREE::VM::BytecodeTargetOptions bytecodeOptions, + IREE::VM::ModuleOp moduleOp, const OrdinalAnalysis &ordinalAnalysis, + MutableArrayRef rodataRefs, FlatbufferBuilder &fbb) { // Start the buffer so that we can begin recording data prior to the root // table (which we do at the very end). This does not change the layout of the // file and is only used to prime the flatcc builder. @@ -334,26 +287,22 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, DebugDatabaseBuilder debugDatabase; SymbolTable symbolTable(moduleOp); - OrdinalCountsAttr ordinalCounts = moduleOp.getOrdinalCountsAttr(); - if (!ordinalCounts) { - return moduleOp.emitError() << "ordinal_counts attribute not found. The " - "OrdinalAllocationPass must be run before."; - } + const auto &ordinalCounts = ordinalAnalysis.getCounts(); // Find all structural ops in the module. std::vector importFuncOps; std::vector exportFuncOps; std::vector internalFuncOps; - importFuncOps.resize(ordinalCounts.getImportFuncs()); - exportFuncOps.resize(ordinalCounts.getExportFuncs()); - internalFuncOps.resize(ordinalCounts.getInternalFuncs()); + importFuncOps.resize(ordinalCounts.importFuncs); + exportFuncOps.resize(ordinalCounts.exportFuncs); + internalFuncOps.resize(ordinalCounts.internalFuncs); for (auto &op : moduleOp.getBlock().getOperations()) { if (auto funcOp = dyn_cast(op)) { - internalFuncOps[funcOp.getOrdinal()->getLimitedValue()] = funcOp; + internalFuncOps[ordinalAnalysis.getOrdinal(funcOp)] = funcOp; } else if (auto exportOp = dyn_cast(op)) { - exportFuncOps[exportOp.getOrdinal()->getLimitedValue()] = exportOp; + exportFuncOps[ordinalAnalysis.getOrdinal(exportOp)] = exportOp; } else if (auto importOp = dyn_cast(op)) { - importFuncOps[importOp.getOrdinal()->getLimitedValue()] = importOp; + importFuncOps[ordinalAnalysis.getOrdinal(importOp)] = importOp; if (!importOp.getName().contains('.')) { return importOp.emitOpError("must reference a function in a module " "(@module_name.func_name); got unscoped `@") @@ -380,7 +329,7 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, size_t totalBytecodeLength = 0; for (auto [i, funcOp] : llvm::enumerate(internalFuncOps)) { auto encodedFunction = BytecodeEncoder::encodeFunction( - funcOp, typeOrdinalMap, symbolTable, debugDatabase); + funcOp, typeOrdinalMap, symbolTable, ordinalAnalysis, debugDatabase); if (!encodedFunction) { return funcOp.emitError() << "failed to encode function bytecode"; } @@ -466,7 +415,7 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, iree_vm_ExportFunctionDef_start(fbb); iree_vm_ExportFunctionDef_local_name_add(fbb, localNameRef); iree_vm_ExportFunctionDef_internal_ordinal_add( - fbb, funcOp.getOrdinal()->getLimitedValue()); + fbb, ordinalAnalysis.getOrdinal(funcOp)); return iree_vm_ExportFunctionDef_end(fbb); }); @@ -530,8 +479,8 @@ buildFlatBufferModule(IREE::VM::TargetOptions vmOptions, auto dependenciesRef = fbb.createOffsetVecDestructive(dependencyRefs); auto typesRef = fbb.createOffsetVecDestructive(typeRefs); - int32_t globalRefs = ordinalCounts.getGlobalRefs(); - int32_t globalBytes = ordinalCounts.getGlobalBytes(); + int32_t globalRefs = ordinalCounts.globalRefs; + int32_t globalBytes = ordinalCounts.globalBytes; iree_vm_ModuleStateDef_ref_t moduleStateDef = 0; if (globalBytes || globalRefs) { @@ -656,15 +605,18 @@ translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, assert(false && "unhandled output format combination"); } + // Compute ordinals for all module-level symbols. + OrdinalAnalysis ordinalAnalysis(moduleOp); + // Declare all rodata entries we want to end up as external data first. This // allows us to compute offsets if needed without having had to perform // serialization yet. Note that not all rodata ends up as external data: if // it's small (like strings) we can avoid the extra seeks and keep it more // local by embedding it in the FlatBuffer. std::vector rodataOps; - rodataOps.resize(moduleOp.getOrdinalCountsAttr().getRodatas()); + rodataOps.resize(ordinalAnalysis.getCounts().rodatas); for (auto rodataOp : moduleOp.getOps()) { - rodataOps[rodataOp.getOrdinal()->getLimitedValue()] = rodataOp; + rodataOps[ordinalAnalysis.getOrdinal(rodataOp)] = rodataOp; } SmallVector rodataRefs; rodataRefs.resize(rodataOps.size()); @@ -699,7 +651,7 @@ translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, llvm::endianness::little, os); }); } - rodataRefs[rodataOp.getOrdinal()->getLimitedValue()] = rodataRef; + rodataRefs[ordinalAnalysis.getOrdinal(rodataOp)] = rodataRef; } // NOTE: we order things so that all of the metadata is close to the start of @@ -708,7 +660,7 @@ translateModuleToBytecode(IREE::VM::ModuleOp moduleOp, // can be large bulk data. FlatbufferBuilder fbb; if (failed(buildFlatBufferModule(vmOptions, bytecodeOptions, moduleOp, - rodataRefs, fbb))) { + ordinalAnalysis, rodataRefs, fbb))) { return failure(); } if (failed(archiveWriter->flush(fbb))) { @@ -751,11 +703,6 @@ void BytecodeTargetOptions::bindOptions(OptionsBinder &binder) { clEnumValN(BytecodeOutputFormat::kAnnotatedMlirText, "annotated-mlir-text", "MLIR module file in the VM dialect with annotations"))); - binder.opt( - "iree-vm-bytecode-module-optimize", optimize, - llvm::cl::cat(vmBytecodeOptionsCategory), - llvm::cl::desc("Optimizes the VM module with CSE/inlining/etc prior to " - "serialization")); binder.opt( "iree-vm-bytecode-source-listing", sourceListing, llvm::cl::cat(vmBytecodeOptionsCategory), diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h index 7f7e8788ed93..e2137ad49143 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h @@ -35,9 +35,6 @@ struct BytecodeTargetOptions { // Format of the module written to the output stream. BytecodeOutputFormat outputFormat = BytecodeOutputFormat::kFlatBufferBinary; - // Run basic CSE/inlining/etc passes prior to serialization. - bool optimize = true; - // Dump a VM MLIR file and annotate source locations with it. // This allows for the runtime to serve stack traces referencing both the // original source locations and the VM IR. diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt index c2c94c4a327a..cabe350656e9 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt @@ -27,18 +27,16 @@ iree_cc_library( DEPS LLVMSupport MLIRIR - MLIRPass MLIRSupport MLIRTransformUtils MLIRTransforms MLIRTranslateLib iree::compiler::Dialect::Util::IR - iree::compiler::Dialect::Util::Transforms iree::compiler::Dialect::VM::Analysis + iree::compiler::Dialect::VM::Analysis::OrdinalAnalysis iree::compiler::Dialect::VM::Analysis::ValueLiveness iree::compiler::Dialect::VM::Conversion iree::compiler::Dialect::VM::IR - iree::compiler::Dialect::VM::Transforms iree::compiler::Dialect::VM::Utils::CallingConvention iree::compiler::Dialect::VM::Utils::TypeTable iree::compiler::Utils diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir index 52f160c3c124..a110a46dad4c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir @@ -4,11 +4,6 @@ // CHECK: "name": "constants" vm.module @constants { - vm.export @func - vm.func @func() { - vm.return - } - // CHECK: "rodata_segments": [{ // Tests that we densely pack i2 values. Note that the final element (3) is @@ -18,7 +13,7 @@ vm.module @constants { // CHECK-NEXT: 26, // CHECK-NEXT: 3 // CHECK-NEXT: ] - vm.rodata private @dense_i2 dense<[0, 1, 2, 3, 2, 2, 1, 0, 3]> : tensor<9xi2> + vm.rodata public @dense_i2 dense<[0, 1, 2, 3, 2, 2, 1, 0, 3]> : tensor<9xi2> // Tests that we densely pack i3 values and insert the wasted 2-bits of // padding in each byte. Smarter implementations would pack to 16- or 64-bit @@ -29,7 +24,7 @@ vm.module @constants { // CHECK-NEXT: 44, // CHECK-NEXT: 62 // CHECK-NEXT: ] - vm.rodata private @dense_i3 dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi3> + vm.rodata public @dense_i3 dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi3> // Tests that we densely pack i4 values and handle partial values (14). // CHECK: "embedded_data": [ @@ -43,7 +38,7 @@ vm.module @constants { // CHECK-NEXT: 254, // CHECK-NEXT: 14 // CHECK-NEXT: ] - vm.rodata private @dense_i4 dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14]> : tensor<17xi4> + vm.rodata public @dense_i4 dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 14]> : tensor<17xi4> // CHECK: "embedded_data": [ // CHECK-NEXT: 98, @@ -51,14 +46,14 @@ vm.module @constants { // CHECK-NEXT: 197, // CHECK-NEXT: 28 // CHECK-NEXT: ] - vm.rodata private @dense_i5 dense<[2, 3, 4, 5, 6, 7]> : tensor<6xi5> + vm.rodata public @dense_i5 dense<[2, 3, 4, 5, 6, 7]> : tensor<6xi5> // CHECK: "embedded_data": [ // CHECK-NEXT: 1, // CHECK-NEXT: 2, // CHECK-NEXT: 3 // CHECK-NEXT: ] - vm.rodata private @dense_i8 dense<[1, 2, 3]> : tensor<3xi8> + vm.rodata public @dense_i8 dense<[1, 2, 3]> : tensor<3xi8> // CHECK: "embedded_data": [ // CHECK-NEXT: 1, @@ -70,7 +65,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 0 // CHECK-NEXT: ] - vm.rodata private @dense_i9 dense<[1, 2, 3, 4, 5]> : tensor<5xi9> + vm.rodata public @dense_i9 dense<[1, 2, 3, 4, 5]> : tensor<5xi9> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -80,7 +75,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 66 // CHECK-NEXT: ] - vm.rodata private @dense_f16 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16> + vm.rodata public @dense_f16 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -90,7 +85,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 60 // CHECK-NEXT: ] - vm.rodata private @splat_f16 dense<1.000000e+00> : tensor<3xf16> + vm.rodata public @splat_f16 dense<1.000000e+00> : tensor<3xf16> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -106,7 +101,7 @@ vm.module @constants { // CHECK-NEXT: 64, // CHECK-NEXT: 64 // CHECK-NEXT: ] - vm.rodata private @dense_f32 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> + vm.rodata public @dense_f32 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> // CHECK: "embedded_data": [ @@ -128,7 +123,7 @@ vm.module @constants { // CHECK-NEXT: 128, // CHECK-NEXT: 64 // CHECK-NEXT: ] - vm.rodata private @dense_resource_complex_f32 dense< + vm.rodata public @dense_resource_complex_f32 dense< "0x0000803F000000400000404000008040" > : tensor<2xcomplex> @@ -146,7 +141,7 @@ vm.module @constants { // CHECK-NEXT: 128, // CHECK-NEXT: 63 // CHECK-NEXT: ] - vm.rodata private @splat_f32 dense<1.000000e+00> : tensor<3xf32> + vm.rodata public @splat_f32 dense<1.000000e+00> : tensor<3xf32> // Tests that elided tensors of sub-byte types get filled with zeros when the // --iree-util-zero-fill-elided-attrs flag is passed. This is useful for @@ -157,7 +152,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 0 // CHECK-NEXT: ] - vm.rodata private @elided_i2 dense_resource<__elided__> : tensor<9xi2> + vm.rodata public @elided_i2 dense_resource<__elided__> : tensor<9xi2> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -173,7 +168,7 @@ vm.module @constants { // CHECK-NEXT: 0, // CHECK-NEXT: 0 // CHECK-NEXT: ] - vm.rodata private @elided_f32 dense_resource<__elided__> : tensor<3xf32> + vm.rodata public @elided_f32 dense_resource<__elided__> : tensor<3xf32> // Tests #util.byte_pattern on sub-byte types. // CHECK: "embedded_data": [ @@ -181,7 +176,7 @@ vm.module @constants { // CHECK-NEXT: 1, // CHECK-NEXT: 1 // CHECK-NEXT: ] - vm.rodata private @byte_pattern_i2 #util.byte_pattern<1> : tensor<9xi2> + vm.rodata public @byte_pattern_i2 #util.byte_pattern<1> : tensor<9xi2> // CHECK: "embedded_data": [ // CHECK-NEXT: 0, @@ -209,5 +204,5 @@ vm.module @constants { // CHECK-NEXT: 8, // CHECK-NEXT: 64 // CHECK-NEXT: ] - vm.rodata private @dense_f64 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf64> + vm.rodata public @dense_f64 dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf64> } diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir index 914ddae9e830..6b3dd3478905 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/dependencies.mlir @@ -32,6 +32,16 @@ vm.module @main_module attributes { version = 100 : i32 } { // CHECK: "flags": "OPTIONAL" vm.import private optional @optional.method1() attributes { minimum_version = 11 : i32 } + // Use the imports so they're not eliminated by DCE. + vm.export @use_imports + vm.func private @use_imports() { + vm.call @required.method0() : () -> () + vm.call @required.method1() : () -> () + vm.call @required.method2() : () -> () + vm.call @optional.method0() : () -> () + vm.call @optional.method1() : () -> () + vm.return + } } // ----- diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp index 2f8e5a2fb3a8..87e2c29dbd7d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp @@ -28,7 +28,6 @@ canonicalizeModule(IREE::VM::ModuleOp moduleOp, RewritePatternSet patterns(moduleOp.getContext()); ConversionTarget target(*moduleOp.getContext()); target.addLegalDialect(); - target.addLegalOp(); // Add all VM canonicalization patterns and mark pseudo-ops illegal. auto *context = moduleOp.getContext(); @@ -86,15 +85,19 @@ canonicalizeModule(IREE::VM::ModuleOp moduleOp, // invalidate the ordinals. modulePasses.addPass(IREE::VM::createOrdinalAllocationPass()); - // C target specific pass - modulePasses.addPass(createConvertVMToEmitCPass()); + // Drop vm.optimization_barrier ops before EmitC conversion. The barriers + // prevent folding during VM-level optimizations above, but EmitC doesn't + // have conversion patterns for vm.optimization_barrier. + modulePasses.addPass(IREE::VM::createDropOptimizationBarriersPass()); - // Drop optimization barriers after EmitC conversion. Must be after conversion - // so barriers prevent folding during VM-level canonicalization, but the - // subsequent canonicalizer only sees EmitC ops (which don't fold VM - // constants). - modulePasses.addPass(IREE::Util::createDropCompilerHintsPass()); + // Clean up dead code created by dropping barriers. The barriers prevented + // constant folding, so after dropping them we need to eliminate unused + // constants to avoid generating unused variables in EmitC. modulePasses.addPass(mlir::createCanonicalizerPass()); + modulePasses.addPass(mlir::createCSEPass()); + + // C target specific pass + modulePasses.addPass(createConvertVMToEmitCPass()); if (failed(passManager.run(moduleOp->getParentOfType()))) { return moduleOp.emitError() << "failed during transform passes"; diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir index 2e746bcc51a6..6c7bea18179a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Target/C/test/control_flow.mlir @@ -18,12 +18,10 @@ vm.module @control_flow_module { // CHECK-NEXT: int32_t [[V0:[^ ]*]]; // CHECK-NEXT: iree_status_t [[STATUS:[^ ]*]]; // CHECK-NEXT: int32_t [[C:[^ ]*]]; - // CHECK-NEXT: int32_t [[D:[^ ]*]]; // CHECK-NEXT: [[COND_NZ]] = vm_cmp_nz_i32([[COND]]); // CHECK-NEXT: [[COND_BOOL]] = (bool) [[COND_NZ]]; // CHECK-NEXT: if ([[COND_BOOL]]) { // CHECK-NEXT: [[C]] = [[A]]; - // CHECK-NEXT: [[D]] = [[A]]; // CHECK-NEXT: goto [[BB2:[^ ]*]]; // CHECK-NEXT: } else { // CHECK-NEXT: goto [[BB1:[^ ]*]]; @@ -31,10 +29,9 @@ vm.module @control_flow_module { // CHECK-NEXT: [[BB1]]: // CHECK-NEXT: [[B]] = vm_add_i32([[A]], [[A]]); // CHECK-NEXT: [[C]] = [[B]]; - // CHECK-NEXT: [[D]] = [[A]]; // CHECK-NEXT: goto [[BB2:[^ ]*]]; // CHECK-NEXT: [[BB2]]: - // CHECK-NEXT: [[V0]] = vm_add_i32([[C]], [[D]]); + // CHECK-NEXT: [[V0]] = vm_add_i32([[C]], [[A]]); // CHECK-NEXT: EMITC_DEREF_ASSIGN_VALUE([[RESULT]], [[V0]]); // CHECK-NEXT: [[STATUS]] = iree_ok_status(); // CHECK-NEXT: return [[STATUS]]; diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel index a666cc18bb5f..926aba932870 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel @@ -25,6 +25,7 @@ iree_compiler_cc_library( "ConvertToYieldableCalls.cpp", "DeduplicateRodata.cpp", "DropEmptyModuleInitializers.cpp", + "DropOptimizationBarriers.cpp", "DropUnusedCalls.cpp", "GlobalInitialization.cpp", "HoistInlinedRodata.cpp", diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt index 14917efaaa8e..3e107cc2820c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ iree_cc_library( "ConvertToYieldableCalls.cpp" "DeduplicateRodata.cpp" "DropEmptyModuleInitializers.cpp" + "DropOptimizationBarriers.cpp" "DropUnusedCalls.cpp" "GlobalInitialization.cpp" "HoistInlinedRodata.cpp" diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index 9fa70b26b3fb..6f08662d394f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -144,6 +144,8 @@ class ConversionPass // legalization when types need conversion (e.g., index -> i32). conversionTarget.addIllegalOp(); patterns.add(typeConverter, context); + // Convert util.optimization_barrier to vm.optimization_barrier. + conversionTarget.addIllegalOp(); populateUtilToVMPatterns(context, conversionTarget, typeConverter, importTable, patterns); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp new file mode 100644 index 000000000000..6536a0a6663c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropOptimizationBarriers.cpp @@ -0,0 +1,28 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "iree/compiler/Dialect/VM/Transforms/Passes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::VM { + +#define GEN_PASS_DEF_DROPOPTIMIZATIONBARRIERSPASS +#include "iree/compiler/Dialect/VM/Transforms/Passes.h.inc" + +class DropOptimizationBarriersPass + : public IREE::VM::impl::DropOptimizationBarriersPassBase< + DropOptimizationBarriersPass> { + void runOnOperation() override { + getOperation()->walk([&](IREE::VM::OptimizationBarrierOp op) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + }); + } +}; + +} // namespace mlir::iree_compiler::IREE::VM diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp index d0093e1cce13..3fdc6f1f1ebd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp @@ -15,6 +15,7 @@ #include "iree/compiler/Dialect/VM/IR/VMOps.h" #include "iree/compiler/Dialect/VM/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -183,12 +184,22 @@ class MaterializeRefDiscardsPass // Get the operands being passed to succ on this edge. SuccessorOperands succOperands = branchOp.getSuccessorOperands(succIndex); SmallVector operandValues(succOperands.getForwardedOperands()); + unsigned producedCount = succOperands.getProducedOperandCount(); - // Add block arguments to newBlock to receive the operands. + // Add block arguments to newBlock to receive the forwarded operands. for (Value operand : operandValues) { newBlock->addArgument(operand.getType(), operand.getLoc()); } + // Add block arguments for produced operands (e.g., vm.call.yieldable + // results). These are created by the terminator at runtime and must be + // forwarded through the new block to the original successor. + for (unsigned i = 0; i < producedCount; ++i) { + // Produced operands come after forwarded operands in succ's arguments. + Type type = succ->getArgument(operandValues.size() + i).getType(); + newBlock->addArgument(type, loc); + } + // Update predecessor's terminator to go to new block instead of succ. // The operands stay the same - they'll now be passed to newBlock. terminator->setSuccessor(newBlock, succIndex); @@ -224,8 +235,10 @@ class MaterializeRefDiscardsPass OpBuilder builder(funcOp.getContext()); - // Collect all refs in the function. - llvm::DenseSet allRefs; + // Collect all refs in the function in deterministic order. + // Walk blocks and operations in order and insert into SetVector, which + // maintains insertion order for deterministic iteration. + llvm::SetVector allRefs; for (Block &block : funcOp.getBlocks()) { for (BlockArgument arg : block.getArguments()) { if (isa(arg.getType())) { @@ -352,6 +365,19 @@ class MaterializeRefDiscardsPass if (op.hasTrait()) { continue; } + + // Skip refs that are MOVE operands of RefMoveInterface + // operations. When an operand is movable and this is its last + // use, the MOVE bit will be set by the register allocator and + // ownership transfers to the operation (e.g., vm.call, + // vm.call.variadic). Inserting a discard would be incorrect as + // the ref is consumed by the operation. + if (auto refMoveOp = dyn_cast(&op)) { + if (refMoveOp.isRefOperandMovable(operand.getOperandNumber())) { + continue; + } + } + // Group by insertion point. auto it = opToIndex.find(&op); if (it == opToIndex.end()) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp index 8a3464da60e1..0098e03440ed 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp @@ -169,6 +169,16 @@ void buildVMTransformPassPipeline(OpPassManager &passManager, if (targetOptions.optimizeForStackSize) { passManager.addNestedPass(createSinkDefiningOpsPass()); } + + // Drop vm.optimization_barrier ops now that optimization is complete. + passManager.addNestedPass( + createDropOptimizationBarriersPass()); + + // Insert explicit discard ops for ref values at their last use points. + // Uses edge-based placement: refs dying on control flow edges get discards + // inserted on those edges, refs dying mid-block get discards after last use. + passManager.addNestedPass( + createMaterializeRefDiscardsPass()); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td index 06e540f76673..2300a01fa645 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.td @@ -158,6 +158,16 @@ def DropUnusedCallsPass : let summary = "Drops vm.call ops that have no side effects and are unused."; } +def DropOptimizationBarriersPass : + Pass<"iree-vm-drop-optimization-barriers", "IREE::VM::ModuleOp"> { + let summary = "Drops vm.optimization_barrier ops."; + let description = [{ + Removes vm.optimization_barrier ops by replacing them with their operands. + This pass should run after all optimization passes that could fold through + the barriers. + }]; +} + def SinkDefiningOpsPass : Pass<"iree-vm-sink-defining-ops", "IREE::VM::ModuleOp"> { let summary = "Sinks defining ops with few uses to their use-sites."; diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir index 06de0d74f865..695fc4dccfbd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/materialize_ref_discards.mlir @@ -2,13 +2,15 @@ // RUN: --pass-pipeline="builtin.module(vm.module(iree-vm-materialize-ref-discards))" \ // RUN: %s | FileCheck %s -// Single ref, single use - discard after use. +// Single ref, single use - NO discard (vm.call has MOVE semantics). // CHECK-LABEL: @single_ref_single_use // CHECK-SAME: (%[[BUF:.*]]: !vm.buffer) vm.module @my_module { vm.func @single_ref_single_use(%buf: !vm.buffer) { + // vm.call supports MOVE, so ref is consumed by call - no discard needed. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return vm.call @consume(%buf) : (!vm.buffer) -> () vm.return } @@ -17,16 +19,19 @@ vm.module @my_module { // ----- -// Multiple uses - discard after LAST use only. +// Multiple uses - NO discard (both calls have MOVE semantics, only last matters). // CHECK-LABEL: @multiple_uses // CHECK-SAME: (%[[BUF:.*]]: !vm.buffer) vm.module @my_module { vm.func @multiple_uses(%buf: !vm.buffer) { + // First call: not last use, no discard. // CHECK: vm.call @consume(%[[BUF]]) // CHECK-NOT: vm.discard.refs vm.call @consume(%buf) : (!vm.buffer) -> () + // Second call: last use with MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return vm.call @consume(%buf) : (!vm.buffer) -> () vm.return } @@ -58,8 +63,10 @@ vm.module @my_module { vm.cond_br %cond, ^bb1(%buf : !vm.buffer), ^bb2 // CHECK: ^[[BB1]](%[[ARG:.*]]: !vm.buffer): ^bb1(%arg: !vm.buffer): + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[ARG]]) - // CHECK-NEXT: vm.discard.refs %[[ARG]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%arg) : (!vm.buffer) -> () vm.br ^exit ^bb2: @@ -72,13 +79,15 @@ vm.module @my_module { // ----- -// Multiple refs dying at same point - batched into single discard. +// Multiple refs passed to call - NO discards (MOVE semantics). // CHECK-LABEL: @multiple_refs_same_death_point // CHECK-SAME: (%[[A:.*]]: !vm.buffer, %[[B:.*]]: !vm.buffer) vm.module @my_module { vm.func @multiple_refs_same_death_point(%a: !vm.buffer, %b: !vm.buffer) { + // Both refs consumed by call with MOVE semantics. // CHECK: vm.call @consume2(%[[A]], %[[B]]) - // CHECK-NEXT: vm.discard.refs %[[A]], %[[B]] : !vm.buffer, !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return vm.call @consume2(%a, %b) : (!vm.buffer, !vm.buffer) -> () vm.return } @@ -130,8 +139,10 @@ vm.module @my_module { vm.cond_br %cond, ^then, ^else // CHECK: ^[[THEN]]: ^then: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^exit // CHECK: ^[[ELSE]]: @@ -157,14 +168,18 @@ vm.module @my_module { vm.cond_br %cond, ^then, ^else // CHECK: ^[[THEN]]: ^then: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^exit // CHECK: ^[[ELSE]]: ^else: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[EXIT]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^exit ^exit: @@ -189,8 +204,10 @@ vm.module @my_module { vm.cond_br %cond, ^then, ^else // CHECK: ^[[THEN]]: ^then: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[USED]]) - // CHECK-NEXT: vm.discard.refs %[[USED]] : !vm.buffer + // CHECK-NOT: vm.discard.refs %[[USED]] + // CHECK-NEXT: vm.br ^[[EXIT:.*]] vm.call @consume(%used) : (!vm.buffer) -> () vm.br ^exit // CHECK: ^[[ELSE]]: @@ -235,14 +252,18 @@ vm.module @my_module { vm.cond_br %cond, ^left, ^right // CHECK: ^[[LEFT]]: ^left: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[MERGE:.*]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^merge // CHECK: ^[[RIGHT]]: ^right: + // vm.call has MOVE semantics - ref consumed by call. // CHECK: vm.call @consume(%[[BUF]]) - // CHECK-NEXT: vm.discard.refs %[[BUF]] : !vm.buffer + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.br ^[[MERGE]] vm.call @consume(%buf) : (!vm.buffer) -> () vm.br ^merge ^merge: @@ -314,8 +335,8 @@ vm.module @my_module { %ref = vm.cast.ref.any %buffer : !vm.buffer -> !vm.ref // CHECK: %[[CAST:.*]] = vm.cast.any.ref %[[REF]] %cast = vm.cast.any.ref %ref : !vm.ref -> !vm.buffer - // %ref is discarded after its last use (vm.cast.any.ref) - // CHECK: vm.discard.refs %[[REF]] + // vm.cast has MOVE semantics - %ref consumed by cast, no discard. + // CHECK-NOT: vm.discard.refs %[[REF]] // CHECK: vm.cmp.eq.ref %[[BUFFER]], %[[CAST]] %eq = vm.cmp.eq.ref %buffer, %cast : !vm.buffer // Both %buffer and %cast die at same point - batched discard @@ -327,7 +348,8 @@ vm.module @my_module { // ----- // Each cast produces a new ref with independent lifetime. -// Refs are discarded after their last use. +// vm.cast has MOVE semantics - refs consumed by casts, not discarded. +// Only refs passed to vm.call operations (which also have MOVE) are consumed. // CHECK-LABEL: @chained_casts_independent vm.module @my_module { vm.func @chained_casts_independent() { @@ -339,18 +361,21 @@ vm.module @my_module { %ref1 = vm.cast.ref.any %buf : !vm.buffer -> !vm.ref // CHECK: %[[BUF2:.*]] = vm.cast.any.ref %[[REF1]] %buf2 = vm.cast.any.ref %ref1 : !vm.ref -> !vm.buffer - // ref1's last use is vm.cast.any.ref, discard it now - // CHECK: vm.discard.refs %[[REF1]] + // vm.cast has MOVE semantics - ref1 consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF1]] // CHECK: %[[REF2:.*]] = vm.cast.ref.any %[[BUF2]] %ref2 = vm.cast.ref.any %buf2 : !vm.buffer -> !vm.ref - // buf2's last use is vm.cast.ref.any, discard it now - // CHECK: vm.discard.refs %[[BUF2]] + // vm.cast has MOVE semantics - buf2 consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF2]] // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] // CHECK: vm.call @use_ref(%[[REF2]]) vm.call @use_ref(%ref2) : (!vm.ref) -> () - // CHECK: vm.discard.refs %[[REF2]] + // vm.call has MOVE semantics - ref2 consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF2]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -359,7 +384,7 @@ vm.module @my_module { // ----- -// Each ref discarded after its last use, no aliasing. +// Each ref consumed by vm.call (MOVE semantics). // CHECK-LABEL: @ref_used_then_original_used vm.module @my_module { vm.func @ref_used_then_original_used() { @@ -371,11 +396,13 @@ vm.module @my_module { %ref = vm.cast.ref.any %buf : !vm.buffer -> !vm.ref // CHECK: vm.call @use_ref(%[[REF]]) vm.call @use_ref(%ref) : (!vm.ref) -> () - // ref's last use is use_ref, discard it - // CHECK: vm.discard.refs %[[REF]] + // vm.call has MOVE semantics - ref consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF]] // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -384,7 +411,7 @@ vm.module @my_module { // ----- -// Ref used in branch, original used after merge - independent lifetimes. +// Ref used in branch (vm.call has MOVE), original used after merge. // CHECK-LABEL: @ref_in_branch_original_after_merge vm.module @my_module { vm.func @ref_in_branch_original_after_merge(%cond: i32) { @@ -399,8 +426,9 @@ vm.module @my_module { ^left: // CHECK: vm.call @use_ref(%[[REF]]) vm.call @use_ref(%ref) : (!vm.ref) -> () - // ref's last use on this path - // CHECK: vm.discard.refs %[[REF]] + // vm.call has MOVE semantics - ref consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[REF]] + // CHECK-NEXT: vm.br ^[[MERGE:.*]] vm.br ^merge ^right: // ref not used on this path - edge discard @@ -409,7 +437,9 @@ vm.module @my_module { ^merge: // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -418,7 +448,7 @@ vm.module @my_module { // ----- -// Ref used in loop, original used after loop exit - independent lifetimes. +// Ref used in loop (vm.call has MOVE), original used after loop exit. // CHECK-LABEL: @ref_in_loop_original_after vm.module @my_module { vm.func @ref_in_loop_original_after(%n: i32) { @@ -438,11 +468,14 @@ vm.module @my_module { %cmp = vm.cmp.lt.i32.s %next, %n : i32 vm.cond_br %cmp, ^loop(%next : i32), ^exit ^exit: - // ref is live throughout loop, dies at exit + // ref is NOT live at exit - it's consumed by vm.call in the loop. + // The last iteration's vm.call has MOVE semantics - ref consumed. // CHECK: vm.discard.refs %[[REF]] // CHECK: vm.call @use_buffer(%[[BUF]]) vm.call @use_buffer(%buf) : (!vm.buffer) -> () - // CHECK: vm.discard.refs %[[BUF]] + // vm.call has MOVE semantics - buf consumed, no discard. + // CHECK-NOT: vm.discard.refs %[[BUF]] + // CHECK-NEXT: vm.return vm.return } vm.import private @use_buffer(%buf: !vm.buffer) @@ -975,3 +1008,175 @@ vm.module @my_module { vm.return } } + +// ----- + +//===----------------------------------------------------------------------===// +// MOVE semantics for regular vm.call and vm.call.variadic +// These are the key tests for the bug fix: non-terminator calls that support +// MOVE semantics must NOT have discards inserted for their ref operands. +//===----------------------------------------------------------------------===// + +// vm.call with ref operand at last use - MOVE semantics, no discard. +// This was the original bug: mid-block discard logic would insert a discard +// after the call, but the call already consumed the ref with MOVE. +// CHECK-LABEL: @call_ref_move_last_use +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_ref_move_last_use(%buf: !vm.buffer) { + // Ref passed to call with MOVE semantics - no discard should be inserted. + // CHECK: vm.call @consume(%[[BUF:.*]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf) : (!vm.buffer) -> () + vm.return + } +} + +// ----- + +// vm.call.variadic with ref operands at last use - MOVE semantics, no discard. +// This is the specific case from the smoketest.mlir bug report. +// CHECK-LABEL: @call_variadic_ref_move_last_use +vm.module @my_module { + vm.import private @hal.command_buffer.dispatch(!vm.buffer, !vm.ref, i32, i32, i32, i32) + vm.func @call_variadic_ref_move_last_use(%cmd: !vm.buffer, %exec: !vm.ref) { + %c0 = vm.const.i32 0 + %c1 = vm.const.i32 1 + // Ref operands passed with MOVE semantics - no discards should be inserted. + // CHECK: vm.call.variadic @hal.command_buffer.dispatch(%[[CMD:.*]], %[[EXEC:.*]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call.variadic @hal.command_buffer.dispatch(%cmd, %exec, %c0, %c1, %c1, %c1) : (!vm.buffer, !vm.ref, i32, i32, i32, i32) -> () + vm.return + } +} + +// ----- + +// Multiple refs passed to vm.call - all with MOVE semantics. +// CHECK-LABEL: @call_multiple_ref_operands_move +vm.module @my_module { + vm.import private @multi(!vm.buffer, !vm.buffer, !vm.ref) + vm.func @call_multiple_ref_operands_move(%buf1: !vm.buffer, %buf2: !vm.buffer, %ref: !vm.ref) { + // Use refs before call to ensure they're live. + // CHECK-DAG: vm.cmp.nz.ref %[[BUF1:[^ ]+]] + %nz1 = vm.cmp.nz.ref %buf1 : !vm.buffer + // CHECK-DAG: vm.cmp.nz.ref %[[BUF2:[^ ]+]] + %nz2 = vm.cmp.nz.ref %buf2 : !vm.buffer + // CHECK-DAG: vm.cmp.nz.ref %[[REF:[^ ]+]] + %nz3 = vm.cmp.nz.ref %ref : !vm.ref + // All refs passed to call with MOVE - no discards. + // CHECK-NOT: vm.discard.refs + // CHECK: vm.call @multi(%[[BUF1]], %[[BUF2]], %[[REF]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @multi(%buf1, %buf2, %ref) : (!vm.buffer, !vm.buffer, !vm.ref) -> () + vm.return + } +} + +// ----- + +// Ref used, then NOT passed to call - still needs discard. +// This verifies the fix is precise: only refs actually passed to MOVE calls +// are exempted from mid-block discards. +// CHECK-LABEL: @call_ref_not_passed +vm.module @my_module { + vm.import private @compute(i32) + vm.func @call_ref_not_passed(%buf: !vm.buffer, %x: i32) { + // CHECK: vm.cmp.nz.ref %[[BUF:[^ ]+]] + %nz = vm.cmp.nz.ref %buf : !vm.buffer + // Ref NOT passed to call, so it needs a discard after its last use. + // CHECK-NEXT: vm.discard.refs %[[BUF]] + // CHECK: vm.call @compute + vm.call @compute(%x) : (i32) -> () + vm.return + } +} + +// ----- + +// Mixed scenario: one ref passed to call (MOVE), another not passed (discard). +// CHECK-LABEL: @call_mixed_ref_operands +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_mixed_ref_operands(%buf1: !vm.buffer, %buf2: !vm.buffer) { + // CHECK-DAG: vm.cmp.nz.ref %[[BUF1:[^ ]+]] + %nz1 = vm.cmp.nz.ref %buf1 : !vm.buffer + // CHECK-DAG: vm.cmp.nz.ref %[[BUF2:[^ ]+]] + %nz2 = vm.cmp.nz.ref %buf2 : !vm.buffer + // buf2 NOT passed to call, discarded after its last use. + // CHECK: vm.discard.refs %[[BUF2]] + // buf1 passed to call with MOVE, not discarded. + // CHECK: vm.call @consume(%[[BUF1]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf1) : (!vm.buffer) -> () + vm.return + } +} + +// ----- + +// Ref passed to multiple calls - only last call gets MOVE, earlier uses need ref retained. +// CHECK-LABEL: @call_ref_multiple_calls +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_ref_multiple_calls(%buf: !vm.buffer) { + // First call: not last use, no discard. + // CHECK: vm.call @consume(%[[BUF:.*]]) + // CHECK-NOT: vm.discard.refs + vm.call @consume(%buf) : (!vm.buffer) -> () + // Second call: last use, MOVE semantics, no discard. + // CHECK: vm.call @consume(%[[BUF]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf) : (!vm.buffer) -> () + vm.return + } +} + +// ----- + +// vm.call.variadic with mixed ref and non-ref operands. +// CHECK-LABEL: @call_variadic_mixed_operands +vm.module @my_module { + vm.import private @mixed(!vm.buffer, i32, i32, !vm.ref, i32) + vm.func @call_variadic_mixed_operands(%buf: !vm.buffer, %ref: !vm.ref) { + %c1 = vm.const.i32 1 + %c2 = vm.const.i32 2 + %c3 = vm.const.i32 3 + // Refs passed with MOVE, integers are just values. + // CHECK: vm.call.variadic @mixed(%[[BUF:.*]], %{{.*}}, %{{.*}}, %[[REF:.*]], %{{.*}}) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call.variadic @mixed(%buf, %c1, %c2, %ref, %c3) : (!vm.buffer, i32, i32, !vm.ref, i32) -> () + vm.return + } +} + +// ----- + +// Control flow with vm.call: ref used in one branch, passed to call in another. +// CHECK-LABEL: @call_ref_control_flow +vm.module @my_module { + vm.import private @consume(!vm.buffer) + vm.func @call_ref_control_flow(%buf: !vm.buffer, %cond: i32) { + // CHECK: vm.cond_br %{{.*}}, ^[[USE:.*]], ^[[CALL:.*]] + vm.cond_br %cond, ^use, ^call + ^use: + // Ref used here, then discarded. + // CHECK: vm.cmp.nz.ref %[[BUF:.*]] + // CHECK-NEXT: vm.discard.refs %[[BUF]] + %nz = vm.cmp.nz.ref %buf : !vm.buffer + vm.return + ^call: + // Ref passed to call with MOVE here, not discarded. + // CHECK: vm.call @consume(%[[BUF:.*]]) + // CHECK-NOT: vm.discard.refs + // CHECK-NEXT: vm.return + vm.call @consume(%buf) : (!vm.buffer) -> () + vm.return + } +} diff --git a/compiler/src/iree/compiler/Tools/iree_compile_lib.cc b/compiler/src/iree/compiler/Tools/iree_compile_lib.cc index 378fd373d26c..0f4586f53528 100644 --- a/compiler/src/iree/compiler/Tools/iree_compile_lib.cc +++ b/compiler/src/iree/compiler/Tools/iree_compile_lib.cc @@ -270,24 +270,30 @@ int mlir::iree_compiler::runIreecMain(int argc, char **argv) { // Switch on compileMode to choose a pipeline to run. switch (compileMode) { case CompileMode::std: - if (!ireeCompilerInvocationPipeline(r.inv, IREE_COMPILER_PIPELINE_STD)) + if (!ireeCompilerInvocationPipeline(r.inv, IREE_COMPILER_PIPELINE_STD)) { return false; + } break; case CompileMode::vm: + if (!ireeCompilerInvocationPipeline(r.inv, IREE_COMPILER_PIPELINE_VM)) { + return false; + } break; case CompileMode::hal_executable: { // Compiling a HAL executable, it is only valid to output in that form. outputFormat = OutputFormat::hal_executable; if (!ireeCompilerInvocationPipeline( - r.inv, IREE_COMPILER_PIPELINE_HAL_EXECUTABLE)) + r.inv, IREE_COMPILER_PIPELINE_HAL_EXECUTABLE)) { return false; + } break; } case CompileMode::precompile: { outputFormat = OutputFormat::precompile; if (!ireeCompilerInvocationPipeline(r.inv, - IREE_COMPILER_PIPELINE_PRECOMPILE)) + IREE_COMPILER_PIPELINE_PRECOMPILE)) { return false; + } break; } default: diff --git a/runtime/src/iree/vm/test/arithmetic_ops.mlir b/runtime/src/iree/vm/test/arithmetic_ops.mlir index 4ec12e1e6266..294180e934f2 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops { vm.export @test_add_i32 vm.func @test_add_i32() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.add.i32 %c1dno, %c1dno : i32 %c2 = vm.const.i32 2 vm.check.eq %v, %c2, "1+1=2" : i32 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops { vm.export @test_sub_i32 vm.func @test_sub_i32() { %c1 = vm.const.i32 3 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.sub.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 1 vm.check.eq %v, %c3, "3-2=1" : i32 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops { vm.export @test_mul_i32 vm.func @test_mul_i32() { %c1 = vm.const.i32 2 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.mul.i32 %c1dno, %c1dno : i32 %c2 = vm.const.i32 4 vm.check.eq %v, %c2, "2*2=4" : i32 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops { vm.export @test_div_i32s vm.func @test_div_i32s() { %c1 = vm.const.i32 4 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 -2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.div.i32.s %c1dno, %c2dno : i32 %c3 = vm.const.i32 -2 vm.check.eq %v, %c3, "4/-2=-2" : i32 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops { vm.export @test_div_i32u vm.func @test_div_i32u() { %c1 = vm.const.i32 4 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.div.i32.u %c1dno, %c2dno : i32 %c3 = vm.const.i32 2 vm.check.eq %v, %c3, "4/2=2" : i32 @@ -63,9 +63,9 @@ vm.module @arithmetic_ops { vm.export @test_rem_i32s vm.func @test_rem_i32s() { %c1 = vm.const.i32 -3 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 -2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.rem.i32.s %c1dno, %c2dno : i32 %c3 = vm.const.i32 -1 vm.check.eq %v, %c3, "-3%-2=-1" : i32 @@ -75,9 +75,9 @@ vm.module @arithmetic_ops { vm.export @test_rem_i32u vm.func @test_rem_i32u() { %c1 = vm.const.i32 3 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.rem.i32.u %c1dno, %c2dno : i32 %c3 = vm.const.i32 1 vm.check.eq %v, %c3, "3%2=1" : i32 @@ -87,11 +87,11 @@ vm.module @arithmetic_ops { vm.export @test_fma_i32 vm.func @test_fma_i32() { %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %c3 = vm.const.i32 3 - %c3dno = util.optimization_barrier %c3 : i32 + %c3dno = vm.optimization_barrier %c3 : i32 %c5 = vm.const.i32 5 - %c5dno = util.optimization_barrier %c5 : i32 + %c5dno = vm.optimization_barrier %c5 : i32 %v = vm.fma.i32 %c2dno, %c3dno, %c5dno : i32 %c11 = vm.const.i32 11 vm.check.eq %v, %c11, "2*3+5=11" : i32 @@ -101,7 +101,7 @@ vm.module @arithmetic_ops { vm.export @test_abs_i32 vm.func @test_abs_i32() { %cn1 = vm.const.i32 -1 - %cn1dno = util.optimization_barrier %cn1 : i32 + %cn1dno = vm.optimization_barrier %cn1 : i32 %v = vm.abs.i32 %cn1dno : i32 %c1 = vm.const.i32 1 vm.check.eq %v, %c1, "abs(-1)=1" : i32 @@ -111,9 +111,9 @@ vm.module @arithmetic_ops { vm.export @test_min_i32s vm.func @test_min_i32s() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.min.i32.s %cn3dno, %c2dno : i32 vm.check.eq %v, %cn3, "smin(-3,2)=-3" : i32 vm.return @@ -122,9 +122,9 @@ vm.module @arithmetic_ops { vm.export @test_min_i32u vm.func @test_min_i32u() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.min.i32.u %cn3dno, %c2dno : i32 vm.check.eq %v, %c2, "umin(-3,2)=2" : i32 vm.return @@ -133,9 +133,9 @@ vm.module @arithmetic_ops { vm.export @test_max_i32s vm.func @test_max_i32s() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.max.i32.s %cn3dno, %c2dno : i32 vm.check.eq %v, %c2, "smax(-3,2)=2" : i32 vm.return @@ -144,9 +144,9 @@ vm.module @arithmetic_ops { vm.export @test_max_i32u vm.func @test_max_i32u() { %cn3 = vm.const.i32 -3 - %cn3dno = util.optimization_barrier %cn3 : i32 + %cn3dno = vm.optimization_barrier %cn3 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.max.i32.u %cn3dno, %c2dno : i32 vm.check.eq %v, %cn3, "umax(-3,2)=-3" : i32 vm.return @@ -155,7 +155,7 @@ vm.module @arithmetic_ops { vm.export @test_not_i32 vm.func @test_not_i32() { %c1 = vm.const.i32 0 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.not.i32 %c1dno : i32 %c2 = vm.const.i32 -1 vm.check.eq %v, %c2, "~0=-1" : i32 @@ -165,9 +165,9 @@ vm.module @arithmetic_ops { vm.export @test_and_i32 vm.func @test_and_i32() { %c1 = vm.const.i32 5 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 3 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.and.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 1 vm.check.eq %v, %c3, "5&3=1" : i32 @@ -177,9 +177,9 @@ vm.module @arithmetic_ops { vm.export @test_or_i32 vm.func @test_or_i32() { %c1 = vm.const.i32 5 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 3 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.or.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 7 vm.check.eq %v, %c3, "5|3=7" : i32 @@ -189,9 +189,9 @@ vm.module @arithmetic_ops { vm.export @test_xor_i32 vm.func @test_xor_i32() { %c1 = vm.const.i32 5 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 3 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 %v = vm.xor.i32 %c1dno, %c2dno : i32 %c3 = vm.const.i32 6 vm.check.eq %v, %c3, "5^3=6" : i32 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops { vm.export @test_ctlz_i32_const_zero vm.func @test_ctlz_i32_const_zero() { %c = vm.const.i32 0 - %cdno = util.optimization_barrier %c : i32 + %cdno = vm.optimization_barrier %c : i32 %actual = vm.ctlz.i32 %cdno : i32 %expected = vm.const.i32 32 vm.check.eq %actual, %expected, "ctlz(0)=32" : i32 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops { vm.export @test_ctlz_i32_const_1 vm.func @test_ctlz_i32_const_1() { %c = vm.const.i32 1 - %cdno = util.optimization_barrier %c : i32 + %cdno = vm.optimization_barrier %c : i32 %actual = vm.ctlz.i32 %cdno : i32 %expected = vm.const.i32 31 vm.check.eq %actual, %expected, "ctlz(1)=31" : i32 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops { vm.export @test_ctlz_i32_const_ffffffff vm.func @test_ctlz_i32_const_ffffffff() { %c = vm.const.i32 0xFFFFFFFF - %cdno = util.optimization_barrier %c : i32 + %cdno = vm.optimization_barrier %c : i32 %actual = vm.ctlz.i32 %cdno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "ctlz(0xFFFFFFFF)=0" : i32 diff --git a/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir b/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir index 2d3fd2ecaf4e..17fefb772796 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops_f32.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_add_f32 vm.func @test_add_f32() { %c1 = vm.const.f32 1.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.add.f32 %c1dno, %c1dno : f32 %c2 = vm.const.f32 3.0 vm.check.eq %v, %c2, "1.5+1.5=3" : f32 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_sub_f32 vm.func @test_sub_f32() { %c1 = vm.const.f32 3.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 2.5 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.sub.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 0.5 vm.check.eq %v, %c3, "3.0-2.5=0.5" : f32 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_mul_f32 vm.func @test_mul_f32() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.mul.f32 %c1dno, %c1dno : f32 %c2 = vm.const.f32 6.25 vm.check.eq %v, %c2, "2.5*2.5=6.25" : f32 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_div_f32 vm.func @test_div_f32() { %c1 = vm.const.f32 4.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 -2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.div.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 -2.0 vm.check.eq %v, %c3, "4.0/-2.0=-2.0" : f32 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_rem_f32 vm.func @test_rem_f32() { %c1 = vm.const.f32 -3.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 -2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.rem.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 1.0 vm.check.eq %v, %c3, "-3.0%-2.0=1.0" : f32 @@ -63,11 +63,11 @@ vm.module @arithmetic_ops_f32 { vm.export @test_fma_f32 vm.func @test_fma_f32() { %c2 = vm.const.f32 2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %c3 = vm.const.f32 3.0 - %c3dno = util.optimization_barrier %c3 : f32 + %c3dno = vm.optimization_barrier %c3 : f32 %c5 = vm.const.f32 5.0 - %c5dno = util.optimization_barrier %c5 : f32 + %c5dno = vm.optimization_barrier %c5 : f32 %v = vm.fma.f32 %c2dno, %c3dno, %c5dno : f32 %c11 = vm.const.f32 11.0 vm.check.eq %v, %c11, "2.0*3.0+5.0=11.0" : f32 @@ -77,7 +77,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_abs_f32 vm.func @test_abs_f32() { %c1 = vm.const.f32 -1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.abs.f32 %c1dno : f32 %c2 = vm.const.f32 1.0 vm.check.eq %v, %c2, "abs(-1.0)=1.0" : f32 @@ -87,7 +87,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_neg_f32 vm.func @test_neg_f32() { %c1 = vm.const.f32 -1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.neg.f32 %c1dno : f32 %c2 = vm.const.f32 1.0 vm.check.eq %v, %c2, "neg(-1.0)=1.0" : f32 @@ -97,7 +97,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_ceil_f32 vm.func @test_ceil_f32() { %c1 = vm.const.f32 1.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.ceil.f32 %c1dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "ceil(1.5)=2.0" : f32 @@ -107,7 +107,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_floor_f32 vm.func @test_floor_f32() { %c15 = vm.const.f32 1.5 - %c15dno = util.optimization_barrier %c15 : f32 + %c15dno = vm.optimization_barrier %c15 : f32 %v = vm.floor.f32 %c15dno : f32 %c1 = vm.const.f32 1.0 vm.check.eq %v, %c1, "floor(1.5)=1.0" : f32 @@ -117,7 +117,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_round_f32 vm.func @test_round_f32() { %c15 = vm.const.f32 1.5 - %c15dno = util.optimization_barrier %c15 : f32 + %c15dno = vm.optimization_barrier %c15 : f32 %v = vm.round.f32 %c15dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "round(1.5)=2.0" : f32 @@ -127,7 +127,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_round_f32_even vm.func @test_round_f32_even() { %c15 = vm.const.f32 1.5 - %c15dno = util.optimization_barrier %c15 : f32 + %c15dno = vm.optimization_barrier %c15 : f32 %v = vm.round.f32.even %c15dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "roundeven(1.5)=2.0" : f32 @@ -137,9 +137,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_min_f32 vm.func @test_min_f32() { %cn3 = vm.const.f32 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f32 + %cn3dno = vm.optimization_barrier %cn3 : f32 %cn2 = vm.const.f32 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f32 + %cn2dno = vm.optimization_barrier %cn2 : f32 %v = vm.min.f32 %cn3dno, %cn2dno : f32 vm.check.eq %v, %cn3, "min(-3.0,-2.0)=-3.0" : f32 vm.return @@ -148,9 +148,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_max_f32 vm.func @test_max_f32() { %cn3 = vm.const.f32 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f32 + %cn3dno = vm.optimization_barrier %cn3 : f32 %cn2 = vm.const.f32 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f32 + %cn2dno = vm.optimization_barrier %cn2 : f32 %v = vm.max.f32 %cn3dno, %cn2dno : f32 vm.check.eq %v, %cn2, "max(-3.0,-2.0)=-2.0" : f32 vm.return @@ -159,7 +159,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_atan_f32 vm.func @test_atan_f32() { %c1 = vm.const.f32 1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.atan.f32 %c1dno : f32 %c2 = vm.const.f32 0.7853981633974483 vm.check.eq %v, %c2, "atan(1.0)=0.7853981633974483" : f32 @@ -169,9 +169,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_atan2_f32 vm.func @test_atan2_f32() { %c1 = vm.const.f32 1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 0.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.atan2.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 1.5707963267948966 vm.check.eq %v, %c3, "atan2(1.0,0.0)=1.5707963267948966" : f32 @@ -181,7 +181,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_cos_f32 vm.func @test_cos_f32() { %c1 = vm.const.f32 0.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cos.f32 %c1dno : f32 %c2 = vm.const.f32 0.8775825618903728 vm.check.eq %v, %c2, "cos(0.5)=0.8775825618903728" : f32 @@ -191,7 +191,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_sin_f32 vm.func @test_sin_f32() { %c1 = vm.const.f32 0.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.sin.f32 %c1dno : f32 %c2 = vm.const.f32 0.479425538604203 vm.check.eq %v, %c2, "sin(0.5)=0.479425538604203" : f32 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_exp_f32 vm.func @test_exp_f32() { %c1 = vm.const.f32 1.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.exp.f32 %c1dno : f32 %c2 = vm.const.f32 2.718281828459045 vm.check.eq %v, %c2, "exp(1.0)=2.718281828459045" : f32 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_exp2_f32 vm.func @test_exp2_f32() { %c1 = vm.const.f32 2.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.exp2.f32 %c1dno : f32 %c2 = vm.const.f32 4.0 vm.check.eq %v, %c2, "exp(2.0)=4.0" : f32 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_expm1_f32 vm.func @test_expm1_f32() { %c1 = vm.const.f32 2.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.expm1.f32 %c1dno : f32 %c2 = vm.const.f32 6.38905609893065 vm.check.eq %v, %c2, "expm1(2.0)=6.38905609893065" : f32 @@ -231,7 +231,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log_f32 vm.func @test_log_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log.f32 %c1dno : f32 %c2 = vm.const.f32 2.302585092994046 vm.check.eq %v, %c2, "log(10.0)=2.302585092994046" : f32 @@ -241,7 +241,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log10_f32 vm.func @test_log10_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log10.f32 %c1dno : f32 %c2 = vm.const.f32 1.0 vm.check.eq %v, %c2, "log10(10.0)=1.0" : f32 @@ -251,7 +251,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log1p_f32 vm.func @test_log1p_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log1p.f32 %c1dno : f32 %c2 = vm.const.f32 2.3978952727983707 vm.check.eq %v, %c2, "log1p(10.0)=2.3978952727983707" : f32 @@ -261,7 +261,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_log2_f32 vm.func @test_log2_f32() { %c1 = vm.const.f32 10.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.log2.f32 %c1dno : f32 %c2 = vm.const.f32 3.321928094887362 vm.check.eq %v, %c2, "log2(10.0)=3.321928094887362" : f32 @@ -271,9 +271,9 @@ vm.module @arithmetic_ops_f32 { vm.export @test_pow_f32 vm.func @test_pow_f32() { %c1 = vm.const.f32 3.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %c2 = vm.const.f32 2.0 - %c2dno = util.optimization_barrier %c2 : f32 + %c2dno = vm.optimization_barrier %c2 : f32 %v = vm.pow.f32 %c1dno, %c2dno : f32 %c3 = vm.const.f32 9.0 vm.check.eq %v, %c3, "pow(3.0,2.0)=9.0" : f32 @@ -283,7 +283,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_rsqrt_f32 vm.func @test_rsqrt_f32() { %c1 = vm.const.f32 4.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.rsqrt.f32 %c1dno : f32 %c2 = vm.const.f32 0.5 vm.check.eq %v, %c2, "rsqrt(4.0)=0.5" : f32 @@ -293,7 +293,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_sqrt_f32 vm.func @test_sqrt_f32() { %c1 = vm.const.f32 4.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.sqrt.f32 %c1dno : f32 %c2 = vm.const.f32 2.0 vm.check.eq %v, %c2, "sqrt(4.0)=2.0" : f32 @@ -303,7 +303,7 @@ vm.module @arithmetic_ops_f32 { vm.export @test_tanh_f32 vm.func @test_tanh_f32() { %c1 = vm.const.f32 0.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.tanh.f32 %c1dno : f32 %c2 = vm.const.f32 0.46211715726000974 vm.check.eq %v, %c2, "tanh(0.5)=0.46211715726000974" : f32 @@ -314,7 +314,7 @@ vm.module @arithmetic_ops_f32 { // vm.export @test_erf_f32 // vm.func @test_erf_f32() { // %c1 = vm.const.f32 0.5 - // %c1dno = util.optimization_barrier %c1 : f32 + // %c1dno = vm.optimization_barrier %c1 : f32 // %v = vm.erf.f32 %c1dno : f32 // %c2 = vm.const.f32 0.520499945 // vm.check.eq %v, %c2, "erf(0.5)=0.520499945" : f32 diff --git a/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir b/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir index 78c4df9b7086..91384cba6c4f 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops_f64.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_add_f64 vm.func @test_add_f64() { %c1 = vm.const.f64 1.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.add.f64 %c1dno, %c1dno : f64 %c2 = vm.const.f64 3.0 vm.check.eq %v, %c2, "1.5+1.5=3" : f64 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_sub_f64 vm.func @test_sub_f64() { %c1 = vm.const.f64 3.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 2.5 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.sub.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 0.5 vm.check.eq %v, %c3, "3.0-2.5=0.5" : f64 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_mul_f64 vm.func @test_mul_f64() { %c1 = vm.const.f64 2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.mul.f64 %c1dno, %c1dno : f64 %c2 = vm.const.f64 6.25 vm.check.eq %v, %c2, "2.5*2.5=6.25" : f64 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_div_f64 vm.func @test_div_f64() { %c1 = vm.const.f64 4.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 -2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.div.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 -2.0 vm.check.eq %v, %c3, "4.0/-2.0=-2.0" : f64 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_rem_f64 vm.func @test_rem_f64() { %c1 = vm.const.f64 -3.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 -2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.rem.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 1.0 vm.check.eq %v, %c3, "-3.0%-2.0=1.0" : f64 @@ -63,11 +63,11 @@ vm.module @arithmetic_ops_f64 { vm.export @test_fma_f64 vm.func @test_fma_f64() { %c2 = vm.const.f64 2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %c3 = vm.const.f64 3.0 - %c3dno = util.optimization_barrier %c3 : f64 + %c3dno = vm.optimization_barrier %c3 : f64 %c5 = vm.const.f64 5.0 - %c5dno = util.optimization_barrier %c5 : f64 + %c5dno = vm.optimization_barrier %c5 : f64 %v = vm.fma.f64 %c2dno, %c3dno, %c5dno : f64 %c11 = vm.const.f64 11.0 vm.check.eq %v, %c11, "2.0*3.0+5.0=11.0" : f64 @@ -77,7 +77,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_abs_f64 vm.func @test_abs_f64() { %c1 = vm.const.f64 -1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.abs.f64 %c1dno : f64 %c2 = vm.const.f64 1.0 vm.check.eq %v, %c2, "abs(-1.0)=1.0" : f64 @@ -87,7 +87,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_neg_f64 vm.func @test_neg_f64() { %c1 = vm.const.f64 -1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.neg.f64 %c1dno : f64 %c2 = vm.const.f64 1.0 vm.check.eq %v, %c2, "neg(-1.0)=1.0" : f64 @@ -97,7 +97,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_ceil_f64 vm.func @test_ceil_f64() { %c1 = vm.const.f64 1.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.ceil.f64 %c1dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "ceil(1.5)=2.0" : f64 @@ -107,7 +107,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_floor_f64 vm.func @test_floor_f64() { %c15 = vm.const.f64 1.5 - %c15dno = util.optimization_barrier %c15 : f64 + %c15dno = vm.optimization_barrier %c15 : f64 %v = vm.floor.f64 %c15dno : f64 %c1 = vm.const.f64 1.0 vm.check.eq %v, %c1, "floor(1.5)=1.0" : f64 @@ -117,7 +117,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_round_f64 vm.func @test_round_f64() { %c15 = vm.const.f64 1.5 - %c15dno = util.optimization_barrier %c15 : f64 + %c15dno = vm.optimization_barrier %c15 : f64 %v = vm.round.f64 %c15dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "round(1.5)=2.0" : f64 @@ -127,7 +127,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_round_f64_even vm.func @test_round_f64_even() { %c15 = vm.const.f64 1.5 - %c15dno = util.optimization_barrier %c15 : f64 + %c15dno = vm.optimization_barrier %c15 : f64 %v = vm.round.f64.even %c15dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "roundeven(1.5)=2.0" : f64 @@ -137,9 +137,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_min_f64 vm.func @test_min_f64() { %cn3 = vm.const.f64 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f64 + %cn3dno = vm.optimization_barrier %cn3 : f64 %cn2 = vm.const.f64 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f64 + %cn2dno = vm.optimization_barrier %cn2 : f64 %v = vm.min.f64 %cn3dno, %cn2dno : f64 vm.check.eq %v, %cn3, "min(-3.0,-2.0)=-3.0" : f64 vm.return @@ -148,9 +148,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_max_f64 vm.func @test_max_f64() { %cn3 = vm.const.f64 -3.0 - %cn3dno = util.optimization_barrier %cn3 : f64 + %cn3dno = vm.optimization_barrier %cn3 : f64 %cn2 = vm.const.f64 -2.0 - %cn2dno = util.optimization_barrier %cn2 : f64 + %cn2dno = vm.optimization_barrier %cn2 : f64 %v = vm.max.f64 %cn3dno, %cn2dno : f64 vm.check.eq %v, %cn2, "max(-3.0,-2.0)=-2.0" : f64 vm.return @@ -159,7 +159,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_atan_f64 vm.func @test_atan_f64() { %c1 = vm.const.f64 1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.atan.f64 %c1dno : f64 %c2 = vm.const.f64 0.7853981633974483 vm.check.eq %v, %c2, "atan(1.0)=0.7853981633974483" : f64 @@ -169,9 +169,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_atan2_f64 vm.func @test_atan2_f64() { %c1 = vm.const.f64 1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 0.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.atan2.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 1.5707963267948966 vm.check.eq %v, %c3, "atan2(1.0,0.0)=1.5707963267948966" : f64 @@ -181,7 +181,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_cos_f64 vm.func @test_cos_f64() { %c1 = vm.const.f64 0.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cos.f64 %c1dno : f64 %c2 = vm.const.f64 0.8775825618903728 vm.check.eq %v, %c2, "cos(0.5)=0.8775825618903728" : f64 @@ -191,7 +191,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_sin_f64 vm.func @test_sin_f64() { %c1 = vm.const.f64 0.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.sin.f64 %c1dno : f64 %c2 = vm.const.f64 0.479425538604203 vm.check.eq %v, %c2, "sin(0.5)=0.479425538604203" : f64 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_exp_f64 vm.func @test_exp_f64() { %c1 = vm.const.f64 1.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.exp.f64 %c1dno : f64 %c2 = vm.const.f64 2.718281828459045 vm.check.eq %v, %c2, "exp(1.0)=2.718281828459045" : f64 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_exp2_f64 vm.func @test_exp2_f64() { %c1 = vm.const.f64 2.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.exp2.f64 %c1dno : f64 %c2 = vm.const.f64 4.0 vm.check.eq %v, %c2, "exp(2.0)=4.0" : f64 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_expm1_f64 vm.func @test_expm1_f64() { %c1 = vm.const.f64 2.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.expm1.f64 %c1dno : f64 %c2 = vm.const.f64 6.38905609893065 vm.check.eq %v, %c2, "expm1(2.0)=6.38905609893065" : f64 @@ -231,7 +231,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log_f64 vm.func @test_log_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log.f64 %c1dno : f64 %c2 = vm.const.f64 2.302585092994046 vm.check.eq %v, %c2, "log(10.0)=2.302585092994046" : f64 @@ -241,7 +241,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log10_f64 vm.func @test_log10_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log10.f64 %c1dno : f64 %c2 = vm.const.f64 1.0 vm.check.eq %v, %c2, "log10(10.0)=1.0" : f64 @@ -251,7 +251,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log1p_f64 vm.func @test_log1p_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log1p.f64 %c1dno : f64 %c2 = vm.const.f64 2.3978952727983707 vm.check.eq %v, %c2, "log1p(10.0)=2.3978952727983707" : f64 @@ -261,7 +261,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_log2_f64 vm.func @test_log2_f64() { %c1 = vm.const.f64 10.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.log2.f64 %c1dno : f64 %c2 = vm.const.f64 3.321928094887362 vm.check.eq %v, %c2, "log2(10.0)=3.321928094887362" : f64 @@ -271,9 +271,9 @@ vm.module @arithmetic_ops_f64 { vm.export @test_pow_f64 vm.func @test_pow_f64() { %c1 = vm.const.f64 3.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %c2 = vm.const.f64 2.0 - %c2dno = util.optimization_barrier %c2 : f64 + %c2dno = vm.optimization_barrier %c2 : f64 %v = vm.pow.f64 %c1dno, %c2dno : f64 %c3 = vm.const.f64 9.0 vm.check.eq %v, %c3, "pow(3.0,2.0)=9.0" : f64 @@ -283,7 +283,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_rsqrt_f64 vm.func @test_rsqrt_f64() { %c1 = vm.const.f64 4.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.rsqrt.f64 %c1dno : f64 %c2 = vm.const.f64 0.5 vm.check.eq %v, %c2, "rsqrt(4.0)=0.5" : f64 @@ -293,7 +293,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_sqrt_f64 vm.func @test_sqrt_f64() { %c1 = vm.const.f64 4.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.sqrt.f64 %c1dno : f64 %c2 = vm.const.f64 2.0 vm.check.eq %v, %c2, "sqrt(4.0)=2.0" : f64 @@ -303,7 +303,7 @@ vm.module @arithmetic_ops_f64 { vm.export @test_tanh_f64 vm.func @test_tanh_f64() { %c1 = vm.const.f64 0.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.tanh.f64 %c1dno : f64 %c2 = vm.const.f64 0.46211715726000974 vm.check.eq %v, %c2, "tanh(0.5)=0.46211715726000974" : f64 @@ -314,7 +314,7 @@ vm.module @arithmetic_ops_f64 { // vm.export @test_erf_f64 // vm.func @test_erf_f64() { // %c1 = vm.const.f64 0.5 - // %c1dno = util.optimization_barrier %c1 : f64 + // %c1dno = vm.optimization_barrier %c1 : f64 // %v = vm.erf.f64 %c1dno : f64 // %c2 = vm.const.f64 0.520499945 // vm.check.eq %v, %c2, "erf(0.5)=0.520499945" : f64 diff --git a/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir b/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir index b6cc8a2653c6..5658d9b158d8 100644 --- a/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir +++ b/runtime/src/iree/vm/test/arithmetic_ops_i64.mlir @@ -7,7 +7,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_add_i64 vm.func @test_add_i64() { %c1 = vm.const.i64 1 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.add.i64 %c1dno, %c1dno : i64 %c2 = vm.const.i64 2 vm.check.eq %v, %c2, "1+1=2" : i64 @@ -17,9 +17,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_sub_i64 vm.func @test_sub_i64() { %c1 = vm.const.i64 3 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.sub.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 1 vm.check.eq %v, %c3, "3-2=1" : i64 @@ -29,7 +29,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_mul_i64 vm.func @test_mul_i64() { %c1 = vm.const.i64 2 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.mul.i64 %c1dno, %c1dno : i64 %c2 = vm.const.i64 4 vm.check.eq %v, %c2, "2*2=4" : i64 @@ -39,9 +39,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_div_i64s vm.func @test_div_i64s() { %c1 = vm.const.i64 4 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 -2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.div.i64.s %c1dno, %c2dno : i64 %c3 = vm.const.i64 -2 vm.check.eq %v, %c3, "4/-2=-2" : i64 @@ -51,9 +51,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_div_i64u vm.func @test_div_i64u() { %c1 = vm.const.i64 4 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.div.i64.u %c1dno, %c2dno : i64 %c3 = vm.const.i64 2 vm.check.eq %v, %c3, "4/2=2" : i64 @@ -63,9 +63,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_rem_i64s vm.func @test_rem_i64s() { %c1 = vm.const.i64 -3 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 -2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.rem.i64.s %c1dno, %c2dno : i64 %c3 = vm.const.i64 -1 vm.check.eq %v, %c3, "-3%-2=-1" : i64 @@ -75,9 +75,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_rem_i64u vm.func @test_rem_i64u() { %c1 = vm.const.i64 3 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.rem.i64.u %c1dno, %c2dno : i64 %c3 = vm.const.i64 1 vm.check.eq %v, %c3, "3%2=1" : i64 @@ -87,11 +87,11 @@ vm.module @arithmetic_ops_i64 { vm.export @test_fma_i64 vm.func @test_fma_i64() { %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %c3 = vm.const.i64 3 - %c3dno = util.optimization_barrier %c3 : i64 + %c3dno = vm.optimization_barrier %c3 : i64 %c5 = vm.const.i64 5 - %c5dno = util.optimization_barrier %c5 : i64 + %c5dno = vm.optimization_barrier %c5 : i64 %v = vm.fma.i64 %c2dno, %c3dno, %c5dno : i64 %c11 = vm.const.i64 11 vm.check.eq %v, %c11, "2*3+5=11" : i64 @@ -101,7 +101,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_abs_i64 vm.func @test_abs_i64() { %cn1 = vm.const.i64 -1 - %cn1dno = util.optimization_barrier %cn1 : i64 + %cn1dno = vm.optimization_barrier %cn1 : i64 %v = vm.abs.i64 %cn1dno : i64 %c1 = vm.const.i64 1 vm.check.eq %v, %c1, "abs(-1)=1" : i64 @@ -111,9 +111,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_min_i64s vm.func @test_min_i64s() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.min.i64.s %cn3dno, %c2dno : i64 vm.check.eq %v, %cn3, "smin(-3,2)=-3" : i64 vm.return @@ -122,9 +122,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_min_i64u vm.func @test_min_i64u() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.min.i64.u %cn3dno, %c2dno : i64 vm.check.eq %v, %c2, "umin(-3,2)=2" : i64 vm.return @@ -133,9 +133,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_max_i64s vm.func @test_max_i64s() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.max.i64.s %cn3dno, %c2dno : i64 vm.check.eq %v, %c2, "smax(-3,2)=2" : i64 vm.return @@ -144,9 +144,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_max_i64u vm.func @test_max_i64u() { %cn3 = vm.const.i64 -3 - %cn3dno = util.optimization_barrier %cn3 : i64 + %cn3dno = vm.optimization_barrier %cn3 : i64 %c2 = vm.const.i64 2 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.max.i64.u %cn3dno, %c2dno : i64 vm.check.eq %v, %cn3, "umax(-3,2)=-3" : i64 vm.return @@ -155,7 +155,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_not_i64 vm.func @test_not_i64() { %c1 = vm.const.i64 0 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.not.i64 %c1dno : i64 %c2 = vm.const.i64 -1 vm.check.eq %v, %c2, "~0=-1" : i64 @@ -165,9 +165,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_and_i64 vm.func @test_and_i64() { %c1 = vm.const.i64 5 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 3 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.and.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 1 vm.check.eq %v, %c3, "5&3=1" : i64 @@ -177,9 +177,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_or_i64 vm.func @test_or_i64() { %c1 = vm.const.i64 5 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 3 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.or.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 7 vm.check.eq %v, %c3, "5|3=7" : i64 @@ -189,9 +189,9 @@ vm.module @arithmetic_ops_i64 { vm.export @test_xor_i64 vm.func @test_xor_i64() { %c1 = vm.const.i64 5 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %c2 = vm.const.i64 3 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %v = vm.xor.i64 %c1dno, %c2dno : i64 %c3 = vm.const.i64 6 vm.check.eq %v, %c3, "5^3=6" : i64 @@ -201,7 +201,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_ctlz_i64_const_zero vm.func @test_ctlz_i64_const_zero() { %c = vm.const.i64 0 - %cdno = util.optimization_barrier %c : i64 + %cdno = vm.optimization_barrier %c : i64 %actual = vm.ctlz.i64 %cdno : i64 %expected = vm.const.i64 64 vm.check.eq %actual, %expected, "ctlz(0)=64" : i64 @@ -211,7 +211,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_ctlz_i64_const_1 vm.func @test_ctlz_i64_const_1() { %c = vm.const.i64 1 - %cdno = util.optimization_barrier %c : i64 + %cdno = vm.optimization_barrier %c : i64 %actual = vm.ctlz.i64 %cdno : i64 %expected = vm.const.i64 63 vm.check.eq %actual, %expected, "ctlz(1)=63" : i64 @@ -221,7 +221,7 @@ vm.module @arithmetic_ops_i64 { vm.export @test_ctlz_i64_const_ffffffffffffffff vm.func @test_ctlz_i64_const_ffffffffffffffff() { %c = vm.const.i64 0xFFFFFFFFFFFFFFFF - %cdno = util.optimization_barrier %c : i64 + %cdno = vm.optimization_barrier %c : i64 %actual = vm.ctlz.i64 %cdno : i64 %expected = vm.const.i64 0 vm.check.eq %actual, %expected, "ctlz(0xFFFFFFFFFFFFFFFF)=0" : i64 diff --git a/runtime/src/iree/vm/test/assignment_ops.mlir b/runtime/src/iree/vm/test/assignment_ops.mlir index 891165da8bc3..c9fc08005d63 100644 --- a/runtime/src/iree/vm/test/assignment_ops.mlir +++ b/runtime/src/iree/vm/test/assignment_ops.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops { vm.export @test_select_i32 vm.func @test_select_i32() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v1 = vm.select.i32 %c0dno, %c0dno, %c1dno : i32 vm.check.eq %v1, %c1, "0 ? 0 : 1 = 1" : i32 %v2 = vm.select.i32 %c1dno, %c0dno, %c1dno : i32 @@ -24,7 +24,7 @@ vm.module @assignment_ops { %c1 = vm.const.i32 1 %list1 = vm.list.alloc %c1 : (i32) -> !vm.list %cond = vm.const.i32 0 - %cond_dno = util.optimization_barrier %cond : i32 + %cond_dno = vm.optimization_barrier %cond : i32 %list = vm.select.ref %cond_dno, %list0, %list1 : !vm.list vm.check.eq %list, %list1, "0 ? list0 : list1 = list1" : !vm.list vm.return @@ -41,17 +41,17 @@ vm.module @assignment_ops { %c300 = vm.const.i32 300 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.i32 %i0_dno[%c100, %c200] else %c300 : i32 vm.check.eq %v0, %c100, "index 0 is 100" : i32 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.i32 %i1_dno[%c100, %c200] else %c300 : i32 vm.check.eq %v1, %c200, "index 1 is 200" : i32 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.i32 %i2_dno[%c100, %c200] else %c300 : i32 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : i32 diff --git a/runtime/src/iree/vm/test/assignment_ops_f32.mlir b/runtime/src/iree/vm/test/assignment_ops_f32.mlir index 6a0246c16f71..5a368da575fb 100644 --- a/runtime/src/iree/vm/test/assignment_ops_f32.mlir +++ b/runtime/src/iree/vm/test/assignment_ops_f32.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops_f32 { vm.export @test_select_f32 vm.func @test_select_f32() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.f32 0.0 %c3 = vm.const.f32 1.0 %v1 = vm.select.f32 %c0dno, %c2, %c3 : f32 @@ -30,17 +30,17 @@ vm.module @assignment_ops_f32 { %c300 = vm.const.f32 300.0 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.f32 %i0_dno[%c100, %c200] else %c300 : f32 vm.check.eq %v0, %c100, "index 0 is 100" : f32 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.f32 %i1_dno[%c100, %c200] else %c300 : f32 vm.check.eq %v1, %c200, "index 1 is 200" : f32 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.f32 %i2_dno[%c100, %c200] else %c300 : f32 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : f32 diff --git a/runtime/src/iree/vm/test/assignment_ops_f64.mlir b/runtime/src/iree/vm/test/assignment_ops_f64.mlir index 7f9d6443f22b..13f6c3820607 100644 --- a/runtime/src/iree/vm/test/assignment_ops_f64.mlir +++ b/runtime/src/iree/vm/test/assignment_ops_f64.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops_f64 { vm.export @test_select_f64 vm.func @test_select_f64() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.f64 0.0 %c3 = vm.const.f64 1.0 %v1 = vm.select.f64 %c0dno, %c2, %c3 : f64 @@ -30,17 +30,17 @@ vm.module @assignment_ops_f64 { %c300 = vm.const.f64 300.0 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.f64 %i0_dno[%c100, %c200] else %c300 : f64 vm.check.eq %v0, %c100, "index 0 is 100" : f64 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.f64 %i1_dno[%c100, %c200] else %c300 : f64 vm.check.eq %v1, %c200, "index 1 is 200" : f64 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.f64 %i2_dno[%c100, %c200] else %c300 : f64 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : f64 diff --git a/runtime/src/iree/vm/test/assignment_ops_i64.mlir b/runtime/src/iree/vm/test/assignment_ops_i64.mlir index a0d9bc18f03f..c2bd579ed7e9 100644 --- a/runtime/src/iree/vm/test/assignment_ops_i64.mlir +++ b/runtime/src/iree/vm/test/assignment_ops_i64.mlir @@ -7,9 +7,9 @@ vm.module @assignment_ops_i64 { vm.export @test_select_i64 vm.func @test_select_i64() { %c0 = vm.const.i32 0 - %c0dno = util.optimization_barrier %c0 : i32 + %c0dno = vm.optimization_barrier %c0 : i32 %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i64 0 %c3 = vm.const.i64 1 %v1 = vm.select.i64 %c0dno, %c2, %c3 : i64 @@ -30,17 +30,17 @@ vm.module @assignment_ops_i64 { %c300 = vm.const.i64 300 %i0 = vm.const.i32 0 - %i0_dno = util.optimization_barrier %i0 : i32 + %i0_dno = vm.optimization_barrier %i0 : i32 %v0 = vm.switch.i64 %i0_dno[%c100, %c200] else %c300 : i64 vm.check.eq %v0, %c100, "index 0 is 100" : i64 %i1 = vm.const.i32 1 - %i1_dno = util.optimization_barrier %i1 : i32 + %i1_dno = vm.optimization_barrier %i1 : i32 %v1 = vm.switch.i64 %i1_dno[%c100, %c200] else %c300 : i64 vm.check.eq %v1, %c200, "index 1 is 200" : i64 %i2 = vm.const.i32 2 - %i2_dno = util.optimization_barrier %i2 : i32 + %i2_dno = vm.optimization_barrier %i2 : i32 %v2 = vm.switch.i64 %i2_dno[%c100, %c200] else %c300 : i64 vm.check.eq %v2, %c300, "index 2 (out of bounds) is default 300" : i64 diff --git a/runtime/src/iree/vm/test/async_ops.mlir b/runtime/src/iree/vm/test/async_ops.mlir index 7a21e63a3d0e..ef75abe7e549 100644 --- a/runtime/src/iree/vm/test/async_ops.mlir +++ b/runtime/src/iree/vm/test/async_ops.mlir @@ -9,17 +9,17 @@ vm.module @async_ops { vm.func @test_yield_sequence() { %c1 = vm.const.i32 1 %c100 = vm.const.i32 100 - %c100_dno = util.optimization_barrier %c100 : i32 + %c100_dno = vm.optimization_barrier %c100 : i32 %y0 = vm.add.i32 %c100_dno, %c1 : i32 - %y0_dno = util.optimization_barrier %y0 : i32 + %y0_dno = vm.optimization_barrier %y0 : i32 vm.yield ^bb1 ^bb1: %y1 = vm.add.i32 %y0_dno, %c1 : i32 - %y1_dno = util.optimization_barrier %y1 : i32 + %y1_dno = vm.optimization_barrier %y1 : i32 vm.yield ^bb2 ^bb2: %y2 = vm.add.i32 %y1_dno, %c1 : i32 - %y2_dno = util.optimization_barrier %y2 : i32 + %y2_dno = vm.optimization_barrier %y2 : i32 vm.yield ^bb3 ^bb3: %c103 = vm.const.i32 103 @@ -36,10 +36,10 @@ vm.module @async_ops { %cond = vm.cmp.nz.i32 %c1 : i32 vm.cond_br %cond, ^true, ^false ^true: - %v_true = util.optimization_barrier %c100 : i32 + %v_true = vm.optimization_barrier %c100 : i32 vm.yield ^check(%v_true : i32) ^false: - %v_false = util.optimization_barrier %c200 : i32 + %v_false = vm.optimization_barrier %c200 : i32 vm.yield ^check(%v_false : i32) ^check(%result : i32): vm.check.eq %result, %c100, "cond=1 selects true branch" : i32 @@ -55,10 +55,10 @@ vm.module @async_ops { %cond = vm.cmp.nz.i32 %c0 : i32 vm.cond_br %cond, ^true, ^false ^true: - %v_true = util.optimization_barrier %c100 : i32 + %v_true = vm.optimization_barrier %c100 : i32 vm.yield ^check(%v_true : i32) ^false: - %v_false = util.optimization_barrier %c200 : i32 + %v_false = vm.optimization_barrier %c200 : i32 vm.yield ^check(%v_false : i32) ^check(%result : i32): vm.check.eq %result, %c200, "cond=0 selects false branch" : i32 @@ -74,15 +74,15 @@ vm.module @async_ops { vm.func private @yield_counter(%start : i32) -> i32 { %c1 = vm.const.i32 1 %v0 = vm.add.i32 %start, %c1 : i32 - %v0_dno = util.optimization_barrier %v0 : i32 + %v0_dno = vm.optimization_barrier %v0 : i32 vm.yield ^y1 ^y1: %v1 = vm.add.i32 %v0_dno, %c1 : i32 - %v1_dno = util.optimization_barrier %v1 : i32 + %v1_dno = vm.optimization_barrier %v1 : i32 vm.yield ^y2 ^y2: %v2 = vm.add.i32 %v1_dno, %c1 : i32 - %v2_dno = util.optimization_barrier %v2 : i32 + %v2_dno = vm.optimization_barrier %v2 : i32 vm.yield ^y3 ^y3: %v3 = vm.add.i32 %v2_dno, %c1 : i32 @@ -105,7 +105,7 @@ vm.module @async_ops { vm.func private @yield_add_one(%arg0: i32) -> i32 { %c1 = vm.const.i32 1 %result = vm.add.i32 %arg0, %c1 : i32 - %result_dno = util.optimization_barrier %result : i32 + %result_dno = vm.optimization_barrier %result : i32 vm.yield ^done ^done: vm.return %result_dno : i32 @@ -213,7 +213,7 @@ vm.module @async_ops { %c2 = vm.const.i32 2 // Add 1 before yield %v0 = vm.add.i32 %arg0, %c1 : i32 - %v0_dno = util.optimization_barrier %v0 : i32 + %v0_dno = vm.optimization_barrier %v0 : i32 vm.yield ^after_first_yield ^after_first_yield: // Call yieldable import (yields 2 times) @@ -221,7 +221,7 @@ vm.module @async_ops { ^after_import(%v1 : i32): // Add 1 after import %v2 = vm.add.i32 %v1, %c1 : i32 - %v2_dno = util.optimization_barrier %v2 : i32 + %v2_dno = vm.optimization_barrier %v2 : i32 vm.yield ^final ^final: vm.return %v2_dno : i32 diff --git a/runtime/src/iree/vm/test/buffer_ops.mlir b/runtime/src/iree/vm/test/buffer_ops.mlir index 74eebabcade9..b0114d78d1cb 100644 --- a/runtime/src/iree/vm/test/buffer_ops.mlir +++ b/runtime/src/iree/vm/test/buffer_ops.mlir @@ -16,8 +16,8 @@ vm.module @buffer_ops { vm.func @test_compare() { %rodata_a = vm.const.ref.rodata @rodata_cmp_3xi32_a : !vm.buffer %rodata_b = vm.const.ref.rodata @rodata_cmp_3xi32_b : !vm.buffer - %rodata_a_dno = util.optimization_barrier %rodata_a : !vm.buffer - %rodata_b_dno = util.optimization_barrier %rodata_b : !vm.buffer + %rodata_a_dno = vm.optimization_barrier %rodata_a : !vm.buffer + %rodata_b_dno = vm.optimization_barrier %rodata_b : !vm.buffer %c0 = vm.const.i64 0 %length = vm.buffer.length %rodata_a_dno : !vm.buffer -> i64 @@ -37,8 +37,8 @@ vm.module @buffer_ops { vm.func @test_compare_empty() { %rodata_a = vm.const.ref.rodata @rodata_cmp_3xi32_a : !vm.buffer %rodata_b = vm.const.ref.rodata @rodata_cmp_3xi32_b : !vm.buffer - %rodata_a_dno = util.optimization_barrier %rodata_a : !vm.buffer - %rodata_b_dno = util.optimization_barrier %rodata_b : !vm.buffer + %rodata_a_dno = vm.optimization_barrier %rodata_a : !vm.buffer + %rodata_b_dno = vm.optimization_barrier %rodata_b : !vm.buffer %c0 = vm.const.i64 0 %c2 = vm.const.i64 2 @@ -59,7 +59,7 @@ vm.module @buffer_ops { %c128 = vm.const.i64 128 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c128, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer %buf_length = vm.buffer.length %buf_dno : !vm.buffer -> i64 @@ -74,7 +74,7 @@ vm.module @buffer_ops { %c0 = vm.const.i64 0 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c0, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer %buf_length = vm.buffer.length %buf_dno : !vm.buffer -> i64 @@ -98,7 +98,7 @@ vm.module @buffer_ops { %c8 = vm.const.i64 8 %alignment = vm.const.i32 16 %buf = vm.buffer.clone %rodata, %c4, %c8, %alignment : !vm.buffer -> !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Compare the cloned range to the original. @@ -116,14 +116,14 @@ vm.module @buffer_ops { %c0 = vm.const.i64 0 %alignment = vm.const.i32 16 %buf0 = vm.buffer.alloc %c0, %alignment : !vm.buffer - %buf0_dno = util.optimization_barrier %buf0 : !vm.buffer + %buf0_dno = vm.optimization_barrier %buf0 : !vm.buffer vm.check.nz %buf0_dno, "!null" : !vm.buffer %buf0_length = vm.buffer.length %buf0_dno : !vm.buffer -> i64 vm.check.eq %c0, %buf0_length, "buffer length == 0" : i64 // Clone it all (or, clone nothing?). %buf1 = vm.buffer.clone %buf0_dno, %c0, %c0, %alignment : !vm.buffer -> !vm.buffer - %buf1_dno = util.optimization_barrier %buf1 : !vm.buffer + %buf1_dno = vm.optimization_barrier %buf1 : !vm.buffer vm.check.nz %buf1_dno, "!null" : !vm.buffer %buf1_length = vm.buffer.length %buf1_dno : !vm.buffer -> i64 vm.check.eq %c0, %buf1_length, "buffer length == 0" : i64 @@ -136,7 +136,7 @@ vm.module @buffer_ops { vm.func @fail_clone_out_of_range() { // Fetch source .rodata blob. %rodata = vm.const.ref.rodata @rodata_3xi32 : !vm.buffer - %rodata_dno = util.optimization_barrier %rodata : !vm.buffer + %rodata_dno = vm.optimization_barrier %rodata : !vm.buffer vm.check.nz %rodata_dno, "!null" : !vm.buffer // Try to clone off the end of the buffer. @@ -162,7 +162,7 @@ vm.module @buffer_ops { // Allocate target buffer. %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %rodata_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Copy the entire contents. @@ -185,7 +185,7 @@ vm.module @buffer_ops { %c4 = vm.const.i64 4 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c4, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Copy the middle 4-byte element. @@ -208,7 +208,7 @@ vm.module @buffer_ops { %c128 = vm.const.i64 128 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c128, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the source buffer. @@ -225,7 +225,7 @@ vm.module @buffer_ops { %c128 = vm.const.i64 128 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c128, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the source buffer. @@ -244,7 +244,7 @@ vm.module @buffer_ops { %c8 = vm.const.i64 8 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c8, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the target buffer. @@ -261,7 +261,7 @@ vm.module @buffer_ops { %c8 = vm.const.i64 8 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %c8, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Try to clone off the end of the target buffer. @@ -286,7 +286,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -315,7 +315,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -344,7 +344,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -373,7 +373,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -402,7 +402,7 @@ vm.module @buffer_ops { %buffer_size = vm.mul.i64 %num_elements, %element_size : i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %buffer_size, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer vm.check.nz %buf_dno, "!null" : !vm.buffer // Fill the middle two elements. @@ -583,12 +583,12 @@ vm.module @buffer_ops { vm.export @test_store_i8 vm.func @test_store_i8() { %ref = vm.const.ref.rodata @test_store_i8_ref : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %ref_length = vm.buffer.length %ref_dno : !vm.buffer -> i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %ref_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 @@ -617,12 +617,12 @@ vm.module @buffer_ops { vm.export @test_store_i16 vm.func @test_store_i16() { %ref = vm.const.ref.rodata @test_store_i16_ref : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %ref_length = vm.buffer.length %ref_dno : !vm.buffer -> i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %ref_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 @@ -651,12 +651,12 @@ vm.module @buffer_ops { vm.export @test_store_i32 vm.func @test_store_i32() { %ref = vm.const.ref.rodata @test_store_i32_ref : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %ref_length = vm.buffer.length %ref_dno : !vm.buffer -> i64 %alignment = vm.const.i32 16 %buf = vm.buffer.alloc %ref_length, %alignment : !vm.buffer - %buf_dno = util.optimization_barrier %buf : !vm.buffer + %buf_dno = vm.optimization_barrier %buf : !vm.buffer %c0 = vm.const.i64 0 %e0 = vm.const.i32 0 diff --git a/runtime/src/iree/vm/test/call_ops.mlir b/runtime/src/iree/vm/test/call_ops.mlir index d79bafa258dd..d103c72b050f 100644 --- a/runtime/src/iree/vm/test/call_ops.mlir +++ b/runtime/src/iree/vm/test/call_ops.mlir @@ -42,10 +42,10 @@ vm.module @call_ops { vm.func private @test_call_r_v_preserve_ref() { %ref = vm.const.ref.zero : !vm.buffer %unused = vm.const.ref.rodata @buffer : !vm.buffer - %unusued_dno_1 = util.optimization_barrier %unused : !vm.buffer + %unusued_dno_1 = vm.optimization_barrier %unused : !vm.buffer vm.check.nz %unused : !vm.buffer vm.call @_r_v_preserve_reg(%ref, %unused) : (!vm.buffer, !vm.buffer) -> () - %unusued_dno_2 = util.optimization_barrier %unused : !vm.buffer + %unusued_dno_2 = vm.optimization_barrier %unused : !vm.buffer vm.check.nz %unusued_dno_2 : !vm.buffer vm.return } @@ -61,7 +61,7 @@ vm.module @call_ops { vm.export @test_call_v_r vm.func @test_call_v_r() { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref %res = vm.call @_v_r() : () -> (!vm.ref) vm.check.eq %ref_dno, %res, "_v_r()=NULL" : !vm.ref vm.return @@ -91,21 +91,21 @@ vm.module @call_ops { vm.func @_r_v(%arg : !vm.ref) attributes {inlining_policy = #util.inline.never} { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %arg, %ref_dno, "Expected %arg to be NULL" : !vm.ref vm.return } vm.func @_r_v_reuse_reg(%arg : !vm.ref, %unused : !vm.ref) attributes {inlining_policy = #util.inline.never} { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %arg, %ref_dno, "Expected %arg to be NULL" : !vm.ref vm.return } vm.func @_r_v_preserve_reg(%arg1 : !vm.ref, %arg2 : !vm.ref) attributes {inlining_policy = #util.inline.never} { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %arg1, %ref_dno, "Expected %arg1 to be NULL" : !vm.ref vm.check.nz %arg2, "Expected %arg2 to be not NULL" : !vm.ref vm.return diff --git a/runtime/src/iree/vm/test/comparison_ops.mlir b/runtime/src/iree/vm/test/comparison_ops.mlir index f0095452e806..a25c92207c9c 100644 --- a/runtime/src/iree/vm/test/comparison_ops.mlir +++ b/runtime/src/iree/vm/test/comparison_ops.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_s_0 vm.func @test_cmp_lt_s_0() { %lhs = vm.const.i32 2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 -2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.s %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "2 < -2" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_s_1 vm.func @test_cmp_lt_s_1() { %lhs = vm.const.i32 -2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.s %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-2 < 2" : i32 @@ -32,9 +32,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_s_2 vm.func @test_cmp_lt_s_2() { %lhs = vm.const.i32 4294967295 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.s %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "4294967295 (UINT_MAX) < 2" : i32 @@ -48,9 +48,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_u_0 vm.func @test_cmp_lt_u_0() { %lhs = vm.const.i32 2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 -2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.u %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "2 < -2 (as unsigned)" : i32 @@ -60,9 +60,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_u_1 vm.func @test_cmp_lt_u_1() { %lhs = vm.const.i32 -2 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.u %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "-2 < 2 (as unsigned)" : i32 @@ -72,9 +72,9 @@ vm.module @comparison_ops { vm.export @test_cmp_lt_u_2 vm.func @test_cmp_lt_u_2() { %lhs = vm.const.i32 4294967295 - %lhs_dno = util.optimization_barrier %lhs : i32 + %lhs_dno = vm.optimization_barrier %lhs : i32 %rhs = vm.const.i32 2 - %rhs_dno = util.optimization_barrier %rhs : i32 + %rhs_dno = vm.optimization_barrier %rhs : i32 %actual = vm.cmp.lt.i32.u %lhs_dno, %rhs_dno : i32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4294967295 (UINT_MAX) < 2 (as unsigned)" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops { %false = vm.const.i32 0 %cn2 = vm.const.i32 -2 - %cn2_dno = util.optimization_barrier %cn2 : i32 + %cn2_dno = vm.optimization_barrier %cn2 : i32 %c2 = vm.const.i32 2 - %c2_dno = util.optimization_barrier %c2 : i32 + %c2_dno = vm.optimization_barrier %c2 : i32 %cmp_0 = vm.cmp.lte.i32.s %cn2_dno, %c2_dno : i32 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -121,9 +121,9 @@ vm.module @comparison_ops { %false = vm.const.i32 0 %cn2 = vm.const.i32 -2 - %cn2_dno = util.optimization_barrier %cn2 : i32 + %cn2_dno = vm.optimization_barrier %cn2 : i32 %c2 = vm.const.i32 2 - %c2_dno = util.optimization_barrier %c2 : i32 + %c2_dno = vm.optimization_barrier %c2 : i32 %cmp_0 = vm.cmp.gt.i32.s %cn2_dno, %c2_dno : i32 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -148,9 +148,9 @@ vm.module @comparison_ops { %false = vm.const.i32 0 %cn2 = vm.const.i32 -2 - %cn2_dno = util.optimization_barrier %cn2 : i32 + %cn2_dno = vm.optimization_barrier %cn2 : i32 %c2 = vm.const.i32 2 - %c2_dno = util.optimization_barrier %c2 : i32 + %c2_dno = vm.optimization_barrier %c2 : i32 %cmp_0 = vm.cmp.gte.i32.s %cn2_dno, %c2_dno : i32 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/comparison_ops_f32.mlir b/runtime/src/iree/vm/test/comparison_ops_f32.mlir index 363a02e50638..3d074eddf1ac 100644 --- a/runtime/src/iree/vm/test/comparison_ops_f32.mlir +++ b/runtime/src/iree/vm/test/comparison_ops_f32.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops_f32 { vm.export @test_cmp_lt_0_f32 vm.func @test_cmp_lt_0_f32() { %lhs = vm.const.f32 4.0 - %lhs_dno = util.optimization_barrier %lhs : f32 + %lhs_dno = vm.optimization_barrier %lhs : f32 %rhs = vm.const.f32 -4.0 - %rhs_dno = util.optimization_barrier %rhs : f32 + %rhs_dno = vm.optimization_barrier %rhs : f32 %actual = vm.cmp.lt.f32.o %lhs_dno, %rhs_dno : f32 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4.0 < -4.0" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops_f32 { vm.export @test_cmp_lt_1_f32 vm.func @test_cmp_lt_1_f32() { %lhs = vm.const.f32 -4.0 - %lhs_dno = util.optimization_barrier %lhs : f32 + %lhs_dno = vm.optimization_barrier %lhs : f32 %rhs = vm.const.f32 4.0 - %rhs_dno = util.optimization_barrier %rhs : f32 + %rhs_dno = vm.optimization_barrier %rhs : f32 %actual = vm.cmp.lt.f32.o %lhs_dno, %rhs_dno : f32 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-4.0 < 4.0" : i32 @@ -41,9 +41,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.eq.f32.near %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %false, "-2 !~ 2" : i32 @@ -56,9 +56,9 @@ vm.module @comparison_ops_f32 { // off by 84 ULPs, arbitrary threshold sets these as "near enough" %c1a = vm.const.f32 1.00002 - %c1a_dno = util.optimization_barrier %c1a : f32 + %c1a_dno = vm.optimization_barrier %c1a : f32 %c1b = vm.const.f32 1.00003 - %c1b_dno = util.optimization_barrier %c1b : f32 + %c1b_dno = vm.optimization_barrier %c1b : f32 %cmp_4 = vm.cmp.eq.f32.near %c1a_dno, %c1b_dno : f32 vm.check.eq %cmp_4, %true, "1.00002 ~ 1.00003" : i32 @@ -74,9 +74,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.lte.f32.o %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.gt.f32.o %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -114,9 +114,9 @@ vm.module @comparison_ops_f32 { %false = vm.const.i32 0 %cn2 = vm.const.f32 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f32 + %cn2_dno = vm.optimization_barrier %cn2 : f32 %c2 = vm.const.f32 2.0 - %c2_dno = util.optimization_barrier %c2 : f32 + %c2_dno = vm.optimization_barrier %c2 : f32 %cmp_0 = vm.cmp.gte.f32.o %cn2_dno, %c2_dno : f32 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/comparison_ops_f64.mlir b/runtime/src/iree/vm/test/comparison_ops_f64.mlir index fb7a67f95332..01b451774fa0 100644 --- a/runtime/src/iree/vm/test/comparison_ops_f64.mlir +++ b/runtime/src/iree/vm/test/comparison_ops_f64.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops_f64 { vm.export @test_cmp_lt_0_f64 vm.func @test_cmp_lt_0_f64() { %lhs = vm.const.f64 4.0 - %lhs_dno = util.optimization_barrier %lhs : f64 + %lhs_dno = vm.optimization_barrier %lhs : f64 %rhs = vm.const.f64 -4.0 - %rhs_dno = util.optimization_barrier %rhs : f64 + %rhs_dno = vm.optimization_barrier %rhs : f64 %actual = vm.cmp.lt.f64.o %lhs_dno, %rhs_dno : f64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4.0 < -4.0" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops_f64 { vm.export @test_cmp_lt_1_f64 vm.func @test_cmp_lt_1_f64() { %lhs = vm.const.f64 -4.0 - %lhs_dno = util.optimization_barrier %lhs : f64 + %lhs_dno = vm.optimization_barrier %lhs : f64 %rhs = vm.const.f64 4.0 - %rhs_dno = util.optimization_barrier %rhs : f64 + %rhs_dno = vm.optimization_barrier %rhs : f64 %actual = vm.cmp.lt.f64.o %lhs_dno, %rhs_dno : f64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-4.0 < 4.0" : i32 @@ -41,9 +41,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.eq.f64.near %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %false, "-2 !~ 2" : i32 @@ -56,9 +56,9 @@ vm.module @comparison_ops_f64 { // off by 84 ULPs, arbitrary threshold sets these as "near enough" %c1a = vm.const.f64 1.00002 - %c1a_dno = util.optimization_barrier %c1a : f64 + %c1a_dno = vm.optimization_barrier %c1a : f64 %c1b = vm.const.f64 1.00003 - %c1b_dno = util.optimization_barrier %c1b : f64 + %c1b_dno = vm.optimization_barrier %c1b : f64 %cmp_4 = vm.cmp.eq.f64.near %c1a_dno, %c1b_dno : f64 vm.check.eq %cmp_4, %true, "1.00002 ~ 1.00003" : i32 @@ -74,9 +74,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.lte.f64.o %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.gt.f64.o %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -114,9 +114,9 @@ vm.module @comparison_ops_f64 { %false = vm.const.i32 0 %cn2 = vm.const.f64 -2.0 - %cn2_dno = util.optimization_barrier %cn2 : f64 + %cn2_dno = vm.optimization_barrier %cn2 : f64 %c2 = vm.const.f64 2.0 - %c2_dno = util.optimization_barrier %c2 : f64 + %c2_dno = vm.optimization_barrier %c2 : f64 %cmp_0 = vm.cmp.gte.f64.o %cn2_dno, %c2_dno : f64 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/comparison_ops_i64.mlir b/runtime/src/iree/vm/test/comparison_ops_i64.mlir index 3c10ef8e0c11..a8a44be0f7ed 100644 --- a/runtime/src/iree/vm/test/comparison_ops_i64.mlir +++ b/runtime/src/iree/vm/test/comparison_ops_i64.mlir @@ -7,9 +7,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_s_0_i64 vm.func @test_cmp_lt_s_0_i64() { %lhs = vm.const.i64 4294967295 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 -4294967295 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.s %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "4294967295 (UINT_MAX) < -4294967295 (UINT_MAX)" : i32 @@ -19,9 +19,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_s_1_i64 vm.func @test_cmp_lt_s_1_i64() { %lhs = vm.const.i64 -4294967295 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 4294967295 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.s %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "-4294967295 (UINT_MAX) < 4294967295 (UINT_MAX)" : i32 @@ -32,9 +32,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_s_2_i64 vm.func @test_cmp_lt_s_2_i64() { %lhs = vm.const.i64 18446744073709551615 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.s %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "18446744073709551615 (ULONG_MAX) < 2" : i32 @@ -48,9 +48,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_u_0_i64 vm.func @test_cmp_lt_u_0_i64() { %lhs = vm.const.i64 2 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 -2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.u %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 1 vm.check.eq %actual, %expected, "2 < -2 (as unsigned)" : i32 @@ -60,9 +60,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_u_1_i64 vm.func @test_cmp_lt_u_1_i64() { %lhs = vm.const.i64 -2 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.u %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "-2 < 2 (as unsigned)" : i32 @@ -72,9 +72,9 @@ vm.module @comparison_ops_i64 { vm.export @test_cmp_lt_u_2_i64 vm.func @test_cmp_lt_u_2_i64() { %lhs = vm.const.i64 18446744073709551615 - %lhs_dno = util.optimization_barrier %lhs : i64 + %lhs_dno = vm.optimization_barrier %lhs : i64 %rhs = vm.const.i64 2 - %rhs_dno = util.optimization_barrier %rhs : i64 + %rhs_dno = vm.optimization_barrier %rhs : i64 %actual = vm.cmp.lt.i64.u %lhs_dno, %rhs_dno : i64 %expected = vm.const.i32 0 vm.check.eq %actual, %expected, "18446744073709551615 (ULONG_MAX) < 2 (as unsigned)" : i32 @@ -94,9 +94,9 @@ vm.module @comparison_ops_i64 { %false = vm.const.i32 0 %cn2 = vm.const.i64 -2 - %cn2_dno = util.optimization_barrier %cn2 : i64 + %cn2_dno = vm.optimization_barrier %cn2 : i64 %c2 = vm.const.i64 2 - %c2_dno = util.optimization_barrier %c2 : i64 + %c2_dno = vm.optimization_barrier %c2 : i64 %cmp_0 = vm.cmp.lte.i64.s %cn2_dno, %c2_dno : i64 vm.check.eq %cmp_0, %true, "-2 <= 2" : i32 @@ -121,9 +121,9 @@ vm.module @comparison_ops_i64 { %false = vm.const.i32 0 %cn2 = vm.const.i64 -2 - %cn2_dno = util.optimization_barrier %cn2 : i64 + %cn2_dno = vm.optimization_barrier %cn2 : i64 %c2 = vm.const.i64 2 - %c2_dno = util.optimization_barrier %c2 : i64 + %c2_dno = vm.optimization_barrier %c2 : i64 %cmp_0 = vm.cmp.gt.i64.s %cn2_dno, %c2_dno : i64 vm.check.eq %cmp_0, %false, "-2 > 2" : i32 @@ -148,9 +148,9 @@ vm.module @comparison_ops_i64 { %false = vm.const.i32 0 %cn2 = vm.const.i64 -2 - %cn2_dno = util.optimization_barrier %cn2 : i64 + %cn2_dno = vm.optimization_barrier %cn2 : i64 %c2 = vm.const.i64 2 - %c2_dno = util.optimization_barrier %c2 : i64 + %c2_dno = vm.optimization_barrier %c2 : i64 %cmp_0 = vm.cmp.gte.i64.s %cn2_dno, %c2_dno : i64 vm.check.eq %cmp_0, %false, "-2 >= 2" : i32 diff --git a/runtime/src/iree/vm/test/control_flow_ops.mlir b/runtime/src/iree/vm/test/control_flow_ops.mlir index a091f942b7c1..902d838c965d 100644 --- a/runtime/src/iree/vm/test/control_flow_ops.mlir +++ b/runtime/src/iree/vm/test/control_flow_ops.mlir @@ -26,7 +26,7 @@ vm.module @control_flow_ops { vm.export @test_check_eq_always vm.func @test_check_eq_always() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.check.eq %c1, %c1dno, "error!" : i32 vm.return } @@ -35,8 +35,8 @@ vm.module @control_flow_ops { vm.func @fail_check_eq_never() { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 - %c1dno = util.optimization_barrier %c1 : i32 - %c2dno = util.optimization_barrier %c2 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 vm.check.eq %c1dno, %c2dno, "error!" : i32 vm.return } @@ -72,7 +72,7 @@ vm.module @control_flow_ops { vm.export @test_cond_br vm.func @test_cond_br() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1dno, ^bb1, ^bb2 ^bb1: vm.check.eq %c1dno, %c1dno, "error!" : i32 @@ -85,7 +85,7 @@ vm.module @control_flow_ops { vm.export @test_cond_br_int_arg vm.func @test_cond_br_int_arg() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1dno, ^bb1(%c1dno : i32), ^bb2(%c1dno : i32) ^bb1(%arg1 : i32): vm.check.eq %arg1, %c1dno, "error!" : i32 @@ -98,7 +98,7 @@ vm.module @control_flow_ops { vm.export @test_cond_br_ref_arg vm.func @test_cond_br_ref_arg() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %ref = vm.const.ref.zero : !vm.ref vm.cond_br %c1dno, ^bb1(%ref : !vm.ref), ^bb2(%ref : !vm.ref) ^bb1(%arg1 : !vm.ref): @@ -115,9 +115,9 @@ vm.module @control_flow_ops { vm.export @test_cond_br_same_successor attributes {emitc.exclude} vm.func private @test_cond_br_same_successor() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 - %c2dno = util.optimization_barrier %c2 : i32 + %c2dno = vm.optimization_barrier %c2 : i32 vm.cond_br %c1dno, ^bb1(%c1dno : i32), ^bb1(%c2dno : i32) ^bb1(%arg1 : i32): vm.check.eq %arg1, %c1dno, "error!" : i32 @@ -129,7 +129,7 @@ vm.module @control_flow_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 vm.br_table %c1dno { default: ^bb1(%c2 : i32), 0: ^bb2(%c0 : i32), @@ -148,7 +148,7 @@ vm.module @control_flow_ops { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 %c-1 = vm.const.i32 -1 - %c-1dno = util.optimization_barrier %c-1 : i32 + %c-1dno = vm.optimization_barrier %c-1 : i32 vm.br_table %c-1dno { default: ^bb1(%c0 : i32), 0: ^bb2(%c1 : i32), diff --git a/runtime/src/iree/vm/test/conversion_ops.mlir b/runtime/src/iree/vm/test/conversion_ops.mlir index 22374a8af34f..d6bdb11cc666 100644 --- a/runtime/src/iree/vm/test/conversion_ops.mlir +++ b/runtime/src/iree/vm/test/conversion_ops.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops { vm.export @test_trunc_i32_i8 vm.func private @test_trunc_i32_i8() { %c1 = vm.const.i32 2147483647 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.trunc.i32.i8 %c1dno : i32 -> i32 %c2 = vm.const.i32 255 vm.check.eq %v, %c2, "truncate unsigned i32 to unsigned i8" : i32 @@ -17,7 +17,7 @@ vm.module @conversion_ops { vm.export @test_trunc_i32_i16 vm.func private @test_trunc_i32_i16() { %c1 = vm.const.i32 2147483647 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.trunc.i32.i16 %c1dno : i32 -> i32 %c2 = vm.const.i32 65535 vm.check.eq %v, %c2, "truncate unsigned i32 to unsigned i16" : i32 @@ -30,7 +30,7 @@ vm.module @conversion_ops { %alignment = vm.const.i32 16 %buffer = vm.buffer.alloc %c128, %alignment : !vm.buffer %any = vm.cast.ref.any %buffer : !vm.buffer -> !vm.ref - %any_dno = util.optimization_barrier %any : !vm.ref + %any_dno = vm.optimization_barrier %any : !vm.ref %cast = vm.cast.any.ref %any_dno : !vm.ref -> !vm.buffer vm.check.eq %buffer, %cast, "cast should succeed" : !vm.buffer vm.return @@ -40,7 +40,7 @@ vm.module @conversion_ops { vm.func private @test_cast_any_ref_null() { %null = vm.const.ref.zero : !vm.buffer %any = vm.cast.ref.any %null : !vm.buffer -> !vm.ref - %any_dno = util.optimization_barrier %any : !vm.ref + %any_dno = vm.optimization_barrier %any : !vm.ref %cast = vm.cast.any.ref %any_dno : !vm.ref -> !vm.buffer vm.check.eq %null, %cast, "cast should succeed on nulls" : !vm.buffer vm.return @@ -52,10 +52,10 @@ vm.module @conversion_ops { %alignment = vm.const.i32 16 %buffer = vm.buffer.alloc %c128, %alignment : !vm.buffer %any = vm.cast.ref.any %buffer : !vm.buffer -> !vm.ref - %any_dno = util.optimization_barrier %any : !vm.ref + %any_dno = vm.optimization_barrier %any : !vm.ref // Should fail at runtime because of the type mismatch. %cast = vm.cast.any.ref %any_dno : !vm.ref -> !vm.list - util.optimization_barrier %cast : !vm.list + vm.optimization_barrier %cast : !vm.list vm.return } diff --git a/runtime/src/iree/vm/test/conversion_ops_f32.mlir b/runtime/src/iree/vm/test/conversion_ops_f32.mlir index bb893f77ddbf..dbc6b55b6d0f 100644 --- a/runtime/src/iree/vm/test/conversion_ops_f32.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_f32.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops_f32 { vm.export @test_bitcast_i32_f32 vm.func @test_bitcast_i32_f32() { %c1 = vm.const.i32 0x40B00000 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.bitcast.i32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 5.5 vm.check.eq %v, %c2, "bitcast i32 to f32" : f32 @@ -17,7 +17,7 @@ vm.module @conversion_ops_f32 { vm.export @test_bitcast_f32_i32 vm.func @test_bitcast_f32_i32() { %c1 = vm.const.f32 5.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.bitcast.f32.i32 %c1dno : f32 -> i32 %c2 = vm.const.i32 0x40B00000 vm.check.eq %v, %c2, "bitcast f32 to i32" : i32 @@ -27,7 +27,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_si32_f32_int_max vm.func @test_cast_si32_f32_int_max() { %c1 = vm.const.i32 2147483647 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.cast.si32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 2147483647.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f32 @@ -37,7 +37,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_si32_f32_int_min vm.func @test_cast_si32_f32_int_min() { %c1 = vm.const.i32 -2147483648 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.cast.si32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 -2147483648.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f32 @@ -47,7 +47,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_ui32_f32_int_max vm.func @test_cast_ui32_f32_int_max() { %c1 = vm.const.i32 4294967295 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %v = vm.cast.ui32.f32 %c1dno : i32 -> f32 %c2 = vm.const.f32 4294967295.0 vm.check.eq %v, %c2, "cast unsigned integer to a floating-point value" : f32 @@ -59,7 +59,7 @@ vm.module @conversion_ops_f32 { // This is the maximum value that is representable precisely as both i32 // and f32. An exponent of 30 with all mantissa bits set. %c1 = vm.const.f32 0x4effffff - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 0x7FFFFF80 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -69,7 +69,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si32_int_min vm.func @test_cast_f32_si32_int_min() { %c1 = vm.const.f32 -2147483648.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 -2147483648 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -79,7 +79,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si32_away_from_zero_pos vm.func @test_cast_f32_si32_away_from_zero_pos() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -89,7 +89,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si32_away_from_zero_neg vm.func @test_cast_f32_si32_away_from_zero_neg() { %c1 = vm.const.f32 -2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si32 %c1dno : f32 -> i32 %c2 = vm.const.i32 -3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 @@ -101,7 +101,7 @@ vm.module @conversion_ops_f32 { // This is the maximum value that is representable precisely as both i64 // and f32. An exponent of 62 with all mantissa bits set. %c1 = vm.const.f32 0x5effffff - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 %c2 = vm.const.i64 0x7FFFFF8000000000 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -111,13 +111,13 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si64_int_min vm.func @test_cast_f32_si64_int_min() { %c1 = vm.const.f32 -9223372036854775808.0 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 // Directly providing the true INT64_MIN of -9223372036854775808 // gives an error so we do -(INT64_MAX) - 1 // See: https://stackoverflow.com/a/65008288 %c2 = vm.const.i64 -9223372036854775807 - %c2dno = util.optimization_barrier %c2 : i64 + %c2dno = vm.optimization_barrier %c2 : i64 %c3 = vm.const.i64 1 %c4 = vm.sub.i64 %c2dno, %c3 : i64 vm.check.eq %v, %c4, "cast floating-point value to a signed integer" : i64 @@ -127,7 +127,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si64_away_from_zero_pos vm.func @test_cast_f32_si64_away_from_zero_pos() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -137,19 +137,22 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_si64_away_from_zero_neg vm.func @test_cast_f32_si64_away_from_zero_neg() { %c1 = vm.const.f32 -2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.si64 %c1dno : f32 -> i64 %c2 = vm.const.i64 -3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 vm.return } - vm.export @test_cast_f32_ui32_int_big + // EmitC constant folding breaks through vm.optimization_barrier, causing + // this test to be folded to unconditional error. Excluded until EmitC is + // removed or barrier handling is fixed. + vm.export @test_cast_f32_ui32_int_big attributes {emitc.exclude} vm.func @test_cast_f32_ui32_int_big() { // This is the maximum value that is representable precisely as both ui32 // and f32. An exponent of 31 with all mantissa bits set. %c1 = vm.const.f32 0x4f7fffff - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui32 %c1dno : f32 -> i32 %c2 = vm.const.i32 0xFFFFFF00 vm.check.eq %v, %c2, "cast floating-point value to an unsigned integer" : i32 @@ -159,19 +162,22 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_ui32_away_from_zero vm.func @test_cast_f32_ui32_away_from_zero() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui32 %c1dno : f32 -> i32 %c2 = vm.const.i32 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i32 vm.return } - vm.export @test_cast_f32_ui64_int_big + // EmitC constant folding breaks through vm.optimization_barrier, causing + // this test to be folded to unconditional error. Excluded until EmitC is + // removed or barrier handling is fixed. + vm.export @test_cast_f32_ui64_int_big attributes {emitc.exclude} vm.func @test_cast_f32_ui64_int_big() { // This is the maximum value that is representable precisely as both ui64 // and f32. An exponent of 63 with all mantissa bits set. %c1 = vm.const.f32 0x5F7FFFFF - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui64 %c1dno : f32 -> i64 %c2 = vm.const.i64 0xFFFFFF0000000000 vm.check.eq %v, %c2, "cast floating-point value to an unsigned integer" : i64 @@ -181,7 +187,7 @@ vm.module @conversion_ops_f32 { vm.export @test_cast_f32_ui64_away_from_zero vm.func @test_cast_f32_ui64_away_from_zero() { %c1 = vm.const.f32 2.5 - %c1dno = util.optimization_barrier %c1 : f32 + %c1dno = vm.optimization_barrier %c1 : f32 %v = vm.cast.f32.ui64 %c1dno : f32 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 diff --git a/runtime/src/iree/vm/test/conversion_ops_f64.mlir b/runtime/src/iree/vm/test/conversion_ops_f64.mlir index 850983425f38..1de7f205e3e1 100644 --- a/runtime/src/iree/vm/test/conversion_ops_f64.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_f64.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops_f64 { vm.export @test_bitcast_i64_f64 vm.func @test_bitcast_i64_f64() { %c1 = vm.const.i64 0x4016000000000000 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.bitcast.i64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 5.5 vm.check.eq %v, %c2, "bitcast i64 to f64" : f64 @@ -17,7 +17,7 @@ vm.module @conversion_ops_f64 { vm.export @test_bitcast_f64_i64 vm.func @test_bitcast_f64_i64() { %c1 = vm.const.f64 5.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.bitcast.f64.i64 %c1dno : f64 -> i64 %c2 = vm.const.i64 0x4016000000000000 vm.check.eq %v, %c2, "bitcast f64 to i64" : i64 @@ -27,7 +27,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_si64_f64_int_max vm.func @test_cast_si64_f64_int_max() { %c1 = vm.const.i64 2147483647 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.cast.si64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 2147483647.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f64 @@ -37,7 +37,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_si64_f64_int_min vm.func @test_cast_si64_f64_int_min() { %c1 = vm.const.i64 -2147483648 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.cast.si64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 -2147483648.0 vm.check.eq %v, %c2, "cast signed integer to a floating-point value" : f64 @@ -47,7 +47,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_ui64_f64_int_max vm.func @test_cast_ui64_f64_int_max() { %c1 = vm.const.i64 4294967295 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.cast.ui64.f64 %c1dno : i64 -> f64 %c2 = vm.const.f64 4294967295.0 vm.check.eq %v, %c2, "cast unsigned integer to a floating-point value" : f64 @@ -57,7 +57,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_si64_int_min vm.func @test_cast_f64_si64_int_min() { %c1 = vm.const.f64 -2147483648.0 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.si64 %c1dno : f64 -> i64 %c2 = vm.const.i64 -2147483648 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -67,7 +67,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_si64_away_from_zero_pos vm.func @test_cast_f64_si64_away_from_zero_pos() { %c1 = vm.const.f64 2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.si64 %c1dno : f64 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -77,7 +77,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_si64_away_from_zero_neg vm.func @test_cast_f64_si64_away_from_zero_neg() { %c1 = vm.const.f64 -2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.si64 %c1dno : f64 -> i64 %c2 = vm.const.i64 -3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 @@ -87,7 +87,7 @@ vm.module @conversion_ops_f64 { vm.export @test_cast_f64_ui64_away_from_zero vm.func @test_cast_f64_ui64_away_from_zero() { %c1 = vm.const.f64 2.5 - %c1dno = util.optimization_barrier %c1 : f64 + %c1dno = vm.optimization_barrier %c1 : f64 %v = vm.cast.f64.ui64 %c1dno : f64 -> i64 %c2 = vm.const.i64 3 vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 diff --git a/runtime/src/iree/vm/test/conversion_ops_i64.mlir b/runtime/src/iree/vm/test/conversion_ops_i64.mlir index 4ab99fa5e1fd..dc17376d9af2 100644 --- a/runtime/src/iree/vm/test/conversion_ops_i64.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_i64.mlir @@ -7,7 +7,7 @@ vm.module @conversion_ops_i64 { vm.export @test_trunc_i64_i32 vm.func @test_trunc_i64_i32() { %c1 = vm.const.i64 9223372036854775807 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %v = vm.trunc.i64.i32 %c1dno : i64 -> i32 %c2 = vm.const.i32 4294967295 vm.check.eq %v, %c2, "truncate unsigned i64 to unsigned i32" : i32 diff --git a/runtime/src/iree/vm/test/global_ops.mlir b/runtime/src/iree/vm/test/global_ops.mlir index 263e7b5028a5..fc6aab40905f 100644 --- a/runtime/src/iree/vm/test/global_ops.mlir +++ b/runtime/src/iree/vm/test/global_ops.mlir @@ -22,7 +22,7 @@ vm.module @global_ops { vm.func @test_global_load_ref() { %actual = vm.global.load.ref @g0 : !vm.buffer %expected = vm.const.ref.zero : !vm.buffer - %expecteddno = util.optimization_barrier %expected : !vm.buffer + %expecteddno = vm.optimization_barrier %expected : !vm.buffer vm.check.eq %actual, %expecteddno : !vm.buffer vm.return } diff --git a/runtime/src/iree/vm/test/list_ops.mlir b/runtime/src/iree/vm/test/list_ops.mlir index 696be360616e..4947a675beec 100644 --- a/runtime/src/iree/vm/test/list_ops.mlir +++ b/runtime/src/iree/vm/test/list_ops.mlir @@ -12,7 +12,7 @@ vm.module @list_ops { %list = vm.list.alloc %c42 : (i32) -> !vm.list vm.list.reserve %list, %c100 : (!vm.list, i32) %sz = vm.list.size %list : (!vm.list) -> i32 - %sz_dno = util.optimization_barrier %sz : i32 + %sz_dno = vm.optimization_barrier %sz : i32 vm.check.eq %sz_dno, %c0, "list.empty.size()=0" : i32 vm.return } @@ -107,7 +107,7 @@ vm.module @list_ops { %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) %v = vm.list.get.i32 %list, %c1 : (!vm.list, i32) -> i32 - %v_dno = util.optimization_barrier %v : i32 + %v_dno = vm.optimization_barrier %v : i32 // Add a dummy use of %v_dno to please recent versions of clang for the C target vm.list.set.i32 %list, %c1, %v_dno : (!vm.list, i32, i32) vm.return diff --git a/runtime/src/iree/vm/test/list_variant_ops.mlir b/runtime/src/iree/vm/test/list_variant_ops.mlir index 202c92ececbd..10cfb400fd32 100644 --- a/runtime/src/iree/vm/test/list_variant_ops.mlir +++ b/runtime/src/iree/vm/test/list_variant_ops.mlir @@ -113,7 +113,7 @@ vm.module @list_variant_ops { vm.list.resize %list, %c1 : (!vm.list, i32) %ref = vm.list.get.ref %list, %c1 : (!vm.list, i32) -> !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.return } diff --git a/runtime/src/iree/vm/test/ref_ops.mlir b/runtime/src/iree/vm/test/ref_ops.mlir index 019cde3f83fb..ca83e92a9b65 100644 --- a/runtime/src/iree/vm/test/ref_ops.mlir +++ b/runtime/src/iree/vm/test/ref_ops.mlir @@ -17,7 +17,7 @@ vm.module @ref_ops { vm.export @test_zero_ref_eq vm.func @test_zero_ref_eq() { %ref = vm.const.ref.zero : !vm.ref - %ref_dno = util.optimization_barrier %ref : !vm.ref + %ref_dno = vm.optimization_barrier %ref : !vm.ref vm.check.eq %ref_dno, %ref_dno : !vm.ref vm.return } @@ -30,9 +30,9 @@ vm.module @ref_ops { vm.export @test_ref_eq attributes {emitc.exclude} vm.func @test_ref_eq() { %ref_1 = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_1_dno = util.optimization_barrier %ref_1 : !vm.buffer + %ref_1_dno = vm.optimization_barrier %ref_1 : !vm.buffer %ref_2 = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_2_dno = util.optimization_barrier %ref_2 : !vm.buffer + %ref_2_dno = vm.optimization_barrier %ref_2 : !vm.buffer vm.check.eq %ref_1_dno, %ref_2_dno : !vm.buffer vm.return } @@ -40,9 +40,9 @@ vm.module @ref_ops { vm.export @test_ref_ne vm.func @test_ref_ne() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer vm.check.ne %ref_a_dno, %ref_b_dno : !vm.buffer vm.return } @@ -50,7 +50,7 @@ vm.module @ref_ops { vm.export @test_ref_nz vm.func @test_ref_nz() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno : !vm.buffer vm.return } @@ -64,7 +64,7 @@ vm.module @ref_ops { vm.export @test_ref_survives_call attributes {emitc.exclude} vm.func @test_ref_survives_call() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before call" : !vm.buffer vm.call @_consume_ref(%ref_dno) : (!vm.buffer) -> () // Ref should still be valid after the call. @@ -74,7 +74,7 @@ vm.module @ref_ops { vm.func private @_consume_ref(%arg : !vm.buffer) attributes {inlining_policy = #util.inline.never} { - %arg_dno = util.optimization_barrier %arg : !vm.buffer + %arg_dno = vm.optimization_barrier %arg : !vm.buffer vm.check.nz %arg_dno, "ref valid in callee" : !vm.buffer vm.return } @@ -83,7 +83,7 @@ vm.module @ref_ops { vm.export @test_same_ref_multiple_args attributes {emitc.exclude} vm.func @test_same_ref_multiple_args() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.call @_consume_two_refs(%ref_dno, %ref_dno) : (!vm.buffer, !vm.buffer) -> () // Ref should still be valid after the call. vm.check.nz %ref_dno, "ref valid after call with same ref twice" : !vm.buffer @@ -92,8 +92,8 @@ vm.module @ref_ops { vm.func private @_consume_two_refs(%arg0 : !vm.buffer, %arg1 : !vm.buffer) attributes {inlining_policy = #util.inline.never} { - %arg0_dno = util.optimization_barrier %arg0 : !vm.buffer - %arg1_dno = util.optimization_barrier %arg1 : !vm.buffer + %arg0_dno = vm.optimization_barrier %arg0 : !vm.buffer + %arg1_dno = vm.optimization_barrier %arg1 : !vm.buffer vm.check.nz %arg0_dno, "first arg valid" : !vm.buffer vm.check.nz %arg1_dno, "second arg valid" : !vm.buffer vm.check.eq %arg0_dno, %arg1_dno, "both args are same ref" : !vm.buffer @@ -104,7 +104,7 @@ vm.module @ref_ops { vm.export @test_ref_returned_from_call attributes {emitc.exclude} vm.func @test_ref_returned_from_call() { %ref = vm.call @_return_ref() : () -> !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "returned ref is valid" : !vm.buffer vm.return } @@ -119,9 +119,9 @@ vm.module @ref_ops { vm.export @test_ref_passthrough attributes {emitc.exclude} vm.func @test_ref_passthrough() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %returned = vm.call @_passthrough_ref(%ref_dno) : (!vm.buffer) -> !vm.buffer - %returned_dno = util.optimization_barrier %returned : !vm.buffer + %returned_dno = vm.optimization_barrier %returned : !vm.buffer vm.check.eq %ref_dno, %returned_dno, "passthrough returns same ref" : !vm.buffer vm.return } @@ -139,9 +139,9 @@ vm.module @ref_ops { vm.export @test_ref_cond_br_both_paths attributes {emitc.exclude} vm.func @test_ref_cond_br_both_paths() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^bb1(%ref_dno : !vm.buffer), ^bb2(%ref_dno : !vm.buffer) ^bb1(%arg1 : !vm.buffer): vm.check.nz %arg1, "ref valid in bb1" : !vm.buffer @@ -157,9 +157,9 @@ vm.module @ref_ops { vm.export @test_ref_cond_br_one_path attributes {emitc.exclude} vm.func @test_ref_cond_br_one_path() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^bb1(%ref_dno : !vm.buffer), ^bb2 ^bb1(%arg1 : !vm.buffer): vm.check.nz %arg1, "ref valid in bb1" : !vm.buffer @@ -172,7 +172,7 @@ vm.module @ref_ops { vm.export @test_ref_in_loop attributes {emitc.exclude} vm.func @test_ref_in_loop() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c3 = vm.const.i32 3 @@ -191,9 +191,9 @@ vm.module @ref_ops { vm.export @test_multiple_refs_in_loop attributes {emitc.exclude} vm.func @test_multiple_refs_in_loop() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c3 = vm.const.i32 3 @@ -217,10 +217,10 @@ vm.module @ref_ops { vm.export @test_global_store_load_ref attributes {emitc.exclude} vm.func @test_global_store_load_ref() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.global.store.ref %ref_dno, @global_ref : !vm.buffer %loaded = vm.global.load.ref @global_ref : !vm.buffer - %loaded_dno = util.optimization_barrier %loaded : !vm.buffer + %loaded_dno = vm.optimization_barrier %loaded : !vm.buffer vm.check.eq %ref_dno, %loaded_dno, "loaded ref equals stored ref" : !vm.buffer vm.return } @@ -229,7 +229,7 @@ vm.module @ref_ops { vm.export @test_ref_valid_after_global_store attributes {emitc.exclude} vm.func @test_ref_valid_after_global_store() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before store" : !vm.buffer vm.global.store.ref %ref_dno, @global_ref : !vm.buffer // Original ref should still be valid after storing to global. @@ -248,12 +248,12 @@ vm.module @ref_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) vm.list.set.ref %list, %c0, %ref_dno : (!vm.list, i32, !vm.buffer) %retrieved = vm.list.get.ref %list, %c0 : (!vm.list, i32) -> !vm.buffer - %retrieved_dno = util.optimization_barrier %retrieved : !vm.buffer + %retrieved_dno = vm.optimization_barrier %retrieved : !vm.buffer vm.check.eq %ref_dno, %retrieved_dno, "retrieved ref equals set ref" : !vm.buffer vm.return } @@ -265,17 +265,17 @@ vm.module @ref_ops { %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %list = vm.list.alloc %c2 : (i32) -> !vm.list vm.list.resize %list, %c2 : (!vm.list, i32) vm.list.set.ref %list, %c0, %ref_a_dno : (!vm.list, i32, !vm.buffer) vm.list.set.ref %list, %c1, %ref_b_dno : (!vm.list, i32, !vm.buffer) %retrieved_a = vm.list.get.ref %list, %c0 : (!vm.list, i32) -> !vm.buffer - %retrieved_a_dno = util.optimization_barrier %retrieved_a : !vm.buffer + %retrieved_a_dno = vm.optimization_barrier %retrieved_a : !vm.buffer %retrieved_b = vm.list.get.ref %list, %c1 : (!vm.list, i32) -> !vm.buffer - %retrieved_b_dno = util.optimization_barrier %retrieved_b : !vm.buffer + %retrieved_b_dno = vm.optimization_barrier %retrieved_b : !vm.buffer vm.check.eq %ref_a_dno, %retrieved_a_dno, "retrieved ref_a equals set ref_a" : !vm.buffer vm.check.eq %ref_b_dno, %retrieved_b_dno, "retrieved ref_b equals set ref_b" : !vm.buffer vm.check.ne %retrieved_a_dno, %retrieved_b_dno, "refs are different" : !vm.buffer @@ -288,12 +288,12 @@ vm.module @ref_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) vm.list.set.ref %list, %c0, %ref_dno : (!vm.list, i32, !vm.buffer) %retrieved = vm.list.get.ref %list, %c0 : (!vm.list, i32) -> !vm.buffer - %retrieved_dno = util.optimization_barrier %retrieved : !vm.buffer + %retrieved_dno = vm.optimization_barrier %retrieved : !vm.buffer // Use retrieved ref multiple times. vm.check.nz %retrieved_dno, "retrieved ref valid (use 1)" : !vm.buffer vm.check.nz %retrieved_dno, "retrieved ref valid (use 2)" : !vm.buffer @@ -307,7 +307,7 @@ vm.module @ref_ops { %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before list set" : !vm.buffer %list = vm.list.alloc %c1 : (i32) -> !vm.list vm.list.resize %list, %c1 : (!vm.list, i32) @@ -325,13 +325,13 @@ vm.module @ref_ops { vm.export @test_select_ref_true attributes {emitc.exclude} vm.func @test_select_ref_true() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 %result = vm.select.ref %c1_dno, %ref_a_dno, %ref_b_dno : !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer vm.check.eq %result_dno, %ref_a_dno, "select true returns first ref" : !vm.buffer vm.return } @@ -339,13 +339,13 @@ vm.module @ref_ops { vm.export @test_select_ref_false attributes {emitc.exclude} vm.func @test_select_ref_false() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c0 = vm.const.i32 0 - %c0_dno = util.optimization_barrier %c0 : i32 + %c0_dno = vm.optimization_barrier %c0 : i32 %result = vm.select.ref %c0_dno, %ref_a_dno, %ref_b_dno : !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer vm.check.eq %result_dno, %ref_b_dno, "select false returns second ref" : !vm.buffer vm.return } @@ -354,13 +354,13 @@ vm.module @ref_ops { vm.export @test_select_ref_input_survives attributes {emitc.exclude} vm.func @test_select_ref_input_survives() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 %result = vm.select.ref %c1_dno, %ref_a_dno, %ref_b_dno : !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer // Both input refs should still be valid after select. vm.check.nz %ref_a_dno, "ref_a valid after select" : !vm.buffer vm.check.nz %ref_b_dno, "ref_b valid after select" : !vm.buffer @@ -376,7 +376,7 @@ vm.module @ref_ops { vm.export @test_ref_multiple_sequential_uses attributes {emitc.exclude} vm.func @test_ref_multiple_sequential_uses() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer // Use 1: check nz vm.check.nz %ref_dno, "use 1" : !vm.buffer // Use 2: pass to call @@ -394,9 +394,9 @@ vm.module @ref_ops { vm.export @test_ref_call_chain attributes {emitc.exclude} vm.func @test_ref_call_chain() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %result = vm.call @_call_chain_a(%ref_dno) : (!vm.buffer) -> !vm.buffer - %result_dno = util.optimization_barrier %result : !vm.buffer + %result_dno = vm.optimization_barrier %result : !vm.buffer vm.check.eq %ref_dno, %result_dno, "chain returns same ref" : !vm.buffer vm.return } @@ -416,9 +416,9 @@ vm.module @ref_ops { vm.export @test_return_multiple_refs attributes {emitc.exclude} vm.func @test_return_multiple_refs() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %results:2 = vm.call @_return_two_refs(%ref_a_dno, %ref_b_dno) : (!vm.buffer, !vm.buffer) -> (!vm.buffer, !vm.buffer) vm.check.eq %results#0, %ref_a_dno, "first result is ref_a" : !vm.buffer @@ -436,9 +436,9 @@ vm.module @ref_ops { vm.export @test_return_refs_swapped attributes {emitc.exclude} vm.func @test_return_refs_swapped() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %results:2 = vm.call @_return_refs_swapped(%ref_a_dno, %ref_b_dno) : (!vm.buffer, !vm.buffer) -> (!vm.buffer, !vm.buffer) vm.check.eq %results#0, %ref_b_dno, "first result is ref_b (swapped)" : !vm.buffer @@ -467,7 +467,7 @@ vm.module @ref_ops { vm.export @test_discard_single_ref attributes {emitc.exclude} vm.func private @test_discard_single_ref() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer vm.check.nz %ref_dno, "ref valid before discard" : !vm.buffer vm.discard.refs %ref_dno : !vm.buffer // Note: After discard, the ref is released. We shouldn't use it. @@ -478,9 +478,9 @@ vm.module @ref_ops { vm.export @test_discard_multiple_refs attributes {emitc.exclude} vm.func private @test_discard_multiple_refs() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer vm.check.nz %ref_a_dno, "ref_a valid before discard" : !vm.buffer vm.check.nz %ref_b_dno, "ref_b valid before discard" : !vm.buffer vm.discard.refs %ref_a_dno, %ref_b_dno : !vm.buffer, !vm.buffer @@ -491,9 +491,9 @@ vm.module @ref_ops { vm.export @test_discard_in_branch attributes {emitc.exclude} vm.func private @test_discard_in_branch() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^bb1, ^bb2 ^bb1: vm.discard.refs %ref_dno : !vm.buffer @@ -513,7 +513,7 @@ vm.module @ref_ops { vm.export @test_nested_loop_outer_ref attributes {emitc.exclude} vm.func private @test_nested_loop_outer_ref() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c2 = vm.const.i32 2 @@ -543,9 +543,9 @@ vm.module @ref_ops { vm.export @test_ping_pong_swap attributes {emitc.exclude} vm.func private @test_ping_pong_swap() { %ref_a = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_a_dno = util.optimization_barrier %ref_a : !vm.buffer + %ref_a_dno = vm.optimization_barrier %ref_a : !vm.buffer %ref_b = vm.const.ref.rodata @buffer_b : !vm.buffer - %ref_b_dno = util.optimization_barrier %ref_b : !vm.buffer + %ref_b_dno = vm.optimization_barrier %ref_b : !vm.buffer %c0 = vm.const.i32 0 %c1 = vm.const.i32 1 %c3 = vm.const.i32 3 @@ -571,9 +571,9 @@ vm.module @ref_ops { vm.export @test_diamond_asymmetric_use attributes {emitc.exclude} vm.func private @test_diamond_asymmetric_use() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c1 = vm.const.i32 1 - %c1_dno = util.optimization_barrier %c1 : i32 + %c1_dno = vm.optimization_barrier %c1 : i32 vm.cond_br %c1_dno, ^use_path(%ref_dno : !vm.buffer), ^nouse_path(%ref_dno : !vm.buffer) ^use_path(%r1 : !vm.buffer): vm.check.nz %r1, "ref valid in use_path" : !vm.buffer @@ -589,9 +589,9 @@ vm.module @ref_ops { vm.export @test_diamond_asymmetric_nouse attributes {emitc.exclude} vm.func private @test_diamond_asymmetric_nouse() { %ref = vm.const.ref.rodata @buffer_a : !vm.buffer - %ref_dno = util.optimization_barrier %ref : !vm.buffer + %ref_dno = vm.optimization_barrier %ref : !vm.buffer %c0 = vm.const.i32 0 - %c0_dno = util.optimization_barrier %c0 : i32 + %c0_dno = vm.optimization_barrier %c0 : i32 vm.cond_br %c0_dno, ^use_path(%ref_dno : !vm.buffer), ^nouse_path(%ref_dno : !vm.buffer) ^use_path(%r1 : !vm.buffer): vm.check.nz %r1, "ref valid in use_path" : !vm.buffer diff --git a/runtime/src/iree/vm/test/shift_ops.mlir b/runtime/src/iree/vm/test/shift_ops.mlir index b1e618d6a310..d6b258cf4436 100644 --- a/runtime/src/iree/vm/test/shift_ops.mlir +++ b/runtime/src/iree/vm/test/shift_ops.mlir @@ -7,7 +7,7 @@ vm.module @shift_ops { vm.export @test_shl_i32 vm.func @test_shl_i32() { %c1 = vm.const.i32 1 - %c1dno = util.optimization_barrier %c1 : i32 + %c1dno = vm.optimization_barrier %c1 : i32 %c2 = vm.const.i32 2 %v = vm.shl.i32 %c1dno, %c2 : i32 %c4 = vm.const.i32 4 @@ -18,7 +18,7 @@ vm.module @shift_ops { vm.export @test_shr_i32s vm.func @test_shr_i32s() { %cn1 = vm.const.i32 -1 - %cn1dno = util.optimization_barrier %cn1 : i32 + %cn1dno = vm.optimization_barrier %cn1 : i32 %c2 = vm.const.i32 2 %v = vm.shr.i32.s %cn1dno, %c2 : i32 vm.check.eq %v, %cn1dno, "-1>>2=-1" : i32 @@ -28,7 +28,7 @@ vm.module @shift_ops { vm.export @test_shr_i32u vm.func @test_shr_i32u() { %c4 = vm.const.i32 4 - %c4dno = util.optimization_barrier %c4 : i32 + %c4dno = vm.optimization_barrier %c4 : i32 %c2 = vm.const.i32 2 %v = vm.shr.i32.u %c4dno, %c2 : i32 %c1 = vm.const.i32 1 diff --git a/runtime/src/iree/vm/test/shift_ops_i64.mlir b/runtime/src/iree/vm/test/shift_ops_i64.mlir index 00c072423595..6a10d14d4a8e 100644 --- a/runtime/src/iree/vm/test/shift_ops_i64.mlir +++ b/runtime/src/iree/vm/test/shift_ops_i64.mlir @@ -7,7 +7,7 @@ vm.module @shift_ops_i64 { vm.export @test_shl_i64 vm.func @test_shl_i64() { %c1 = vm.const.i64 1 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %shamt = vm.const.i32 2 %v = vm.shl.i64 %c1dno, %shamt : i64 %c4 = vm.const.i64 4 @@ -18,7 +18,7 @@ vm.module @shift_ops_i64 { vm.export @test_shr_i64s vm.func @test_shr_i64s() { %c1 = vm.const.i64 -1 - %c1dno = util.optimization_barrier %c1 : i64 + %c1dno = vm.optimization_barrier %c1 : i64 %shamt = vm.const.i32 2 %v = vm.shr.i64.s %c1dno, %shamt : i64 %cn1 = vm.const.i64 -1 @@ -29,7 +29,7 @@ vm.module @shift_ops_i64 { vm.export @test_shr_i64u vm.func @test_shr_i64u() { %c4 = vm.const.i64 4 - %c4dno = util.optimization_barrier %c4 : i64 + %c4dno = vm.optimization_barrier %c4 : i64 %shamt = vm.const.i32 2 %v = vm.shr.i64.u %c4dno, %shamt : i64 %c1 = vm.const.i64 1 diff --git a/tests/compiler_driver/streams.mlir b/tests/compiler_driver/streams.mlir index c7d7df7fe9c0..1181afff67f9 100644 --- a/tests/compiler_driver/streams.mlir +++ b/tests/compiler_driver/streams.mlir @@ -1,9 +1,7 @@ // RUN: iree-compile --split-input-file \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ -// RUN: --output-format=vm-bytecode \ -// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s \ -// RUN: --mlir-print-ir-after=iree-vm-ordinal-allocation 2>&1 | FileCheck %s +// RUN: --compile-to=vm %s | FileCheck %s // This file has a few test programs that show how to mix `flow` dispatches into // those created by the `linalg` dispatch region formation: the idea is to use diff --git a/tools/iree-dump-module-main.c b/tools/iree-dump-module-main.c index a6a596da1cb0..af384c196621 100644 --- a/tools/iree-dump-module-main.c +++ b/tools/iree-dump-module-main.c @@ -336,6 +336,23 @@ static void iree_tooling_print_rwdata_segment_defs( } } +// Returns the first export name for the given internal ordinal, or empty if not +// exported. +static iree_string_view_t iree_tooling_lookup_export_name( + iree_host_size_t internal_ordinal, + iree_vm_ExportFunctionDef_vec_t export_defs) { + for (size_t j = 0; j < iree_vm_ExportFunctionDef_vec_len(export_defs); ++j) { + iree_vm_ExportFunctionDef_table_t export_def = + iree_vm_ExportFunctionDef_vec_at(export_defs, j); + if (iree_vm_ExportFunctionDef_internal_ordinal(export_def) == + internal_ordinal) { + const char* name = iree_vm_ExportFunctionDef_local_name(export_def); + return iree_make_string_view(name, strlen(name)); + } + } + return iree_string_view_empty(); +} + static void iree_tooling_print_function_descriptors( iree_vm_FunctionDescriptor_vec_t descriptors, iree_vm_ExportFunctionDef_vec_t export_defs) { @@ -550,11 +567,23 @@ static iree_status_t iree_tooling_dump_module_disassembly( iree_status_t status = iree_vm_bytecode_module_create( instance, IREE_VM_BYTECODE_MODULE_FLAG_ALLOW_PLACEHOLDER_TYPES, archive_contents, iree_allocator_null(), host_allocator, &module); + iree_const_byte_span_t flatbuffer_contents = iree_const_byte_span_empty(); + iree_host_size_t rodata_offset = 0; + if (iree_status_is_ok(status)) { + status = iree_vm_bytecode_archive_parse_header( + archive_contents, &flatbuffer_contents, &rodata_offset); + } if (iree_status_is_ok(status)) { iree_string_builder_t builder; iree_string_builder_initialize(host_allocator, &builder); - // Iterate over exported functions and build the disassembly output. + // Extract export names from the flatbuffer module definition. + iree_vm_BytecodeModuleDef_table_t module_def = + iree_vm_BytecodeModuleDef_as_root(flatbuffer_contents.data); + iree_vm_ExportFunctionDef_vec_t export_defs = + iree_vm_BytecodeModuleDef_exported_functions(module_def); + + // Iterate over internal functions and build the disassembly output. iree_vm_module_signature_t signature = iree_vm_module_signature(module); for (iree_host_size_t i = 0; i < signature.internal_function_count; ++i) { iree_vm_function_t function; @@ -562,8 +591,15 @@ static iree_status_t iree_tooling_dump_module_disassembly( module, IREE_VM_FUNCTION_LINKAGE_INTERNAL, i, &function); if (!iree_status_is_ok(status)) break; + // Get function name from exports if available, otherwise use internal + // name. + iree_string_view_t export_name = + iree_tooling_lookup_export_name(i, export_defs); + iree_string_view_t function_name = iree_string_view_is_empty(export_name) + ? iree_vm_function_name(&function) + : export_name; + // Apply filter (ordinal or name) if provided. - iree_string_view_t function_name = iree_vm_function_name(&function); if (!iree_string_view_is_empty(function_filter)) { uint32_t filter_ordinal = -1; if (iree_string_view_atoi_uint32(function_filter, &filter_ordinal)) { diff --git a/tools/test/iree-dump-module.mlir b/tools/test/iree-dump-module.mlir index db95f5c1f604..2609a754e14e 100644 --- a/tools/test/iree-dump-module.mlir +++ b/tools/test/iree-dump-module.mlir @@ -8,20 +8,20 @@ // RUN: %t.vmfb | \ // RUN: FileCheck %s -// CHECK-LABEL: @module : version 0 +// CHECK: @module : version 0 -// CHECK-LABEL: module.fn0 +// CHECK: fn0 func.func @fn0(%input : tensor) -> (tensor) { - // CHECK: [{{[0-9]+}}] + // CHECK: [{{[0-9]+}}]{{.*}} %result = math.absf %input : tensor return %result : tensor } -// CHECK-LABEL: module.fn1 +// CHECK: fn1 func.func @fn1(%input : tensor) -> (tensor) { - // CHECK: [{{[0-9]+}}] + // CHECK: [{{[0-9]+}}]{{.*}} %result = arith.mulf %input, %input : tensor return %result : tensor } -// CHECK-LABEL: module.__init +// CHECK: __init From 4fa4196f321374f7fba8579e16d4ef837fb58184 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 14 Jan 2026 08:16:01 -0800 Subject: [PATCH 35/71] Adding SplitParameterEncoderPass and iree-encode-parameters. (#22814) Key components: - SplitParameterEncoder.cpp: Main pass implementation - Analyzes parameter usage across the module - Hoists hoistable operations (constants, splats, dispatches) to initializers - Generates parameter encoding module - ParameterOptions: New pipeline options for parameter encoder configuration - --iree-parameter-encoder-mode=consolidate|overlay - --iree-parameter-encoder-target= (TBD) - --iree-parameter-encoder-output-file=encoder.mlirbc - --iree-parameter-encoder-output-scope=encoded (recommended) This pass produces some of the more sophisticated programs the compiler has encountered and due to its nature of streaming through large volumes of parameters was very sensitive to all the bail-outs that were present. The good news is that this upgrades the compiler quite a bit, at least! Paired with this, the iree-encode-parameters tool is designed to run the parameter encoder modules produced from the SplitParameterEncoderPass. Though the modules are normal IREE programs they have a bunch of metadata to aid in a more ergonomic tool, automation (as we get into multi-targeting/retargeting of encodings), and handling the two-stage initialization. API usage is possible to allow runtimes/frameworks/etc to integrate on-the-fly encoding if desired. The tool currently bails on multi-targeting and double creates contexts. Future improvements can finish the compiler->tool->runtime target detection and selection logic. The parameters module will need to be reworked to allow new parameter providers to be registered (here the input/output files) after the context has been created in order to avoid the double creates. See tests/e2e/parameters/encode_parameters.mlir for a full example: ``` // Compile main module with encoder MLIR output and splat parameter export. // RUN: iree-compile %s \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ // RUN: --iree-parameter-encoder-output-file=%t_encoder.mlir \ // RUN: --iree-parameter-splat=%t_input.irpa \ // RUN: -o %t_main.vmfb // // Compile the encoder module separately. // RUN: iree-compile %t_encoder.mlir \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ // RUN: -o %t_encoder.vmfb // // Run the encoder to transform parameters. // RUN: iree-encode-parameters \ // RUN: --module=%t_encoder.vmfb \ // RUN: --parameters=model=%t_input.irpa \ // RUN: --output=encoded=%t_output.irpa \ // RUN: --quiet // // Run the main module with both input and encoded parameters. // The encoded parameters contain the pre-computed transformed values. // RUN: iree-run-module \ // RUN: --device=local-sync \ // RUN: --module=%t_main.vmfb \ // RUN: --function=main \ // RUN: --parameters=model=%t_input.irpa \ // RUN: --parameters=encoded=%t_output.irpa | \ // RUN: FileCheck %s ``` --- .../compiler/API/Internal/CompilerDriver.cpp | 20 +- .../iree/compiler/ConstEval/JitGlobals.cpp | 3 +- .../Flow/Transforms/ConvertShardToFlow.cpp | 15 +- .../Dialect/Stream/IR/StreamInterfaces.td | 12 + .../compiler/Dialect/Stream/IR/StreamOps.cpp | 26 +- .../compiler/Dialect/Stream/IR/StreamOps.td | 19 + .../Dialect/Stream/Transforms/BUILD.bazel | 2 + .../Dialect/Stream/Transforms/CMakeLists.txt | 2 + .../Stream/Transforms/LayoutSlices.cpp | 20 +- .../Transforms/PackDispatchOperands.cpp | 2 +- .../Dialect/Stream/Transforms/Passes.cpp | 69 +- .../Dialect/Stream/Transforms/Passes.h | 44 + .../Dialect/Stream/Transforms/Passes.td | 55 +- .../Dialect/Stream/Transforms/RefineUsage.cpp | 1 + .../Transforms/SplitParameterEncoder.cpp | 2216 +++++++++++++++++ .../Stream/Transforms/test/BUILD.bazel | 1 + .../Stream/Transforms/test/CMakeLists.txt | 1 + .../test/split_parameter_encoder.mlir | 1770 +++++++++++++ .../iree/compiler/Dialect/Util/IR/UtilOps.cpp | 52 + .../iree/compiler/Dialect/Util/IR/UtilOps.td | 14 + .../compiler/Dialect/Util/IR/UtilTypes.td | 3 + .../compiler/ExternalInterfaces/BUILD.bazel | 1 + .../ExternalInterfaces/CMakeLists.txt | 1 + .../StreamExternalModels.cpp | 18 + .../ExternalInterfaces/UtilExternalModels.cpp | 53 + .../compiler/GlobalOptimization/Passes.cpp | 4 +- .../iree/compiler/GlobalOptimization/Passes.h | 4 +- .../src/iree/compiler/Pipelines/Options.cpp | 112 +- .../src/iree/compiler/Pipelines/Options.h | 48 +- .../src/iree/compiler/Pipelines/Pipelines.cpp | 39 +- .../src/iree/compiler/Pipelines/Pipelines.h | 2 + compiler/src/iree/compiler/Utils/BUILD.bazel | 2 +- .../src/iree/compiler/Utils/CMakeLists.txt | 2 +- compiler/src/iree/compiler/Utils/Folding.h | 34 - .../src/iree/compiler/Utils/ModuleUtils.cpp | 34 + .../src/iree/compiler/Utils/ModuleUtils.h | 4 + .../src/iree/compiler/Utils/OptionUtils.h | 48 +- runtime/src/iree/hal/utils/file_transfer.c | 16 +- .../src/iree/io/formats/irpa/irpa_builder.c | 8 + .../src/iree/io/formats/irpa/irpa_builder.h | 7 + runtime/src/iree/tooling/context_util.c | 3 +- runtime/src/iree/tooling/parameter_util.c | 40 +- runtime/src/iree/tooling/parameter_util.h | 11 +- tests/e2e/parameters/BUILD.bazel | 2 + tests/e2e/parameters/CMakeLists.txt | 2 + tests/e2e/parameters/encode_parameters.mlir | 68 + tests/e2e/parameters/export_parameters.mlir | 4 +- .../parameters/generate_splat_archive.mlir | 2 +- tools/BUILD.bazel | 21 + tools/CMakeLists.txt | 24 + tools/iree-encode-parameters-main.c | 1116 +++++++++ 51 files changed, 5906 insertions(+), 171 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir delete mode 100644 compiler/src/iree/compiler/Utils/Folding.h create mode 100644 tests/e2e/parameters/encode_parameters.mlir create mode 100644 tools/iree-encode-parameters-main.c diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index f4cea928b8c4..23770bcccdf0 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -247,6 +247,7 @@ struct GlobalInit { InputDialectOptions *clInputOptions = nullptr; PreprocessingOptions *clPreprocessingOptions = nullptr; GlobalOptimizationOptions *clGlobalOptimizationOptions = nullptr; + ParameterOptions *clParameterOptions = nullptr; DispatchCreationOptions *clDispatchCreationOptions = nullptr; SchedulingOptions *clSchedulingOptions = nullptr; IREE::HAL::TargetOptions *clHalTargetOptions = nullptr; @@ -293,6 +294,7 @@ void GlobalInit::registerCommandLineOptions() { clInputOptions = &InputDialectOptions::FromFlags::get(); clPreprocessingOptions = &PreprocessingOptions::FromFlags::get(); clGlobalOptimizationOptions = &GlobalOptimizationOptions::FromFlags::get(); + clParameterOptions = &ParameterOptions::FromFlags::get(); clDispatchCreationOptions = &DispatchCreationOptions::FromFlags::get(); clSchedulingOptions = &SchedulingOptions::FromFlags::get(); clHalTargetOptions = &IREE::HAL::TargetOptions::FromFlags::get(); @@ -403,6 +405,7 @@ struct Session { BindingOptions bindingOptions; InputDialectOptions inputOptions; PreprocessingOptions preprocessingOptions; + ParameterOptions parameterOptions; GlobalOptimizationOptions highLevelOptimizationOptions; DispatchCreationOptions dispatchCreationOptions; SchedulingOptions schedulingOptions; @@ -432,6 +435,7 @@ Session::Session(GlobalInit &globalInit) inputOptions = *globalInit.clInputOptions; preprocessingOptions = *globalInit.clPreprocessingOptions; highLevelOptimizationOptions = *globalInit.clGlobalOptimizationOptions; + parameterOptions = *globalInit.clParameterOptions; dispatchCreationOptions = *globalInit.clDispatchCreationOptions; schedulingOptions = *globalInit.clSchedulingOptions; halTargetOptions = *globalInit.clHalTargetOptions; @@ -453,6 +457,7 @@ Session::Session(GlobalInit &globalInit) preprocessingOptions.bindOptions(binder); inputOptions.bindOptions(binder); highLevelOptimizationOptions.bindOptions(binder); + parameterOptions.bindOptions(binder); dispatchCreationOptions.bindOptions(binder); schedulingOptions.bindOptions(binder); halTargetOptions.bindOptions(binder); @@ -1015,10 +1020,10 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { buildIREEVMTransformPassPipeline( session.targetRegistry, session.pipelineOptions, session.bindingOptions, session.inputOptions, session.preprocessingOptions, - session.highLevelOptimizationOptions, session.dispatchCreationOptions, - session.schedulingOptions, session.halTargetOptions, - session.vmTargetOptions, pipelineHooks, *passManager, compileFrom, - compileTo); + session.parameterOptions, session.highLevelOptimizationOptions, + session.dispatchCreationOptions, session.schedulingOptions, + session.halTargetOptions, session.vmTargetOptions, pipelineHooks, + *passManager, compileFrom, compileTo); break; } case IREE_COMPILER_PIPELINE_HAL_EXECUTABLE: { @@ -1049,9 +1054,10 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { buildIREEPrecompileTransformPassPipeline( session.targetRegistry, session.pipelineOptions, session.bindingOptions, session.inputOptions, session.preprocessingOptions, - session.highLevelOptimizationOptions, session.dispatchCreationOptions, - session.schedulingOptions, session.halTargetOptions, pipelineHooks, - *passManager, compileFrom, compileTo); + session.parameterOptions, session.highLevelOptimizationOptions, + session.dispatchCreationOptions, session.schedulingOptions, + session.halTargetOptions, pipelineHooks, *passManager, compileFrom, + compileTo); break; } case IREE_COMPILER_PIPELINE_VM: { diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 46502f9f59e2..882a5451f7cf 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -73,6 +73,7 @@ struct CompileOptions { BindingOptions bindingOptions; InputDialectOptions inputOptions; PreprocessingOptions preprocessingOptions; + ParameterOptions parameterOptions; GlobalOptimizationOptions globalOptimizationOptions; DispatchCreationOptions dispatchCreationOptions; SchedulingOptions schedulingOptions; @@ -631,7 +632,7 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { buildIREEVMTransformPassPipeline( *targetRegistry.value, compileOptions->pipelineOptions, compileOptions->bindingOptions, compileOptions->inputOptions, - compileOptions->preprocessingOptions, + compileOptions->preprocessingOptions, compileOptions->parameterOptions, compileOptions->globalOptimizationOptions, compileOptions->dispatchCreationOptions, compileOptions->schedulingOptions, compileOptions->executableOptions, diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp index 3b224f821408..5e4b314b4ff0 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp @@ -6,7 +6,6 @@ #include "iree/compiler/Dialect/Flow/Conversion/ShardToFlow/Patterns.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" -#include "iree/compiler/Utils/Folding.h" #include "iree/compiler/Utils/Indexing.h" #include "iree/compiler/Utils/OpVisitor.h" #include "iree/compiler/Utils/Permutation.h" @@ -27,6 +26,20 @@ namespace mlir::iree_compiler::IREE::Flow { namespace { +// Convert a `Value` or an `Attribute` range to a range of `OpFoldResult`. +template +static void toOpFoldResults(Range &&range, OutIt outIt) { + llvm::transform(std::forward(range), outIt, + [](auto v) { return OpFoldResult(v); }); +} + +template +static SmallVector toOpFoldResults(Range &&range) { + SmallVector res; + toOpFoldResults(std::forward(range), std::back_inserter(res)); + return res; +} + static bool hasMoreThanOneShard(Operation *op) { int shardCount = 0; op->walk([&shardCount](shard::ShardOp shard) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td index e9977dde4470..d96ab6899f5a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td @@ -213,6 +213,18 @@ def Stream_AffinityOp : Stream_OpInterface<"AffinityOpInterface"> { return $_op.getAffinityAttr(); }] >, + InterfaceMethod< + /*desc=*/[{ + Removes all affinities specified on the op. + }], + /*retTy=*/"void", + /*methodName=*/"removeAffinityAttrs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + $_op.setAffinityAttr(nullptr); + }] + >, ]; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 1e8d8fb35ea4..6afb02f6d217 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -1712,17 +1712,10 @@ ResourceAllocOp::createSuballocations( // small enough workloads and our target devices are relatively lax on // things so long as we stay under UINT32_MAX boundaries. - // All slices are 0-0 (overlapping). - size_t sliceCount = locs.size(); - SmallVector lifetimeIntervals(sliceCount * 2, 0); - // Compute total size and the offsets of all suballocated resources via the // pack op. - auto indexType = builder.getIndexType(); - SmallVector packedOffsetTypes(sliceCount, indexType); auto packOp = IREE::Stream::ResourcePackOp::create( - builder, fusedLoc, indexType, packedOffsetTypes, /*offset=*/nullptr, - builder.getIndexArrayAttr(lifetimeIntervals), storageSizes, affinityAttr); + builder, fusedLoc, /*offset=*/nullptr, storageSizes, affinityAttr); // Create the new alloca based on the total required size. auto allocOp = IREE::Stream::ResourceAllocOp::create( @@ -1884,6 +1877,18 @@ void ResourcePackOp::getAsmResultNames( // } } +void ResourcePackOp::build(OpBuilder &builder, OperationState &state, + Value offset, ValueRange valueSizes, + IREE::Stream::AffinityAttr affinity) { + // All slices are 0-0 (overlapping). + size_t sliceCount = valueSizes.size(); + SmallVector lifetimeIntervals(sliceCount * 2, 0); + auto indexType = builder.getIndexType(); + SmallVector indexTypes(sliceCount, indexType); + build(builder, state, indexType, indexTypes, offset, + builder.getIndexArrayAttr(lifetimeIntervals), valueSizes, affinity); +} + LogicalResult ResourcePackOp::verify() { ResourcePackOp op = *this; size_t sliceCount = op.getPackedOffsets().size(); @@ -2909,6 +2914,11 @@ IREE::Stream::AffinityAttr AsyncTransferOp::getResultAffinityAttr() { return getTargetAffinityAttr(); } +void AsyncTransferOp::removeAffinityAttrs() { + removeSourceAffinityAttr(); + removeTargetAffinityAttr(); +} + void AsyncTransferOp::getAsyncAccessRanges( SmallVectorImpl &ranges) { ranges.push_back({ResourceAccessBitfield::Read, getSource(), Value{}, diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td index a56218f637db..8a8469478932 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td @@ -577,6 +577,14 @@ def Stream_ResourcePackOp : Stream_PureOp<"resource.pack", [ attr-dict-with-keyword }]; + let builders = [ + OpBuilder<(ins + "Value":$offset, + "ValueRange":$valueSizes, + CArg<"AffinityAttr", "{}">:$affinityAttr + )>, + ]; + let hasVerifier = 1; let extraClassDeclaration = [{ @@ -1844,6 +1852,7 @@ def Stream_AsyncConstantOp : Stream_PureOp<"async.constant", [ DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, DeclareOpInterfaceMethods, ]> { @@ -1906,6 +1915,7 @@ def Stream_AsyncSplatOp : Stream_Op<"async.splat", [ "getAsyncAccessRanges", ]>, Util_SizeAwareOp, + Util_HoistableOpInterface, ]> { let summary = [{Splats a value into a resource.}]; let description = [{ @@ -1972,6 +1982,7 @@ def Stream_AsyncCloneOp : Stream_Op<"async.clone", [ "getAsyncAccessRanges", ]>, Util_SizeAwareOp, + Util_HoistableOpInterface, ]> { let summary = [{Clones the contents of a value.}]; let description = [{ @@ -2016,6 +2027,7 @@ def Stream_AsyncSliceOp : Stream_PureOp<"async.slice", [ DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, ]> { let summary = [{Slices out a cloned subview of a value.}]; @@ -2069,6 +2081,7 @@ def Stream_AsyncFillOp : Stream_Op<"async.fill", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Fills a subview of a stream resource with a value.}]; let description = [{ @@ -2124,6 +2137,7 @@ def Stream_AsyncUpdateOp : Stream_Op<"async.update", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Updates a slice of a subview of a resource in-place.}]; let description = [{ @@ -2181,6 +2195,7 @@ def Stream_AsyncCopyOp : Stream_Op<"async.copy", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Copies a subview of a stream resource to another.}]; let description = [{ @@ -2244,6 +2259,7 @@ def Stream_AsyncCollectiveOp : Stream_Op<"async.collective", [ "getTiedResultOperandIndex", "getTiedResultOperandIndices", ]>, + Util_HoistableOpInterface, ]> { let summary = [{Performs a collective operation.}]; let description = [{ @@ -2358,12 +2374,14 @@ def Stream_AsyncTransferOp : Stream_PureOp<"async.transfer", [ "getAffinityAttr", "setAffinityAttr", "getResultAffinityAttr", + "removeAffinityAttrs", ]>, Stream_AsyncPhaseOp, Stream_StreamableOp, DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, ]> { let summary = [{Transfers a resource from one location/state to another.}]; @@ -2504,6 +2522,7 @@ def Stream_AsyncDispatchOp : Stream_PureOp<"async.dispatch", [ DeclareOpInterfaceMethods, + Util_HoistableOpInterface, Util_SizeAwareOp, DeclareOpInterfaceMethods slices, IREE::Stream::ResourceConfigAttr resourceConfig, IndexSet &indexSet, OpBuilder &builder) { @@ -114,7 +114,7 @@ packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, } reservations.insert(insertionIt, reservation); slice.packedOffset.replaceAllUsesWith(builder.createOrFold( - packOp.getLoc(), baseOffset, indexSet.get(bestOffset))); + loc, baseOffset, indexSet.get(bestOffset))); // Update highwater mark indicating how much memory needs to be allocated // for the entire slab. @@ -122,7 +122,7 @@ packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, } highwaterMark = IREE::Util::align(highwaterMark, rangeAlignment); - return builder.createOrFold(packOp.getLoc(), baseOffset, + return builder.createOrFold(loc, baseOffset, indexSet.get(highwaterMark)); } @@ -145,11 +145,10 @@ packStaticSlicesGreedily(IREE::Stream::ResourcePackOp packOp, Value baseOffset, // |baseOffset|. Returns |baseOffset| + the total size of the allocation // aligned to the requirements of |resourceConfig|. static Value -packDynamicSlicesConservatively(IREE::Stream::ResourcePackOp packOp, - Value baseOffset, MutableArrayRef slices, +packDynamicSlicesConservatively(Location loc, Value baseOffset, + MutableArrayRef slices, IREE::Stream::ResourceConfigAttr resourceConfig, IndexSet &indexSet, OpBuilder &builder) { - auto loc = packOp.getLoc(); int64_t offsetAlignment = resourceConfig.getMinBufferOffsetAlignment(); int64_t rangeAlignment = resourceConfig.getMinBufferRangeAlignment(); @@ -255,9 +254,9 @@ struct LayoutSlicesPass // First pack all static slices as these are entirely knowable here at // compile time. - auto offset = packOp.getOffset() ? packOp.getOffset() : indexSet.get(0); + Value offset = packOp.getOffset() ? packOp.getOffset() : indexSet.get(0); if (!staticSlices.empty()) { - offset = packStaticSlicesGreedily(packOp, offset, staticSlices, + offset = packStaticSlicesGreedily(packOp.getLoc(), offset, staticSlices, resourceConfig, indexSet, builder); // TODO(benvanik): make this an option; it can be useful for debugging @@ -270,8 +269,9 @@ struct LayoutSlicesPass // available we could reuse static slices with non-overlapping lifetimes // in some cases. if (!dynamicSlices.empty()) { - offset = packDynamicSlicesConservatively( - packOp, offset, dynamicSlices, resourceConfig, indexSet, builder); + offset = packDynamicSlicesConservatively(packOp.getLoc(), offset, + dynamicSlices, resourceConfig, + indexSet, builder); } // Total packed length is the current offset after all slices are diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp index 26e58773884a..54cbb46de0a0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp @@ -298,7 +298,7 @@ static void updateExportFuncOp(mlir::FunctionOpInterface funcOp) { } //===----------------------------------------------------------------------===// -// --iree-hal-pack-dispatch-operands +// --iree-stream-pack-dispatch-operands //===----------------------------------------------------------------------===// struct PackDispatchOperandsPass diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 4f3879ae1414..d0ec1998266f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -184,26 +184,67 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, // Specialize the encodings before the lowering of stream tensor ops. passManager.addPass(IREE::Stream::createSpecializeEncodingsPass()); + // Lower stream.tensor.* ops to stream.async.* ops based on + // affinity/configuration assigned during placement. FunctionLikeNest(passManager) - // Run canonicalization after specializing to clean up any - // duplicate/redundant IR and fold any duplicate encoding chains before we - // perform the encoding materialization. - .addPass(mlir::createCanonicalizerPass) - .addPass(mlir::createCSEPass) - - // Lower stream.tensor.* ops to stream.async.* ops based on - // affinity/configuration assigned during placement. .addPass(IREE::Stream::createEncodeHostTensorsPass); passManager.addNestedPass( IREE::Stream::createEncodeDeviceTensorsPass()); passManager.addPass(IREE::Stream::createMaterializeEncodingsPass()); + // Layout packed slices (if any exist yet) to emit the arithmetic required for + // all resource offsets. We introduce more packing ops later on but do want + // to support using the layout utilities earlier if the encodings need them. + // Having the arithmetic baked out allows for better propagation (resource + // offsets and sizes can be detected as constant if statically packed, etc). + FunctionLikeNest(passManager).addPass(IREE::Stream::createLayoutSlicesPass); + buildStreamCleanupPassPipeline(passManager, transformOptions); // Everything must now be in stream.async.* form but we don't yet have - // lifetime assigned. + // lifetime assigned. We don't expect there to be any aliasing or other + // trickery yet as we haven't materialized copy-on-write handling and copy + // elision. passManager.addPass(IREE::Stream::createVerifyLoweringToAsyncResourcesPass()); + // If we want to split out a parameter encoder now is the best time: we have + // all of the encodings specialized but haven't yet started allocating memory + // (which will be entirely different in the split module) and if any are + // multi-targeting we haven't yet materialized their concrete forms. + // + // Once this pass runs the original parameters will (mostly) be removed and + // in place of globally initialized constants will be loads from the new + // encoded parameters. Any packing/layout is done now so that the parameter + // index has a common layout between both modules. + if (transformOptions.parameterEncoderOutputFile.hasValue() && + !transformOptions.parameterEncoderOutputFile.empty()) { + IREE::Stream::SplitParameterEncoderPassOptions encoderPassOptions; + encoderPassOptions.mode = transformOptions.parameterEncoderMode; + encoderPassOptions.outputScope = + transformOptions.parameterEncoderOutputScope; + encoderPassOptions.outputFile = transformOptions.parameterEncoderOutputFile; + passManager.addPass( + IREE::Stream::createSplitParameterEncoderPass(encoderPassOptions)); + + // This is somewhat dangerous in that if there is any aliasing in the + // program this _may_ break it. But we don't allow aliasing at this point of + // the pipeline so that's a risk I'm willing to take. The splitting pass + // introduces resource subview ops that we need to propagate to consumers. + // + // TODO(benvanik): improve stream.async.slice handling in + // ElideAsyncCopiesPass. Today it is local only and it results in parameters + // sliced in initializers being treated as copies. If we fixed that we could + // use stream.async.slice as is appropriate at this phase of lowering and + // remove this pass. + passManager.addPass(IREE::Util::createPropagateSubrangesPass()); + + // DCE any executables no longer required just to make the IR cleaner. + // Often times we'll have quite a few hoisted initialization and encoding + // dispatches that are not used elsewhere in the program (though some may + // be due to deduplication!). + passManager.addPass(mlir::createSymbolDCEPass()); + } + // Materialize copy-on-write behavior with explicit stream.async.* ops. // This will insert a lot of copies, so follow it up with a pass that elides // ones that aren't needed. This is easier to verify than if there was one @@ -346,11 +387,6 @@ void buildStreamOptimizationPassPipeline( // cause duplication. Run CSE to collapse. buildStreamCleanupPassPipeline(passManager, transformOptions); - // If any scf ops crept in we get rid of them here. We should be able to - // support them all the way through the stream dialect but some passes are not - // currently set up to handle them (such as elide timepoints). - FunctionLikeNest(passManager).addPass(mlir::createSCFToControlFlowPass); - //---------------------------------------------------------------------------- // Whole-program scheduling optimization //---------------------------------------------------------------------------- @@ -374,6 +410,11 @@ void buildStreamOptimizationPassPipeline( FunctionLikeNest(passManager) .addPass(IREE::Stream::createReuseAllocationsPass); + // If any scf ops crept in we get rid of them here. We should be able to + // support them all the way through the stream dialect but some passes are + // not currently set up to handle them (such as elide timepoints). + FunctionLikeNest(passManager).addPass(mlir::createSCFToControlFlowPass); + // Elide timepoints in dependency chains where one is known to have been // reached by the time another is (A -> B -> A|C). ipoPipeline.addPass(IREE::Stream::createElideTimepointsPass()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h index da852c2a5055..4a22080f2f2a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.h @@ -9,6 +9,7 @@ #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "iree/compiler/Utils/OptionUtils.h" #include "llvm/ADT/StringMap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -55,6 +56,21 @@ enum class DumpOutputFormat { JSON = 4, }; +// Controls how the encoder manages parameters. +enum class ParameterEncoderMode { + // Merge all encoded parameters and original parameters into a single + // consolidated scope. + Consolidate = 0, + // Only produce encoded parameters and leave original parameters untouched. + Overlay = 1, +}; + +// Options for the Stream transformation pipeline. +// +// These options are typically populated from top-level compiler options +// (ParameterOptions in Pipelines/Options.h) when building the full compiler +// pipeline. When constructing individual passes, relevant options are mapped +// to pass-specific option structs (e.g., SplitParameterEncoderPassOptions). struct TransformOptions : public PassPipelineOptions { Option initializationMode{ *this, @@ -72,6 +88,34 @@ struct TransformOptions : public PassPipelineOptions { "waiting for them to complete.")), }; + Option parameterEncoderMode{ + *this, + "parameter-encoder-mode", + llvm::cl::desc("Controls how the encoder manages parameters."), + llvm::cl::init(ParameterEncoderMode::Consolidate), + llvm::cl::values( + clEnumValN(ParameterEncoderMode::Consolidate, "consolidate", + "Merge all encoded parameters and original parameters " + "into a single consolidated scope."), + clEnumValN(ParameterEncoderMode::Overlay, "overlay", + "Only produce encoded parameters and leave original " + "parameters untouched.")), + }; + Option parameterEncoderOutputScope{ + *this, + "parameter-encoder-output-scope", + llvm::cl::desc( + "Parameter scope for the output parameters. Omit for global."), + llvm::cl::init("encoded"), + }; + Option parameterEncoderOutputFile{ + *this, + "parameter-encoder-output-file", + llvm::cl::desc(".mlir/.mlirbc file path to write the split parameter " + "encoder module to."), + llvm::cl::init(""), + }; + Option optimizeBindings{ *this, "optimize-bindings", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index 8a175399a162..9c15ada3b4b5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -44,6 +44,56 @@ def ConvertToStreamPass : ]; } +def SplitParameterEncoderPass : + Pass<"iree-stream-split-parameter-encoder", "mlir::ModuleOp"> { + let summary = "Splits out a parameter encoder module for compatible hoisted expressions."; + let description = [{ + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "IREE::HAL::HALDialect", + "IREE::Stream::StreamDialect", + "IREE::Util::UtilDialect", + ]; + let options = [ + Option< + "mode", "mode", + "IREE::Stream::ParameterEncoderMode", + /*default=*/"IREE::Stream::ParameterEncoderMode::Consolidate", + "Controls how the encoder manages parameters.", + [{::llvm::cl::values( + clEnumValN(IREE::Stream::ParameterEncoderMode::Consolidate, "consolidate", "Merge all encoded parameters and original parameters into a single consolidated scope."), + clEnumValN(IREE::Stream::ParameterEncoderMode::Overlay, "overlay", "Only produce encoded parameters and leave original parameters untouched.") + )}] + >, + Option< + "outputScope", "output-scope", + "std::string", /*default=*/"\"encoded\"", + "Parameter scope for the output parameters. Omit for global." + >, + Option< + "outputFile", "output-file", + "std::string", /*default=*/"std::string()", + ".mlir/.mlirbc file path to write the split parameter encoder module to." + >, + Option< + "hoistParameterExpressions", "hoist-parameter-expressions", + "bool", /*default=*/"true", + "Enable hoisting parameter transformation expressions." + >, + Option< + "hoistConstantExpressions", "hoist-constant-expressions", + "bool", /*default=*/"true", + "Enable hoisting pure constant expressions with transformations." + >, + Option< + "maxEncodingGrowthFactor", "max-encoding-growth-factor", + "float", /*default=*/"1.2f", + "Maximum ratio of output size to input parameter size." + >, + ]; +} + def EncodeHostTensorsPass : Pass<"iree-stream-encode-host-tensors", ""> { let summary = "Encodes tensors into storage formats based on affinity and target support."; @@ -707,7 +757,7 @@ def DumpStatisticsPass : Option< "outputFormat", "output-format", "IREE::Stream::DumpOutputFormat", - "IREE::Stream::DumpOutputFormat::Pretty", + /*default=*/"IREE::Stream::DumpOutputFormat::Pretty", "Specifies the output format to produce.", [{::llvm::cl::values( clEnumValN(IREE::Stream::DumpOutputFormat::Pretty, "pretty", "Human-readable pretty printed output."), @@ -717,8 +767,7 @@ def DumpStatisticsPass : >, Option< "outputFile", "output-file", - "std::string", - /*default=*/"std::string()", + "std::string", /*default=*/"std::string()", "File path to write to; or `` for stderr or `-` for stdout." >, ]; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp index 3378748644f4..833ac102a220 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp @@ -499,6 +499,7 @@ static void insertUsageRefinementPatterns(MLIRContext *context, ApplyGenericOp, ApplyGenericOp, ApplyGenericOp, + ApplyGenericOp, ApplyGenericOp, ApplyGenericOp>(context, analysis); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp new file mode 100644 index 000000000000..de0c6fdb001f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SplitParameterEncoder.cpp @@ -0,0 +1,2216 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Utils/IntegerSet.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "iree/compiler/Utils/RegionOpUtils.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FileSystem.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/RegionUtils.h" + +#define DEBUG_TYPE "iree-stream-split-parameter-encoder" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir::iree_compiler::IREE::Stream { + +#define GEN_PASS_DEF_SPLITPARAMETERENCODERPASS +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Collection of IndexSets for managing index memoization across functions. +class IndexSetCollection { +public: + // Returns an index set for the parent function of |op|. + IndexSet *get(Operation *op) { + auto parentOp = op->getParentOfType(); + auto it = funcMap.find(parentOp); + if (it != funcMap.end()) { + return it->second.get(); + } + auto indexSet = std::make_unique( + op->getLoc(), OpBuilder::atBlockBegin(&parentOp.front())); + IndexSet *indexSetPtr = indexSet.get(); + funcMap.insert({parentOp, std::move(indexSet)}); + return indexSetPtr; + } + +private: + DenseMap> funcMap; +}; + +// Erases all ops in |leafOps| and all of their potentially newly-dead +// transitive producer dependencies. +// +// This custom DCE is required because MLIR's standard mlir::dce::removeDeadCode +// doesn't handle two cases we need: +// 1. Ops implementing HoistableOpInterface - control flow ops like scf.for/if +// whose bodies contain only hoistable (pure) operations can be deleted. +// 2. Ops with MemoryEffects::Read but no Write - these are "pure-ish" ops that +// can't be marked Pure (which would allow CSE) but are still safe to delete. +// +// TODO(benvanik): figure out how to move this to RegionOpUtils - it relies on +// some util op interfaces, though, so is hard to get out there now. +static void pruneDeadOps(ArrayRef leafOps) { + SmallVector deadOpWorklist{leafOps}; + + // Use a DenseSet to track already-processed operations to avoid duplicate + // processing when operations appear multiple times in the worklist. + DenseSet processedOps; + while (!deadOpWorklist.empty()) { + Operation *op = deadOpWorklist.pop_back_val(); + + // Skip if we've already processed this operation. + if (!processedOps.insert(op).second) { + continue; + } + + // Skip if the operation is no longer trivially dead (may have been + // deleted already or gained new uses). + // Also elide ops with no uses that have MemoryEffects::Effect::Read but no + // writes - these match the semantics of canonicalization ElideUnusedOp + // patterns (ops that are pure-ish but can't be marked Pure due to CSE). + bool canDelete = mlir::isOpTriviallyDead(op); + if (!canDelete && op->use_empty()) { + auto memInterface = dyn_cast(op); + if (memInterface) { + SmallVector effects; + memInterface.getEffects(effects); + // Safe to delete if it only has Allocate/Read effects (no Write). + canDelete = + llvm::none_of(effects, [](const MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + }); + } else if (auto hoistableOp = + dyn_cast(op)) { + // Operations with HoistableOpInterface can be deleted if unused and + // hoistable (pure). This handles control flow ops like scf.for/scf.if + // whose bodies contain only hoistable operations. + canDelete = hoistableOp.isHoistableOp(); + } + } + if (!canDelete) { + continue; + } + + // Collect defining operations before we delete this op. + SmallVector producerOps; + for (Value operand : op->getOperands()) { + if (Operation *producer = operand.getDefiningOp()) { + producerOps.push_back(producer); + } + } + + // Erase the dead operation. + op->erase(); + + // Check if any of the producers now have no uses and add them to the + // worklist. The worklist loop will determine if they're safe to delete. + for (Operation *producer : producerOps) { + if (producer->use_empty()) { + deadOpWorklist.push_back(producer); + } + } + } +} + +//===----------------------------------------------------------------------===// +// EncodingExpr +//===----------------------------------------------------------------------===// + +// Configuration controlling which expressions are hoisted to the encoder +// module. This policy determines hoisting eligibility based on expression type, +// size growth limits, and parameter/constant handling preferences. +struct EncodingPolicy { + // Pack multiple parameters into larger slabs to reduce overheads. + // This can dramatically improve startup time, reduces memory fragmentation, + // and reduces dispatch overheads. + bool packParameters = true; // false; + // Include direct parameter loads that have no modifications. + // When true the output parameter indices will have all required parameters + // and any original parameters will not be required by the base program at + // runtime. When false the user must provide the original parameters. + bool includeUnmodified = true; + // Any splat under this size will be serialized to the output parameter index + // as if it were data instead of being embedded as a splat. + // This increases the file size but allows for better parameter batching and + // can reduce runtime overhead. + int64_t serializeSplatSizeThreshold = 1024; + + // Enable hoisting parameter transformation expressions. + // When true, expressions that transform parameters (parameter → + // dispatch/encoding) will be extracted into the encoder module for offline + // evaluation. + bool hoistParameterExpressions = true; + + // Enable hoisting pure constant expressions with transformations. + // When true, expressions that transform pure constants (constant → + // dispatch/encoding) will be extracted into the encoder module for offline + // evaluation. + bool hoistConstantExpressions = true; + + // Maximum ratio of output size to input size before rejecting hoisting. + // This prevents expressions that significantly increase storage from being + // hoisted. Example: 1.2 allows 20% growth for padding/alignment. + float maxEncodingGrowthFactor = 1.2f; +}; + +// An encoding expression represents a subgraph of operations that transforms +// input parameters/constants into output values stored to globals. Each +// expression can have multiple inputs (parameter loads) and multiple outputs +// (global stores). The expression is hoisted to the encoder module where it +// can be evaluated offline, with the results stored as pre-encoded parameters. +struct EncodingExpr { + // Affinity of consumers of the expression in the original program. + // All outputs share the same affinity. + IREE::Stream::AffinityAttr affinityAttr; + + struct Input { + // Inlined constant resource or parameter load. + mutable IREE::Stream::AsyncConstantOp constantOp; + + Location getLoc() const { return constantOp.getLoc(); } + + // Returns true if the input is sourced from a parameter. + bool isParameter() const { + return isa(constantOp.getValue()); + } + }; + SmallVector inputs; + + struct Output { + // Size in bytes of the output resource. + int64_t size = 0; + // Constant pattern value if this is a splat. + TypedAttr splatPattern; + // Sink op storing the produced output into a global. + mutable IREE::Util::GlobalStoreOpInterface storeOp; + // Produced value feeding into the store. + // This may be either be directly consumed by the store or an op earlier in + // the slice in cases where there are metadata ops we want to skip. + Value producedValue; + + Location getLoc() const { return storeOp.getLoc(); } + + // Returns true if the output is a constant splat that needs no execution. + // Only certain data types/widths are supported in the format and if not + // supported natively we'll need to splat the value into the file. It's + // rare for there to be splats that end up like this and it's unlikely the + // user wants a file full of splatted values but at this point in the + // pipeline we can only assume they asked for it. + bool isSupportedSplat() const { + if (!splatPattern || !splatPattern.getType().isIntOrFloat()) { + return false; + } + const unsigned bitWidth = splatPattern.getType().getIntOrFloatBitWidth(); + return bitWidth == 8 || bitWidth == 16 || bitWidth == 32 || + bitWidth == 64; + } + }; + SmallVector outputs; + + // All operations (excluding outputs). + SetVector ops; + + // Returns a fused location from all operations in the expression. + Location getLoc() const { + SetVector locs; + for (auto *op : ops) { + locs.insert(op->getLoc()); + } + for (auto &output : outputs) { + locs.insert(output.getLoc()); + } + return FusedLoc::get(ops.front()->getContext(), locs.getArrayRef()); + } + + // Returns the resource config for the expression by checking all outputs. + // If any outputs have differing configs + IREE::Stream::ResourceConfigAttr getResourceConfigAttr() const { + // Expressions should only be formed from outputs that share an affinity + // so we can look at the first output and assume they all match. + if (outputs.empty()) { + return {}; + } + auto globalStoreOp = outputs.front().storeOp; + Value storedValue = globalStoreOp.getStoredGlobalValue(); + auto *producingOp = storedValue.getDefiningOp(); + return IREE::Stream::ResourceConfigAttr::lookup( + producingOp ? producingOp : globalStoreOp); + } + + // Returns true if the expression has any parameter inputs. + bool hasParameterInputs() const { + return llvm::any_of(inputs, + [](const Input &input) { return input.isParameter(); }); + } + + // Returns true if the expression has any constant inputs (non-parameter). + bool hasConstantInputs() const { + return llvm::any_of( + inputs, [](const Input &input) { return !input.isParameter(); }); + } + + // Estimates total input size from all inputs in bytes. + int64_t estimateInputSize() const { + int64_t total = 0; + for (const auto &input : inputs) { + if (input.constantOp) { + Value sizeValue = input.constantOp.getResultSize(); + APInt size; + if (matchPattern(sizeValue, m_ConstantInt(&size))) { + total += size.getZExtValue(); + } + } + } + return total; + } + + // Estimates total output size from all outputs in bytes. + int64_t estimateOutputSize() const { + int64_t total = 0; + for (const auto &output : outputs) { + total += output.size; + } + return total; + } +}; + +struct EncodingExprSet { + // All expressions terminating in parameter outputs in the order they were + // originally present in the module (even if split across initializers). + SmallVector exprs; + + bool empty() const { return exprs.empty(); } +}; + +// Collects all external timepoint dependencies from the expression. This +// includes await timepoints from TimelineOpInterface ops in the expression that +// reference external values, and timepoints from external resource operands +// extracted via getResultTimepoint or by inserting a barrier. +static Value collectExternalTimepoints(const EncodingExpr &expr, + OpBuilder &builder) { + SetVector timepoints; + + // Build a set of ops that contribute RESOURCES (not just timepoints) to the + // expression. An op is a "resource contributor" if at least one of its + // non-timepoint results is used by another op in the expression. + // + // This distinction is important because the backward slice follows ALL + // operands including await timepoints. Ops that only contribute timepoints + // (like a timeline_op whose resource output is unused) should be considered + // "external" for synchronization purposes - their timepoints need to be + // awaited by the replacement op. + DenseSet resourceContributors; + for (Operation *op : expr.ops) { + for (Value result : op->getResults()) { + // Skip timepoint results - we only care about resource contributions. + if (isa(result.getType())) { + continue; + } + // Check if any user of this non-timepoint result is in the expression. + for (Operation *user : result.getUsers()) { + if (expr.ops.contains(user)) { + resourceContributors.insert(op); + break; + } + } + if (resourceContributors.contains(op)) { + break; + } + } + } + + // A timepoint is "internal" only if its defining op contributes resources + // (not just timepoints) to the expression. + auto isInternalTimepoint = [&](Value tp) -> bool { + Operation *defOp = tp.getDefiningOp(); + return defOp && resourceContributors.contains(defOp); + }; + + // Collect external await timepoints from resource-contributing ops only. + // We only look at resource contributors because: + // 1. They represent the "core" data flow of the expression + // 2. Non-resource-contributor ops (like joins, unused timeline ops) are + // "synchronization helpers" whose await timepoints are transitively + // covered by the resource contributors' awaits + // This ensures we don't collect both a joined timepoint AND its component + // timepoints when a join is in the expression but doesn't contribute + // resources. + for (Operation *op : resourceContributors) { + auto timelineOp = dyn_cast(op); + if (!timelineOp) { + continue; + } + for (Value awaitTp : timelineOp.getAwaitTimepoints()) { + if (!isInternalTimepoint(awaitTp)) { + timepoints.insert(awaitTp); + } + } + } + + // A resource is "internal" only if its defining op contributes resources + // (not just timepoints) to the expression. + auto isInternalResource = [&](Value resource) -> bool { + Operation *defOp = resource.getDefiningOp(); + return defOp && resourceContributors.contains(defOp); + }; + + // Collect timepoints from external resource operands. + for (Operation *op : expr.ops) { + for (Value operand : op->getOperands()) { + if (!isa(operand.getType())) { + continue; + } + if (isInternalResource(operand)) { + continue; + } + + // Try to get timepoint from TimelineOpInterface. + Value timepoint; + Operation *definingOp = operand.getDefiningOp(); + if (definingOp) { + if (auto timelineOp = + dyn_cast(definingOp)) { + timepoint = timelineOp.getResultTimepoint(); + } + } + + // If no timepoint available, insert barrier to extract it. + if (!timepoint) { + Value resourceSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + operand.getLoc(), operand, builder); + assert(resourceSize && "stream resource must have queryable size"); + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(definingOp); + auto barrierOp = IREE::Stream::TimepointBarrierOp::create( + builder, operand.getLoc(), operand.getType(), + builder.getType(), operand, + resourceSize, affinityAttr); + timepoint = barrierOp.getResultTimepoint(); + } + + if (timepoint) { + timepoints.insert(timepoint); + } + } + } + + if (timepoints.empty()) { + return {}; + } + return IREE::Stream::joinTimepoints( + expr.getLoc(), SmallVector(timepoints.begin(), timepoints.end()), + builder); +} + +// Finds all util.global.store-like ops that store constant resources in +// initializers. Stores are returned in program order. +// +// TODO: note that this does not check for stores in functions called by +// initializers and also does not currently check for variables (as they are +// usually uninitialized). +static SmallVector +findAllConstantStoreOps(mlir::ModuleOp moduleOp) { + SmallVector storeOps; + for (auto initializerOp : + moduleOp.getOps()) { + // Skip initializers that have CFGs. We don't handle conditional + // initialization of globals today. + auto ®ion = initializerOp.getInitializerRegion(); + if (!region.hasOneBlock()) { + LLVM_DEBUG(DBGS() << "ignoring initializer as it has multiple blocks\n"); + continue; + } + // Find all stores. Note that we purposefully skip nested regions today. + for (auto storeOp : + region.front().getOps()) { + Type storedType = storeOp.getStoredGlobalValue().getType(); + if (auto resourceType = + dyn_cast(storedType)) { + if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant) { + storeOps.push_back(storeOp); + } + } + } + } + return storeOps; +} + +// Returns true if the operation's memory effects allow it to be hoisted as a +// const-expr operation. We allow Allocate and Free effects (memory management) +// but reject Read/Write effects to external memory. +static bool hasHoistableMemoryEffects(Operation *op) { + auto effectInterface = dyn_cast(op); + if (!effectInterface) { + // No memory effect interface means no effects - hoistable. + return true; + } + + SmallVector effects; + effectInterface.getEffects(effects); + + for (const auto &effect : effects) { + // Allocate effects are fine (creating new memory). + if (isa(effect.getEffect())) { + continue; + } + // Free effects are also fine (releasing memory). + if (isa(effect.getEffect())) { + continue; + } + // Read or Write effects on non-result values are not const-expr. + // Operations can write to their own results (that's how they produce + // them), but reading/writing external memory is not allowed. + if (isa(effect.getEffect()) || + isa(effect.getEffect())) { + // Check if the effect is on a result of this op (allowed) or + // on external memory (not allowed). + if (Value value = llvm::dyn_cast_if_present(effect.getValue())) { + // If it's a result of this op, it's fine. + if (value.getDefiningOp() == op) { + continue; + } + } + // Read/Write to external memory - not const-expr. + return false; + } + } + + return true; +} + +static bool isConstExprOp(Operation *op) { + // Optimization barriers cannot be folded. + if (isa(op)) { + return false; + } + + // By default, ops without results are not const-expr. + if (op->getNumResults() == 0) { + return false; + } + + // If implementing the HoistableOpInterface, just use the decision made by + // the interface. + if (auto hoistableOp = dyn_cast(op)) { + return hoistableOp.isHoistableOp(); + } + + // Forbid if part of a parent that should be treated atomically. + Operation *parent = op; + while (auto hoistableParent = + parent->getParentOfType()) { + if (hoistableParent.isAtomicallyHoistableOp()) { + return false; + } + parent = hoistableParent; + } + + // Check memory effects: we allow Allocate effects (creating new memory) + // but reject Read/Write effects to external memory. This is more permissive + // than OpOracle.cpp's isMemoryEffectFree check, allowing operations like + // stream.async.splat that allocate but don't have other side effects. + return hasHoistableMemoryEffects(op); +} + +static IREE::Stream::AffinityAttr +lookupConsumerAffinityAttr(Value storedValue) { + if (auto affinityOp = dyn_cast( + storedValue.getDefiningOp())) { + return affinityOp.getResultAffinityAttr(); + } + return IREE::Stream::AffinityAttr::lookupOrDefault( + storedValue.getDefiningOp()); +} + +// Returns true if the expression producing |storedValue| is an input without +// any modification (such as inlined constants/parameters). +static bool isPassThroughStore(Value storedValue) { + Operation *op = storedValue.getDefiningOp(); + do { + if (auto transferOp = dyn_cast(op)) { + op = transferOp.getSource().getDefiningOp(); + } else if (auto constantOp = dyn_cast(op)) { + return true; + } else { + return false; + } + } while (op); + return false; +} + +// Returns the result index of |result| in the parent operation. +// The result must be a valid result of op. +static unsigned findResultIndex(Operation *op, Value result) { + for (unsigned i = 0; i < op->getNumResults(); ++i) { + if (op->getResult(i) == result) { + return i; + } + } + llvm_unreachable("result not found in operation"); +} + +// Attempts to evaluate a size value to a constant integer. +// This handles direct constants and analyzes through control flow operations +// where the size is provably constant (e.g., scf.if with matching branch +// sizes). +static std::optional tryEvaluateConstantSize(Value sizeValue) { + if (!sizeValue) { + return std::nullopt; + } + + // Try direct constant match (existing behavior). + APInt size; + if (matchPattern(sizeValue, m_ConstantInt(&size))) { + return size.getZExtValue(); + } + + // For scf.if, check if both branches yield the same constant size. + if (auto ifOp = sizeValue.getDefiningOp()) { + unsigned resultIndex = findResultIndex(ifOp, sizeValue); + + // Get the yielded values from both regions. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + + // Recursively evaluate both branch sizes. + Value thenValue = thenYield.getOperand(resultIndex); + Value elseValue = elseYield.getOperand(resultIndex); + + // Find sizes for the yielded resource values. + auto thenSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + thenValue, &ifOp.getThenRegion().front(), Block::iterator(thenYield)); + auto elseSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + elseValue, &ifOp.getElseRegion().front(), Block::iterator(elseYield)); + + auto thenSize = tryEvaluateConstantSize(thenSizeValue); + auto elseSize = tryEvaluateConstantSize(elseSizeValue); + + // If both branches have the same constant size, return it. + if (thenSize && elseSize && *thenSize == *elseSize) { + return *thenSize; + } + + return std::nullopt; + } + + // For scf.for, check if the size is loop-invariant. + if (auto forOp = sizeValue.getDefiningOp()) { + unsigned resultIndex = findResultIndex(forOp, sizeValue); + + // Check the initial value (iter_arg). + Value initArg = forOp.getInitArgs()[resultIndex]; + auto initSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + initArg, forOp->getBlock(), Block::iterator(forOp)); + auto initSize = tryEvaluateConstantSize(initSizeValue); + + if (!initSize) { + return std::nullopt; + } + + // Check the yielded value in the loop body. + auto yieldOp = + cast(forOp.getRegion().front().getTerminator()); + Value yieldedValue = yieldOp.getOperand(resultIndex); + + // Find size for the yielded resource value. + auto yieldedSizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + yieldedValue, &forOp.getRegion().front(), Block::iterator(yieldOp)); + auto yieldedSize = tryEvaluateConstantSize(yieldedSizeValue); + + // If the yielded size matches the initial size, it's invariant. + if (yieldedSize && *yieldedSize == *initSize) { + return *initSize; + } + + return std::nullopt; + } + + // Could not evaluate to a constant. + return std::nullopt; +} + +// Returns a constant pattern for a value derived entirely from a splatted +// value. Returns nullptr if the value is not derived from a splat or has a +// non-constant pattern. +static TypedAttr findConstantSplatPattern(Value storedValue) { + Operation *op = storedValue.getDefiningOp(); + do { + if (auto transferOp = dyn_cast(op)) { + op = transferOp.getSource().getDefiningOp(); + } else if (auto splatOp = dyn_cast(op)) { + TypedAttr pattern; + if (matchPattern(splatOp.getValue(), m_Constant(&pattern))) { + return pattern; + } + return {}; + } else { + return {}; + } + } while (op); + return {}; +} + +// Returns the last value produced that is non-metadata (according to us). +// This lets us skip meaningless ops like transfers and clones that change +// lifetime when cloning into the target program. Those ops, though valid, make +// the IR a lot more confusing to follow and prevent some early folding +// opportunities. +static Value findProducedValue(Value value) { + while (Operation *defOp = value.getDefiningOp()) { + if (auto transferOp = dyn_cast(defOp)) { + // We never care about transfers unless they are transferring to unknown. + auto resultType = + cast(transferOp.getResult().getType()); + if (resultType.getLifetime() != IREE::Stream::Lifetime::Unknown) { + value = transferOp.getSource(); + continue; + } + } else if (auto cloneOp = dyn_cast(defOp)) { + // Skip past clones to find the actual producing operation. + // Clones are just type/lifetime conversions, not data producers. + value = cloneOp.getSource(); + continue; + } + break; + } + return value; +} + +// Returns true if the expression should be hoisted based on policy. +static bool shouldHoistExpression(const EncodingExpr &expr, + const EncodingPolicy &policy) { + bool hasParams = expr.hasParameterInputs(); + bool hasConstants = expr.hasConstantInputs(); + + // Check if this expression type should be hoisted per policy. + if (hasParams && !policy.hoistParameterExpressions) { + LLVM_DEBUG(DBGS() << "skipping parameter expression per policy\n"); + return false; + } + if (!hasParams && hasConstants && !policy.hoistConstantExpressions) { + LLVM_DEBUG(DBGS() << "skipping constant expression per policy\n"); + return false; + } + if (!hasParams && !hasConstants) { + // No inputs at all - probably an error case or pure splat. + LLVM_DEBUG(DBGS() << "skipping expression with no inputs\n"); + return false; + } + + // Check size growth threshold. + int64_t inputSize = expr.estimateInputSize(); + int64_t outputSize = expr.estimateOutputSize(); + if (inputSize > 0) { + float growthFactor = static_cast(outputSize) / inputSize; + if (growthFactor > policy.maxEncodingGrowthFactor) { + LLVM_DEBUG(DBGS() << "rejecting expression due to size growth: " + << growthFactor << "x (threshold: " + << policy.maxEncodingGrowthFactor << "x)\n"); + return false; + } + } + + return true; +} + +// Analyzes |moduleOp| to find all expressions producing global constants that +// we can turn into parameters, if any. +static EncodingExprSet gatherEncodingExprSet(mlir::ModuleOp moduleOp, + EncodingPolicy policy) { + auto constantStoreOps = findAllConstantStoreOps(moduleOp); + + EncodingExprSet exprSet; + + std::unique_ptr asmState; + LLVM_DEBUG(asmState = std::make_unique( + moduleOp, OpPrintingFlags().elideLargeElementsAttrs())); + + for (auto storeOp : constantStoreOps) { + LLVM_DEBUG({ + DBGS() << "evaluating store slice for inclusion: "; + storeOp->print(llvm::dbgs(), *asmState); + llvm::dbgs() << "\n"; + }); + Value storedValue = storeOp.getStoredGlobalValue(); + + BackwardSliceOptions sliceOptions; + sliceOptions.inclusive = true; + bool foundAnyNonConstExprOps = false; + sliceOptions.filter = [&](Operation *op) { + if (isConstExprOp(op)) { + return true; + } + foundAnyNonConstExprOps = true; + return false; + }; + // Collect all values that need to be included in the slice: + // - The stored value itself + // - Values used inside nested regions that are defined outside + // + // We compute backward slices for all of them into the same SetVector, + // which gives us proper topological ordering with deduplication. + SetVector rootValues; + rootValues.insert(storedValue); + + // Do a first pass to find region-containing operations. + SetVector tempSlice; + if (failed(mlir::getBackwardSlice(storedValue, &tempSlice, sliceOptions)) || + foundAnyNonConstExprOps) { + LLVM_DEBUG(DBGS() << "failed to calculate backward slice for op or found " + "non-const-expr ops, skipping\n"); + continue; + } + + // Find external dependencies from nested regions using MLIR's standard API. + // getUsedValuesDefinedAbove returns all values used inside a region but + // defined outside of it - exactly what we need for region captures. + for (auto *op : tempSlice) { + for (Region ®ion : op->getRegions()) { + SetVector capturedValues; + mlir::getUsedValuesDefinedAbove(region, capturedValues); + LLVM_DEBUG({ + if (!capturedValues.empty()) { + DBGS() << "found " << capturedValues.size() + << " captured values in region of "; + op->print(llvm::dbgs()); + llvm::dbgs() << ":\n"; + for (Value captured : capturedValues) { + llvm::dbgs() << " "; + captured.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + } + } + }); + for (Value captured : capturedValues) { + rootValues.insert(captured); + } + } + } + + // Now compute backward slices for all root values. + // When we have multiple roots (due to captured values), calling + // getBackwardSlice iteratively can break topological order because new + // operations get appended. We need to sort after merging. + bool needsSort = rootValues.size() > 1; + SetVector slice; + for (Value rootValue : rootValues) { + if (failed(mlir::getBackwardSlice(rootValue, &slice, sliceOptions)) || + foundAnyNonConstExprOps) { + LLVM_DEBUG(DBGS() << "failed to calculate backward slice for op or " + "found non-const-expr ops, skipping\n"); + break; + } + } + + if (foundAnyNonConstExprOps) { + continue; + } + + // Sort only when we merged multiple slices (i.e., had captured values). + // This is a small set (one expression), not the whole program. + // Use mlir::topologicalSort which correctly handles operations across + // different blocks and regions, unlike isBeforeInBlock which only works + // for operations within the same block. + if (needsSort) { + slice = mlir::topologicalSort(slice); + } + + LLVM_DEBUG({ + DBGS() << "slice:\n"; + llvm::interleave( + slice, llvm::dbgs(), + [&](Operation *op) { + llvm::dbgs() << " "; + op->print(llvm::dbgs(), *asmState); + }, + "\n"); + llvm::dbgs() << "\n"; + }); + + // Overlay mode optimization: When a slice is just a parameter load with no + // transformation (detected by isPassThroughStore below), we skip including + // it as an output in overlay mode since the original parameter is + // unchanged. This is controlled by policy.includeUnmodified: + // - Consolidate mode (includeUnmodified=true): includes all parameters + // - Overlay mode (includeUnmodified=false): skips pass-through parameters + // + // Future enhancement: Could add overlap detection to merge expressions that + // write to overlapping parameter regions, possibly requiring a two-pass + // approach. For now, non-overlapping expressions work correctly. + + EncodingExpr expr; + expr.affinityAttr = + lookupConsumerAffinityAttr(storeOp.getStoredGlobalValue()); + + for (auto *op : slice) { + if (auto constantOp = dyn_cast(op)) { + EncodingExpr::Input input; + input.constantOp = constantOp; + expr.inputs.push_back(input); + } + } + + if (!isPassThroughStore(storeOp.getStoredGlobalValue()) || + policy.includeUnmodified) { + // Check if the produced value prefers cloning (like pure splats). + // These should be included in the slice for cloning but not serialized + // as outputs. + Value producedValue = findProducedValue(storeOp.getStoredGlobalValue()); + auto *producingOp = producedValue.getDefiningOp(); + if (producingOp) { + if (auto streamableOp = + dyn_cast(producingOp)) { + if (streamableOp.preferCloneToConsumers()) { + LLVM_DEBUG(DBGS() + << "skipping output for op that prefers cloning\n"); + continue; + } + } + } + + Value storedValue = storeOp.getStoredGlobalValue(); + Value sizeValue = IREE::Util::SizeAwareTypeInterface::findSizeValue( + storedValue, storeOp->getBlock(), Block::iterator(storeOp)); + + // If findSizeValue returns null, it might be because the value comes from + // a control flow operation (like scf.for or scf.if) that doesn't + // implement SizeAwareOpInterface. Try analyzing the control flow + // directly. + std::optional sizeOpt; + if (sizeValue) { + sizeOpt = tryEvaluateConstantSize(sizeValue); + } else if (auto *defOp = storedValue.getDefiningOp()) { + // Try analyzing control flow operations directly. + if (auto forOp = dyn_cast(defOp)) { + // Find which result this is. + unsigned resultIdx = 0; + for (unsigned i = 0; i < forOp.getNumResults(); ++i) { + if (forOp.getResult(i) == storedValue) { + resultIdx = i; + break; + } + } + // Get size from init arg. + Value initArg = forOp.getInitArgs()[resultIdx]; + Value initSizeValue = + IREE::Util::SizeAwareTypeInterface::findSizeValue( + initArg, forOp->getBlock(), Block::iterator(forOp)); + sizeOpt = tryEvaluateConstantSize(initSizeValue); + } else if (auto ifOp = dyn_cast(defOp)) { + // Find which result this is. + unsigned resultIdx = 0; + for (unsigned i = 0; i < ifOp.getNumResults(); ++i) { + if (ifOp.getResult(i) == storedValue) { + resultIdx = i; + break; + } + } + + // Get sizes from both branches. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + Value thenValue = thenYield.getOperand(resultIdx); + Value elseValue = elseYield.getOperand(resultIdx); + auto thenSizeValue = + IREE::Util::SizeAwareTypeInterface::findSizeValue( + thenValue, &ifOp.getThenRegion().front(), + Block::iterator(thenYield)); + auto elseSizeValue = + IREE::Util::SizeAwareTypeInterface::findSizeValue( + elseValue, &ifOp.getElseRegion().front(), + Block::iterator(elseYield)); + auto thenSize = tryEvaluateConstantSize(thenSizeValue); + auto elseSize = tryEvaluateConstantSize(elseSizeValue); + + // Both branches must have the same constant size. + if (thenSize && elseSize && *thenSize == *elseSize) { + sizeOpt = *thenSize; + } + } + } + + if (!sizeOpt) { + LLVM_DEBUG(DBGS() << "failed to find stored resource size, skipping\n"); + continue; + } + EncodingExpr::Output output; + output.size = *sizeOpt; + output.splatPattern = + findConstantSplatPattern(storeOp.getStoredGlobalValue()); + output.storeOp = storeOp; + output.producedValue = producedValue; + expr.outputs.push_back(output); + } + + if (expr.outputs.empty()) { + LLVM_DEBUG(DBGS() << "no outputs produced by policy, skipping\n"); + continue; + } + + expr.ops = std::move(slice); + exprSet.exprs.push_back(std::move(expr)); + } + + return exprSet; +} + +//===----------------------------------------------------------------------===// +// ParameterIndex and builders +//===----------------------------------------------------------------------===// + +// An entry in the parameter index describing a single output parameter. +// Entries can be either SPLAT (constant pattern fill) or DATA (computed bytes). +// A single EncodingExpr may produce multiple entries if it has multiple +// outputs. +struct ParameterEntry { + // Location of the parameter based on the original consumer op. + std::optional loc; + enum class Type { + SPLAT = 0, + DATA = 1, + }; + // Type of the entry (indicates which value field is valid). + Type type; + // Key of the entry within the parameter scope. + StringAttr key; + // Optional metadata embedded with the entry. + SmallVector metadata; + // Total byte length of the parameter in memory. + int64_t length; + // Type-specific value. + union { + struct SplatEntry { + int64_t pattern; + int64_t patternLength; + } splat; + struct DataEntry { + int64_t minimumAlignment; + } data; + } value; + + static ParameterEntry createSplat(Location loc, StringAttr key, + int64_t length, int64_t pattern, + int64_t patternLength) { + ParameterEntry entry{loc}; + entry.type = Type::SPLAT; + entry.key = key; + entry.length = length; + entry.value.splat.pattern = pattern; + entry.value.splat.patternLength = patternLength; + return entry; + } + + static ParameterEntry createData(Location loc, StringAttr key, int64_t length, + int64_t minimumAlignment) { + ParameterEntry entry{loc}; + entry.type = Type::DATA; + entry.key = key; + entry.length = length; + entry.value.data.minimumAlignment = minimumAlignment; + return entry; + } + + Location getLoc() const { + return loc.has_value() ? loc.value() : UnknownLoc::get(key.getContext()); + } +}; + +// An IRPA parameter index. +struct ParameterIndex { + // Fused location derived from all included parameter locations. + Location loc; + // Scope name the index is referenced with, if any. + StringAttr scope; + // All parameter entries in the index. + SmallVector entries; + + void dump(llvm::raw_ostream &os) const { + os << "ParameterIndex[" << scope << "]:\n"; + llvm::interleave( + entries, os, + [&](const ParameterEntry &entry) { + os << " '" << entry.key << "' " << entry.length << " bytes "; + if (!entry.metadata.empty()) { + os << "(metadata: " << entry.metadata.size() << "B) "; + } + switch (entry.type) { + case ParameterEntry::Type::SPLAT: + os << "splat: " + << APInt(entry.value.splat.patternLength * 8, + entry.value.splat.pattern); + break; + case ParameterEntry::Type::DATA: + os << "data: min alignment " << entry.value.data.minimumAlignment + << "B"; + break; + } + }, + "\n"); + os << "\n"; + } +}; + +struct ParameterBuilder { + MLIRContext *context; + StringAttr scope; + StringAttr key; + + ParameterBuilder() = delete; + explicit ParameterBuilder(MLIRContext *context, StringAttr scope, + StringAttr key) + : context(context), scope(scope), key(key) {} + virtual ~ParameterBuilder() = default; + virtual ParameterEntry finalize() = 0; +}; + +struct SplatParameterBuilder : public ParameterBuilder { + Location loc; + int64_t length = 0; + Attribute pattern; + + SplatParameterBuilder(StringAttr scope, StringAttr key, Location loc, + int64_t length, Attribute pattern) + : ParameterBuilder(loc.getContext(), scope, key), loc(loc), + length(length), pattern(pattern) {} + + ParameterEntry finalize() override { + APInt intValue; + APFloat floatValue(0.0f); + if (matchPattern(pattern, m_ConstantFloat(&floatValue))) { + intValue = floatValue.bitcastToAPInt(); + } else if (matchPattern(pattern, m_ConstantInt(&intValue))) { + } else { + assert(false && "ints/floats only; should have been verified"); + } + return ParameterEntry::createSplat( + loc, key, length, intValue.getZExtValue(), intValue.getBitWidth() / 8); + } +}; + +struct DataParameterBuilder : public ParameterBuilder { + IREE::Stream::AffinityAttr affinityAttr; + int64_t maxSize = 0; + int64_t offsetAlignment = 0; + int64_t rangeAlignment = 0; + int64_t currentOffset = 0; + SmallVector locs; + + DataParameterBuilder(StringAttr scope, StringAttr key, + IREE::Stream::AffinityAttr affinityAttr, + IREE::Stream::ResourceConfigAttr resourceConfigAttr) + : ParameterBuilder(resourceConfigAttr.getContext(), scope, key), + affinityAttr(affinityAttr), + maxSize(resourceConfigAttr.getMaxAllocationSize()), + offsetAlignment(resourceConfigAttr.getMinBufferOffsetAlignment()), + rangeAlignment(resourceConfigAttr.getMinBufferRangeAlignment()) {} + + // Reserves |length| bytes of storage in the parameter and returns the aligned + // offset within the parameter if there is sufficient capacity remaining. + std::optional tryReserve(Location loc, int64_t length) { + int64_t alignedOffset = IREE::Util::align(currentOffset, offsetAlignment); + int64_t alignedLength = IREE::Util::align(length, rangeAlignment); + int64_t newOffset = std::max(currentOffset, alignedOffset + alignedLength); + if (newOffset > maxSize) { + // Capacity exceeded. + return std::nullopt; + } + currentOffset = newOffset; + return alignedOffset; + } + + ParameterEntry finalize() override { + return ParameterEntry::createData( + FusedLoc::get(context, locs), key, + IREE::Util::align(currentOffset, rangeAlignment), offsetAlignment); + } +}; + +// A subrange of an output parameter produced by an encoding expression. +// Note that a single expression may produce multiple output subranges. +struct ParameterSubrange { + // Parameter index scope. + StringAttr scope; + // Parameter key the subrange is referencing. + StringAttr key; + // Offset within the parameter where the produced value will be placed. + // Aligned to the requirements of the parameter. + int64_t offset = 0; + // Length of subrange the produced value occupies. Note that if padding is + // present this may not extend to all of the parameter storage. + int64_t length = 0; + + ParameterSubrange(StringAttr scope, StringAttr key, int64_t offset, + int64_t length) + : scope(scope), key(key), offset(offset), length(length) {} + + // Creates a named parameter attribute for this subrange with the given total + // length of the storage parameter. + IREE::Stream::NamedParameterAttr + createNamedParameterAttr(int64_t totalLength) const { + Type i8Type = IntegerType::get(scope.getContext(), 8); + auto parameterType = RankedTensorType::get({totalLength}, i8Type); + return IREE::Stream::NamedParameterAttr::get( + scope.getContext(), parameterType, scope, key, DictionaryAttr{}); + } +}; + +// Map of expression outputs to a reserved parameter subrange. +using OutputParameterSubrangeMap = + llvm::MapVector; + +// Incremental ParameterIndex builder with support for parameter combining. +class ParameterIndexBuilder { +public: + ParameterIndexBuilder(StringAttr scope, const EncodingPolicy &encodingPolicy) + : scope(scope), encodingPolicy(encodingPolicy) {} + + FailureOr insertExpr(const EncodingExpr *expr) { + OutputParameterSubrangeMap outputMap; + for (const auto &output : expr->outputs) { + FailureOr> subrangeOr; + if (output.isSupportedSplat() && + output.size > encodingPolicy.serializeSplatSizeThreshold) { + subrangeOr = insertSplatOutput(expr, &output); + } else { + subrangeOr = insertDataOutput(expr, &output); + } + if (failed(subrangeOr)) { + return failure(); + } + if (subrangeOr->has_value()) { + outputMap.insert( + std::make_pair(&output, std::move(subrangeOr->value()))); + } + } + return outputMap; + } + + ParameterIndex finalize() { + SmallVector parameterLocs; + SmallVector parameterEntries; + for (auto ¶meter : parameters) { + ParameterEntry parameterEntry = parameter->finalize(); + parameterLocs.push_back(parameterEntry.getLoc()); + parameterEntries.push_back(std::move(parameterEntry)); + } + ParameterIndex index{FusedLoc::get(scope.getContext(), parameterLocs)}; + index.scope = scope; + index.entries = std::move(parameterEntries); + return index; + } + +private: + StringAttr makeParameterName() { + return StringAttr::get(scope.getContext(), + Twine("parameter") + std::to_string(nextId++)); + } + + FailureOr> + insertSplatOutput(const EncodingExpr *expr, + const EncodingExpr::Output *output) { + auto splatBuilder = std::make_unique( + scope, makeParameterName(), expr->getLoc(), output->size, + output->splatPattern); + auto subrange = ParameterSubrange(splatBuilder->scope, splatBuilder->key, 0, + output->size); + parameters.push_back(std::move(splatBuilder)); + return {subrange}; + } + + // Inserts a data output into the parameter index, packing into existing + // parameters when possible. + // + // Uses first-fit allocation: iterates through existing parameters in order + // and places the output in the first one with matching affinity and available + // space. This is simple and fast for compilation, though not optimal for + // minimizing fragmentation. A best-fit or sorted-by-size approach could + // improve memory efficiency if parameter packing becomes a bottleneck. + FailureOr> + insertDataOutput(const EncodingExpr *expr, + const EncodingExpr::Output *output) { + if (encodingPolicy.packParameters) { + for (auto *existingBuilder : dataParameters) { + if (existingBuilder->affinityAttr == expr->affinityAttr) { + std::optional offset = + existingBuilder->tryReserve(expr->getLoc(), output->size); + if (offset.has_value()) { + auto subrange = + ParameterSubrange(existingBuilder->scope, existingBuilder->key, + offset.value(), output->size); + return {subrange}; + } + } + } + } + + auto newBuilder = std::make_unique( + scope, makeParameterName(), expr->affinityAttr, + expr->getResourceConfigAttr()); + std::optional offset = + newBuilder->tryReserve(expr->getLoc(), output->size); + if (offset.has_value()) { + auto subrange = ParameterSubrange(newBuilder->scope, newBuilder->key, + offset.value(), output->size); + dataParameters.push_back(newBuilder.get()); + parameters.push_back(std::move(newBuilder)); + return {subrange}; + } + + LLVM_DEBUG(llvm::dbgs() + << " ! failed to reserve " << output->size + << " bytes for output at " << output->getLoc() << "\n"); + return mlir::emitError(output->getLoc(), + "failed to reserve parameter space for output\n"); + } + + StringAttr scope; + const EncodingPolicy &encodingPolicy; + SmallVector> parameters; + SmallVector dataParameters; + unsigned nextId = 0; +}; + +//===----------------------------------------------------------------------===// +// Encoder work scheduling +//===----------------------------------------------------------------------===// + +// A target configuration for a set of specialized encodings. +// Contains the parameter indices (what parameters will be produced), the +// execution schedule (steps), and a lookup map from (scope, key) to entries. +// Targets may specialize for multiple devices simultaneously if the +// configuration is for heterogeneous execution and may produce multiple +// parameter indices. Currently only a single "all" target is supported. +struct TargetPlan { + // Name of the target for the user to specify in tools. + std::string name; + + // Affinity of the device performing the encoding in the encoder module. + // When cross-targeting encoders this will differ from the devices in the + // original program. For consistency it always has a new name. + IREE::Stream::AffinityAttr affinityAttr; + + // Parameter indices produced by the target. + std::vector parameterIndices; + + // A map of (scope, key) to the parameter in the specified index. + DenseMap, ParameterEntry> parameterEntries; + + // A discrete step in the encoding process. + struct Step { + std::string description; + int64_t globalByteOffset = 0; + int64_t globalByteLength = 0; + const EncodingExpr *expr = nullptr; + OutputParameterSubrangeMap outputMap; + + Location getLoc() const { return expr->getLoc(); } + }; + + // An unordered sequence of encoding steps. + // Steps _generally_ start in order but may end in any order and can be + // considered more as "chunks of work" than some point on a timeline. + // Each step may encode more than one parameter. + SmallVector steps; + + // Cumulative size of all writes to all parameters in all scopes. + int64_t globalByteSize = 0; + + // Appends an encoding expression and its output mapping to the schedule. + void appendExpr(const EncodingExpr *expr, + OutputParameterSubrangeMap outputMap) { + Step step; + + // Today we just name the steps in sequence, but could use the parameter + // names in the output map. + step.description = "step" + std::to_string(steps.size()); + + // Since order is largely undefined and each step may produce multiple + // parameters we track a cumulative write offset in a virtual global + // parameter file and use that. Tools can present % completed or use the + // virtual subranges to indicate fine-grained progress. + step.globalByteOffset = globalByteSize; + step.globalByteLength = std::accumulate( + expr->outputs.begin(), expr->outputs.end(), int64_t{0}, + [](int64_t sum, const EncodingExpr::Output &output) -> int64_t { + return sum + output.size; + }); + LLVM_DEBUG(DBGS() << "defining step `" << step.description << "` (at " + << step.globalByteOffset << " for " + << step.globalByteLength << ")\n"); + globalByteSize += step.globalByteLength; + + step.expr = expr; + step.outputMap = std::move(outputMap); + steps.push_back(std::move(step)); + } + + // Returns the named parameter reference attribute for the given subrange. + IREE::Stream::NamedParameterAttr + getNamedParameterAttr(const ParameterSubrange &subrange) const { + auto parameterEntryIt = + parameterEntries.find(std::make_pair(subrange.scope, subrange.key)); + assert(parameterEntryIt != parameterEntries.end() && + "map must contain all entries"); + const ParameterEntry ¶meterEntry = parameterEntryIt->second; + return subrange.createNamedParameterAttr(parameterEntry.length); + } +}; + +//===----------------------------------------------------------------------===// +// Parameter encoder construction +//===----------------------------------------------------------------------===// + +// Adds a function to the new encoder module that tries to automatically detect +// the target configuration given the list of HAL devices. The intent is that it +// performs the same device detection logic the main module performs at runtime +// but with a provided list instead of what the HAL module provides: the only +// device(s) we have at the global level are those of the host performing the +// encoding. +// +// Signature, returning a string constant target name: +// util.func public @__encode_parameter_detect_target( +// %devices: !util.list) -> !util.buffer +static void addAutoTargetDetectFunc(Location loc, + ArrayRef targetPlans, + OpBuilder &encoderBuilder) { + std::string funcName = "__encode_parameter_detect_target"; + LLVM_DEBUG(DBGS() << "emitting auto target detection function: " << funcName + << "...\n"); + + auto bufferType = encoderBuilder.getType(); + auto deviceType = encoderBuilder.getType(); + auto deviceListType = + encoderBuilder.getType(deviceType); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({deviceListType}, {bufferType})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + funcOp->setAttr( + "iree.reflection", + funcBuilder.getDictionaryAttr({ + NamedAttribute("iree.encode.function", + funcBuilder.getStringAttr("detect_target")), + })); + + // Always unconditionally choose the first target today. + assert(!targetPlans.empty()); + Value targetName = IREE::Util::BufferConstantOp::create( + funcBuilder, loc, targetPlans.front().name); + IREE::Util::ReturnOp::create(funcBuilder, loc, targetName); +} + +// Builds a struct of `[scope name, [entries]]`. +// Supported entry types: +// +// SPLAT (iree_io_parameter_archive_builder_add_splat_entry): +// [0]: i64 type=0 +// [1]: !util.buffer key (not be NUL terminated) +// [2]: !util.buffer metadata (optional) +// [3]: i64 data length (total size of the parameter) +// [4]: i64 pattern (only up to pattern_bytes_length bytes used) +// [5]: i64 pattern_byte_length +// +// DATA (iree_io_parameter_archive_builder_add_data_entry): +// [0]: i64 type=1 +// [1]: !util.buffer key (not be NUL terminated) +// [2]: !util.buffer metadata (optional) +// [3]: i64 data length (total size of the parameter) +// [4]: i64 minimum alignment (or 0 if don't care) +static Value buildParameterIndexStruct(const ParameterIndex ¶meterIndex, + IntegerSet &i64Set, + OpBuilder &builder) { + LLVM_DEBUG({ + DBGS() << "emitting index with scope: `" << parameterIndex.scope << "` (" + << parameterIndex.entries.size() << " entries)\n"; + parameterIndex.dump(llvm::dbgs()); + }); + + auto loc = parameterIndex.loc; + auto listType = builder.getType(); + + Value scopeName = + IREE::Util::BufferConstantOp::create(builder, loc, parameterIndex.scope); + + SmallVector entryValues; + for (auto &entry : parameterIndex.entries) { + Location entryLoc = entry.getLoc(); + Value typeValue = i64Set.get(static_cast(entry.type)); + Value keyValue = + IREE::Util::BufferConstantOp::create(builder, entryLoc, entry.key); + Value metadataValue = IREE::Util::BufferConstantOp::createOrNull( + builder, entryLoc, entry.metadata); + SmallVector structFields = { + typeValue, + keyValue, + metadataValue, + i64Set.get(entry.length), + }; + switch (entry.type) { + case ParameterEntry::Type::SPLAT: + structFields.push_back(i64Set.get(entry.value.splat.pattern)); + structFields.push_back(i64Set.get(entry.value.splat.patternLength)); + break; + case ParameterEntry::Type::DATA: + structFields.push_back(i64Set.get(entry.value.data.minimumAlignment)); + break; + } + Value entryValue = IREE::Util::ListConstructOp::create( + builder, entryLoc, listType, structFields); + entryValues.push_back(entryValue); + } + + Value entryList = + IREE::Util::ListConstructOp::create(builder, loc, listType, entryValues); + Value indexStruct = IREE::Util::ListConstructOp::create( + builder, loc, listType, {scopeName, entryList}); + return indexStruct; +}; + +// Adds a function to the new encoder module that returns the parameter indices +// produced for a given target. A single target may result in more than one +// parameter file in cases where we want to shard parameters. +// +// Signature: +// util.func public @__encode_parameter_indices_TARGET() -> !util.list +static void addTargetIndexBuilderFunc(Location loc, + const TargetPlan &targetPlan, + OpBuilder &encoderBuilder) { + auto listType = encoderBuilder.getType(); + std::string funcName = "__encode_parameter_indices_" + targetPlan.name; + LLVM_DEBUG(DBGS() << "emitting index builder function: " << funcName + << "...\n"); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({}, {listType})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + + // Reflection information lets the tool list available targets and required + // scopes without having to call each function. + // VM bytecode only supports string/integer reflection attributes, so we + // encode scopes as a comma-separated string. + // Note: Empty string values crash the flatbuffer serializer, so we only + // include the scopes attribute if there are non-empty scopes. + std::string scopesStr; + for (const auto ¶meterIndex : targetPlan.parameterIndices) { + StringRef scope = parameterIndex.scope.getValue(); + if (scope.empty()) { + continue; + } + if (!scopesStr.empty()) { + scopesStr += ","; + } + scopesStr += scope; + } + SmallVector reflectionAttrs; + reflectionAttrs.push_back(NamedAttribute( + "iree.encode.function", funcBuilder.getStringAttr("indices"))); + reflectionAttrs.push_back(NamedAttribute( + "iree.encode.target", funcBuilder.getStringAttr(targetPlan.name))); + if (!scopesStr.empty()) { + reflectionAttrs.push_back(NamedAttribute( + "iree.encode.scopes", funcBuilder.getStringAttr(scopesStr))); + } + funcOp->setAttr("iree.reflection", + funcBuilder.getDictionaryAttr(reflectionAttrs)); + + IntegerSet i64Set(loc, funcBuilder); + SmallVector indicesStructs; + for (const auto ¶meterIndex : targetPlan.parameterIndices) { + indicesStructs.push_back( + buildParameterIndexStruct(parameterIndex, i64Set, funcBuilder)); + } + + Value indicesList = IREE::Util::ListConstructOp::create( + funcBuilder, loc, listType, indicesStructs); + IREE::Util::ReturnOp::create(funcBuilder, loc, {indicesList}); +} + +// Adds a function to the new encoder module that produces a list of steps +// involved in encoding the parameters for a specific target. Steps do not +// correspond 1:1 with parameters in either the input or output module and may +// complete in any order so we return a list of structs and fences that can be +// used to observe the state and report on progress. If progress capture is +// desired the list needs to be passed back into the encoder function so that it +// can instrument the encoding process with the fences. +// +// A global byte range is attached to each step for presentation purposes only: +// multiple parameter indices may be constructed and an individual step may +// produce values for each. Tools may only use the global byte range to denote +// cumulative bytes written by each step. +// +// Each step entry consists of: +// [0]: i64 reserved 0 +// [1]: !hal.fence indicating encoding has begun +// [2]: !hal.fence indicating encoding has ended +// [3]: !util.buffer descriptive comment (not be NUL terminated) +// [4]: i64 synthetic global byte offset +// [5]: i64 synthetic global byte length +// +// Signature: +// util.func public @__encode_parameter_steps_TARGET() -> !util.list +static void addTargetEncoderStepsFunc(Location loc, + const TargetPlan &targetPlan, + OpBuilder &encoderBuilder) { + std::string funcName = "__encode_parameter_steps_" + targetPlan.name; + LLVM_DEBUG(DBGS() << "emitting encoder steps function: " << funcName + << "...\n"); + auto listType = encoderBuilder.getType(); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({}, {listType})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + funcOp->setAttr( + "iree.reflection", + funcBuilder.getDictionaryAttr({ + NamedAttribute("iree.encode.function", + funcBuilder.getStringAttr("steps")), + NamedAttribute("iree.encode.target", + funcBuilder.getStringAttr(targetPlan.name)), + })); + + Type deviceType = funcBuilder.getType(); + Value deviceValue = + IREE::Stream::ContextResolveOp::create(funcBuilder, loc, {deviceType}, + targetPlan.affinityAttr) + .getResult(0); + + SmallVector stepStructs; + IntegerSet i64Set(loc, funcBuilder); + for (auto &step : targetPlan.steps) { + Value beginFence = IREE::HAL::FenceCreateOp::create( + funcBuilder, loc, deviceValue, IREE::HAL::FenceFlagBitfield::None); + Value endFence = IREE::HAL::FenceCreateOp::create( + funcBuilder, loc, deviceValue, IREE::HAL::FenceFlagBitfield::None); + Value descriptionValue = IREE::Util::BufferConstantOp::create( + funcBuilder, loc, step.description); + stepStructs.push_back(IREE::Util::ListConstructOp::create( + funcBuilder, loc, listType, + { + i64Set.get(0), + beginFence, + endFence, + descriptionValue, + i64Set.get(step.globalByteOffset), + i64Set.get(step.globalByteLength), + })); + } + + Value stepsList = IREE::Util::ListConstructOp::create(funcBuilder, loc, + listType, stepStructs); + IREE::Util::ReturnOp::create(funcBuilder, loc, stepsList); +} + +using MarkObjectReference = + std::function; + +// Adds a function to the new encoder module that encodes parameters for a +// specific target. Encoding will wait for the provided `wait_fence` prior to +// starting any processing and signal the provided `signal_fence` when all +// processing has completed. The steps list is the result of the paired +// `__encode_parameter_steps_TARGET` function and the fences within will be +// signaled as encoding progresses. +// +// Signature (follows standard coarse-fences ABI with fences at end): +// util.func public @__encode_parameters_TARGET( +// %steps: !util.list, +// %wait_fence: !hal.fence, +// %signal_fence: !hal.fence) +static LogicalResult +addTargetEncoderFunc(Location loc, const TargetPlan &targetPlan, + const MarkObjectReference &markObjectReference, + OpBuilder &encoderBuilder) { + std::string funcName = "__encode_parameters_" + targetPlan.name; + LLVM_DEBUG(DBGS() << "emitting encoder function: " << funcName << "...\n"); + auto fenceType = encoderBuilder.getType(); + auto listType = encoderBuilder.getType(); + auto funcOp = IREE::Util::FuncOp::create( + encoderBuilder, loc, funcName, + encoderBuilder.getFunctionType({listType, fenceType, fenceType}, {})); + funcOp.setVisibility(SymbolTable::Visibility::Public); + OpBuilder funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); + funcOp->setAttr( + "iree.reflection", + funcBuilder.getDictionaryAttr({ + NamedAttribute("iree.abi.model", + funcBuilder.getStringAttr("coarse-fences")), + NamedAttribute("iree.encode.function", + funcBuilder.getStringAttr("encode")), + NamedAttribute("iree.encode.target", + funcBuilder.getStringAttr(targetPlan.name)), + })); + + // TODO(benvanik): make steps optional, probably by just calling the steps + // function internally when not provided so that we can keep all the encoding + // code branch-free. For now we require it be provided. + + Value waitFence = funcOp.getArgument(1); + Value signalFence = funcOp.getArgument(2); + + Type timepointType = funcBuilder.getType(); + Value lastTimepoint = IREE::Stream::TimepointImportOp::create( + funcBuilder, loc, timepointType, waitFence, targetPlan.affinityAttr); + + // Use explicit transient lifetime for all output slab allocations. + // This storage is allocated at the start of each step and deallocated at the + // end, making transient the correct lifetime. + Type resourceType = funcBuilder.getType( + IREE::Stream::Lifetime::Transient); + IndexSet indexSet(loc, funcBuilder); + IntegerSet i64Set(loc, funcBuilder); + for (const auto &step : targetPlan.steps) { + Location stepLoc = step.getLoc(); + + // Build a map of scope name to the outputs going to it and their parameter + // references. Note that this mapping is target-specific (as each target may + // have a different mix of parameters and parameter sizes due to differences + // in encodings). + struct OutputReservation { + const EncodingExpr::Output *output = nullptr; + const ParameterSubrange *parameterSubrange = nullptr; + IREE::Stream::NamedParameterAttr parameterAttr; + size_t slabOffsetOrdinal = 0; + }; + llvm::MapVector> scopeOutputs; + SmallVector outputSizes; + for (auto &output : step.expr->outputs) { + auto it = step.outputMap.find(&output); + if (it == step.outputMap.end()) { + continue; // no serialization required + } + const ParameterSubrange &subrange = it->second; + OutputReservation reservation; + reservation.output = &output; + reservation.parameterSubrange = &subrange; + reservation.parameterAttr = targetPlan.getNamedParameterAttr(subrange); + reservation.slabOffsetOrdinal = outputSizes.size(); + scopeOutputs[reservation.parameterAttr.getScope()].push_back(reservation); + outputSizes.push_back(indexSet.get(subrange.length)); + } + + // Allocate transient storage for all the parameter outputs. + // If we were overlapping we'd want to get this from a ringbuffer. + // TODO(benvanik): stream.async.ringbuffer-style ops for safely doing bump + // pointer allocation with timeline-awareness at this level. + auto reservationPackOp = IREE::Stream::ResourcePackOp::create( + funcBuilder, stepLoc, /*offset=*/nullptr, outputSizes, + targetPlan.affinityAttr); + Value outputSlabSize = reservationPackOp.getTotalLength(); + auto outputSlabAllocaOp = IREE::Stream::ResourceAllocaOp::create( + funcBuilder, stepLoc, resourceType, timepointType, outputSlabSize, + /*indeterminate_lifetime=*/nullptr, lastTimepoint, + targetPlan.affinityAttr); + Value outputSlab = outputSlabAllocaOp.getResult(); + + // Note: Input parameters are NOT included in this slab allocation. + // Inputs are loaded via stream.async.constant operations (cloned below) + // which reference external parameter storage and don't require allocation. + // Only outputs need slab allocation as transient working memory before + // being scattered to their final parameter locations. + // + // Wait for the slab to be ready before we transition back into async IR. + outputSlab = IREE::Stream::TimepointAwaitOp::create( + funcBuilder, stepLoc, {outputSlab}, {outputSlabSize}, + outputSlabAllocaOp.getResultTimepoint()) + .getResult(0); + + // Clone the expression IR and fix it up for use in the new module. + // We have to remove any affinities referencing the devices in the source + // program and ensure we also bring along any referenced objects + // (executables, etc). + // + // The slice is already in topological order from getBackwardSlice, and + // all captured values from nested regions have been included via + // getUsedValuesDefinedAbove, so we can clone directly without sorting. + // + // AsyncConstantOp with parameter values are converted to + // AsyncParameterLoadOp during cloning because the lowering path through + // ResourceConstantsOp does not preserve await_timepoint. + // AsyncParameterLoadOp lowers directly to CmdParameterLoadOp which does + // preserve await_timepoint. + IRMapping exprMapping; + for (auto *sourceOp : step.expr->ops) { + auto *clonedOp = funcBuilder.clone(*sourceOp, exprMapping); + if (auto affinityOp = + dyn_cast(clonedOp)) { + affinityOp.removeAffinityAttrs(); + } + // Convert AsyncConstantOp with parameter values to AsyncParameterLoadOp. + // This ensures await_timepoint is preserved through lowering, since + // AsyncConstantOp goes through ResourceConstantsOp which drops await. + if (auto constantOp = dyn_cast(clonedOp)) { + if (auto parameterAttr = + dyn_cast(constantOp.getValue())) { + // Extract parameter scope and key from the attribute. + StringAttr scopeAttr = parameterAttr.getScope(); + StringAttr keyAttr = parameterAttr.getKey(); + // Create zero offset for full parameter load. + Value zeroOffset = i64Set.get(0); + Value resultSize = constantOp.getResultSize(); + // Create AsyncParameterLoadOp with the wait fence as await. + auto paramLoadOp = IREE::Stream::AsyncParameterLoadOp::create( + funcBuilder, constantOp.getLoc(), + constantOp.getResult().getType(), + funcBuilder.getType(), + /*await_timepoint=*/lastTimepoint, scopeAttr, keyAttr, zeroOffset, + resultSize, targetPlan.affinityAttr); + // Await the result timepoint to get a resolved resource that can be + // used by streamable ops without explicit synchronization. + auto awaitOp = IREE::Stream::TimepointAwaitOp::create( + funcBuilder, constantOp.getLoc(), paramLoadOp.getResult(), + resultSize, paramLoadOp.getResultTimepoint()); + // Update mapping to use the awaited result. + exprMapping.map(sourceOp->getResult(0), awaitOp.getResults().front()); + // Erase the cloned AsyncConstantOp. + constantOp.erase(); + clonedOp = awaitOp; + } else { + // Non-parameter constant: just set await_timepoint. + constantOp.getAwaitTimepointMutable().assign(lastTimepoint); + } + } + auto symbolUses = SymbolTable::getSymbolUses(clonedOp); + if (symbolUses.has_value()) { + for (auto &use : symbolUses.value()) { + if (failed(markObjectReference(clonedOp, use.getSymbolRef()))) { + return failure(); + } + } + } + } + + // Scatter the outputs into the parameter(s) for each scope. + for (auto [scope, outputReservations] : scopeOutputs) { + for (auto &reservation : outputReservations) { + Location outputLoc = reservation.output->getLoc(); + Value outputValue = + exprMapping.lookup(reservation.output->producedValue); + Value packedOffset = + reservationPackOp.getPackedOffsets()[reservation.slabOffsetOrdinal]; + Value packedEnd = + indexSet.add(packedOffset, reservation.parameterSubrange->length); + Value outputSize = indexSet.get(reservation.parameterSubrange->length); + auto updateOp = IREE::Stream::AsyncUpdateOp::create( + funcBuilder, outputLoc, outputSlab.getType(), outputSlab, + outputSlabSize, packedOffset, packedEnd, outputValue, outputSize, + targetPlan.affinityAttr); + outputSlab = updateOp.getResult(); + } + } + auto outputBarrierOp = IREE::Stream::TimepointBarrierOp::create( + funcBuilder, step.getLoc(), outputSlab, outputSlabSize, + targetPlan.affinityAttr); + outputSlab = outputBarrierOp.getResult(); + + // Scatter parameters from the transient slab into each target scope. + SmallVector scatterTimepoints; + for (auto [scope, outputReservations] : scopeOutputs) { + SmallVector outputLocs; + SmallVector sourceOffsets; + SmallVector sourceEnds; + SmallVector sourceLengths; + SmallVector targetKeys; + SmallVector targetOffsets; + for (auto &reservation : outputReservations) { + outputLocs.push_back(reservation.output->getLoc()); + Value packedOffset = + reservationPackOp.getPackedOffsets()[reservation.slabOffsetOrdinal]; + Value packedSize = indexSet.get(reservation.parameterSubrange->length); + sourceOffsets.push_back(packedOffset); + sourceLengths.push_back(packedSize); + targetKeys.push_back(reservation.parameterAttr.getKey()); + targetOffsets.push_back( + i64Set.get(reservation.parameterSubrange->offset)); + } + // Compute source ends (offset + length) for async parameter scatter. + for (auto [offset, length] : + llvm::zip_equal(sourceOffsets, sourceLengths)) { + auto end = funcBuilder.createOrFold( + funcBuilder.getFusedLoc(outputLocs), offset, length); + sourceEnds.push_back(end); + } + auto scatterOp = IREE::Stream::AsyncParameterScatterOp::create( + funcBuilder, funcBuilder.getFusedLoc(outputLocs), outputSlab, + outputSlabSize, sourceOffsets, sourceEnds, sourceLengths, scope, + funcBuilder.getArrayAttr(targetKeys), targetOffsets, + outputBarrierOp.getResultTimepoint(), targetPlan.affinityAttr); + // AsyncParameterScatterOp returns (resource, timepoint) tuple. + outputSlab = scatterOp.getResult(); + scatterTimepoints.push_back(scatterOp.getResultTimepoint()); + } + Value scattersTimepoint = IREE::Stream::TimepointJoinOp::create( + funcBuilder, stepLoc, scatterTimepoints); + + // Deallocate the output slab (now the scattered resource). + Value deallocaTimepoint = IREE::Stream::ResourceDeallocaOp::create( + funcBuilder, stepLoc, outputSlab, outputSlabSize, + /*prefer_origin=*/false, scattersTimepoint, targetPlan.affinityAttr); + + lastTimepoint = deallocaTimepoint; + } + + // Chain the final timepoint (which depends on all steps via the loop above) + // with the external signal fence. This signals completion of all encoding + // steps. We use a single chain at the end rather than chaining after each + // step because: (1) the function has only one signal fence parameter, and + // (2) callers wait on the fence to know when all encoding is complete, not + // individual steps. + IREE::Stream::TimepointChainExternalOp::create(funcBuilder, funcOp.getLoc(), + lastTimepoint, {signalFence}, + targetPlan.affinityAttr); + + IREE::Util::ReturnOp::create(funcBuilder, loc); + + return success(); +} + +// Replaces all encoded exprs in the original module with loads/gathers from the +// new encoded parameters. +static void replaceEncodedExprs(ArrayRef targetPlans) { + // TODO: support multiple targets by emitting a big switch, a detection + // function, and then conditionally execute each plan. Each plan should + // encompass all the required expressions but heterogeneous makes things + // more complicated in a way I can't yet see. For now we assume all + // expressions are grouped into a single target and always evaluated (vs. + // conditionally evaluated per target). + const TargetPlan &targetPlan = targetPlans.front(); + + // Since expressions may share ops we accumulate all the root ops we believe + // are dead and then burn them down after we're done accessing them. + SmallVector deadOpWorklist; + + // Note that it's possible for targets to not have all expressions: if we are + // specializing a heterogeneous module we may produce one encoder module per + // target each with its own set of placed parameters. + IndexSetCollection indexSetCollection; + for (auto &step : targetPlan.steps) { + // Collect external timepoints once per expression (shared by all outputs). + Value expressionAwaitTimepoint; + if (!step.expr->outputs.empty()) { + OpBuilder timepointBuilder(step.expr->outputs.front().storeOp); + expressionAwaitTimepoint = + collectExternalTimepoints(*step.expr, timepointBuilder); + } + + for (auto &output : step.expr->outputs) { + auto it = step.outputMap.find(&output); + if (it == step.outputMap.end()) { + continue; // no serialization required + } + auto *indexSet = indexSetCollection.get(output.storeOp); + OpBuilder builder(output.storeOp); + + // Since each target may have a unique size and packing of their + // encoded parameters we need to reference the plan-specific parameter. + const ParameterSubrange &subrange = it->second; + auto parameterAttr = targetPlan.getNamedParameterAttr(subrange); + const int64_t storageSize = parameterAttr.getStorageSize(); + Value storageSizeValue = indexSet->get(storageSize); + + // Embed an inline constant referencing the parameter and slice out the + // subrange (if any). + Value oldValue = output.storeOp.getStoredGlobalValue(); + + Value constantValue = IREE::Stream::AsyncConstantOp::create( + builder, output.getLoc(), oldValue.getType(), + expressionAwaitTimepoint, parameterAttr, storageSizeValue, + step.expr->affinityAttr); + Value newValue = constantValue; + if (subrange.offset != 0 || subrange.length != storageSize) { + // TODO(benvanik): use AsyncSliceOp instead; today ElideAsyncCopiesPass + // does not do any IPO and inserting slices here forces each parameter + // to be cloned at execution. Inserting ResourceSubviewOp is only barely + // safe here because we otherwise don't allow it and know we can run a + // propagation pass immediately after this pass. It's shady, though, and + // may block other optimizations. + // + // Should be: + // newValue = IREE::Stream::AsyncSliceOp::create( + // builder, output.getLoc(), constantValue, storageSizeValue, + // indexSet->get(subrange.offset), + // indexSet->add(subrange.offset, subrange.length), + // indexSet->get(subrange.length), step.expr->affinityAttr); + newValue = IREE::Stream::ResourceSubviewOp::create( + builder, output.getLoc(), constantValue, storageSizeValue, + indexSet->get(subrange.offset), indexSet->get(subrange.length)); + } + output.storeOp.setStoredGlobalValue(newValue); + + // Now that we've replaced a use (but maybe not all uses!) we may be able + // to kill one or more ops. Since expressions/outputs may share IR we + // enqueue the deletion check to the end. + if (auto *producerRootOp = oldValue.getDefiningOp()) { + // Enqueue ops with no uses for pruning - pruneDeadOps will determine + // if they're actually safe to delete. + if (producerRootOp->use_empty()) { + deadOpWorklist.push_back(producerRootOp); + } + } + } + } + + // Recursively delete unused operations and their producers. + pruneDeadOps(std::move(deadOpWorklist)); +} + +//===----------------------------------------------------------------------===// +// --iree-stream-split-parameter-encoder +//===----------------------------------------------------------------------===// + +// Placeholder planning for taking an expression set and producing a +// target-specialized set of parameter indices and an encoding schedule. +// +// TODO: use analysis to identify a set of a target configurations. This may +// be too tricky to do automatically (what would we call the +// configurations?) and require the user to specify the exact names and +// constituent devices. We'd want to take the configuration and prune the +// expression set to those used with involved devices, potentially allow for +// a second specialization round, etc. For now we just have one default +// target and let the tool auto select it. +static FailureOr planDefaultTarget(const EncodingExprSet &exprSet, + StringAttr scope, + EncodingPolicy encodingPolicy) { + LLVM_DEBUG( + DBGS() + << "building parameter index and schedule for default target in scope `" + << scope << "`\n"); + + TargetPlan targetPlan; + targetPlan.name = "all"; + + // For now we leave the encoding host target unspecified. This allows the + // user to compile for any device they want. We could copy the device from + // the source module if we wanted to do 1:1 encoding:execution. + targetPlan.affinityAttr = IREE::HAL::DevicePromiseAttr::get( + scope.getContext(), StringAttr::get(scope.getContext(), "__device_0"), + -1); + + ParameterIndexBuilder parameterIndexBuilder(scope, encodingPolicy); + for (int i = 0; i < exprSet.exprs.size(); ++i) { + const EncodingExpr &expr = exprSet.exprs[i]; + auto outputMapOr = parameterIndexBuilder.insertExpr(&expr); + if (failed(outputMapOr)) { + return mlir::emitError(expr.getLoc(), + "failed to add expression to parameter index"); + } + targetPlan.appendExpr(&expr, std::move(outputMapOr.value())); + } + ParameterIndex parameterIndex = parameterIndexBuilder.finalize(); + for (auto &entry : parameterIndex.entries) { + targetPlan.parameterEntries[std::make_pair(scope, entry.key)] = entry; + } + targetPlan.parameterIndices.push_back(std::move(parameterIndex)); + return targetPlan; +} + +struct SplitParameterEncoderPass + : public IREE::Stream::impl::SplitParameterEncoderPassBase< + SplitParameterEncoderPass> { + using IREE::Stream::impl::SplitParameterEncoderPassBase< + SplitParameterEncoderPass>::SplitParameterEncoderPassBase; + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::ModuleOp moduleOp = getOperation(); + + // Scan the program and find candidate expressions. + EncodingPolicy encodingPolicy; + encodingPolicy.includeUnmodified = + mode == IREE::Stream::ParameterEncoderMode::Consolidate; + encodingPolicy.hoistParameterExpressions = hoistParameterExpressions; + encodingPolicy.hoistConstantExpressions = hoistConstantExpressions; + encodingPolicy.maxEncodingGrowthFactor = maxEncodingGrowthFactor; + + EncodingExprSet exprSet = gatherEncodingExprSet(moduleOp, encodingPolicy); + + // Filter expressions by policy (size growth, expression type). + EncodingExprSet filteredExprSet; + for (const auto &expr : exprSet.exprs) { + if (shouldHoistExpression(expr, encodingPolicy)) { + filteredExprSet.exprs.push_back(expr); + } else { + LLVM_DEBUG(DBGS() << "skipping expression based on policy\n"); + } + } + + if (filteredExprSet.empty()) { + // No candidates detected (or none the policy approves) so no-op. + // + // The user invoking this pass did ask for a new file, though, so we need + // to at least delete any existing one so the user doesn't get confused + // (old artifacts from a run where we did write something carried across). + LLVM_DEBUG(DBGS() << "no candidate expressions detected; skipping pass " + "and deleting existing output file\n"); + if (!outputFile.empty()) { + (void)llvm::sys::fs::remove(outputFile); + } + return; + } + + // Create the new encoder module we'll be populating. Note that we may have + // multiple targets that contribute functions to the module. + OwningOpRef encoderModuleOpRef = + mlir::ModuleOp::create(moduleOp.getLoc(), "encoder"); + mlir::ModuleOp encoderModuleOp = *encoderModuleOpRef; + encoderModuleOp->setAttr( + "iree.reflection", + DictionaryAttr::get( + context, { + NamedAttribute("iree.tool", + StringAttr::get( + context, "iree-encode-parameters")), + NamedAttribute("iree.encode.version", + IntegerAttr::get( + IntegerType::get(context, 32), 1)), + })); + OpBuilder encoderBuilder = + OpBuilder::atBlockBegin(encoderModuleOp.getBody()); + + // Today we only support a single target and build the index for that. + // A few things in here will need to change when we specialize but most of + // the data structures are set up for it. + std::string targetOutputScope = + outputScope.hasValue() ? outputScope.getValue() : ""; + auto defaultTargetOr = planDefaultTarget( + filteredExprSet, StringAttr::get(context, targetOutputScope), + encodingPolicy); + if (failed(defaultTargetOr)) { + return signalPassFailure(); + } + LLVM_DEBUG(DBGS() << "note: default target '" << defaultTargetOr->name + << "' used in place of target specialization\n"); + SmallVector targetPlans; + targetPlans.push_back(std::move(defaultTargetOr).value()); + + // Emit the target detection function used by tools to try to infer the host + // target (useful for post-deployment encoding). + addAutoTargetDetectFunc(moduleOp->getLoc(), targetPlans, encoderBuilder); + + // Emit the per-target metadata functions. + for (const auto &targetPlan : targetPlans) { + addTargetIndexBuilderFunc(moduleOp->getLoc(), targetPlan, encoderBuilder); + addTargetEncoderStepsFunc(moduleOp->getLoc(), targetPlan, encoderBuilder); + } + + // Accumulate object references during cloning so that we can deduplicate + // and clone them all afterward. This avoids interleaving the objects with + // the encoder functions - sometimes that is good, but it's easier to read + // the IR when they aren't. + SymbolTable sourceSymbolTable(moduleOp); + SetVector objectsToClone; + + // Capture the last op (if any) so we can insert after it later. + // This ensures objects go before any encoder functions we're about to add. + Operation *lastOpBeforeEncoders = + &*std::prev(encoderModuleOp.getBody()->end(), 1); + auto markObjectReference = [&](Operation *userOp, + SymbolRefAttr symbolRef) -> LogicalResult { + auto objectNameAttr = symbolRef.getRootReference(); + auto *objectOp = sourceSymbolTable.lookup(objectNameAttr); + if (!objectOp) { + return userOp->emitOpError() + << "reference to undefined symbol " << symbolRef; + } + if (!objectOp->hasTrait()) { + return userOp->emitOpError() + << "reference to non-object-like symbol " << symbolRef; + } + objectsToClone.insert(objectOp); + return success(); + }; + + // Produce all of the encoder functions and gather the objects we need to + // clone. + for (const auto &targetPlan : targetPlans) { + if (failed(addTargetEncoderFunc(moduleOp->getLoc(), targetPlan, + markObjectReference, encoderBuilder))) { + return signalPassFailure(); + } + } + + // Clone all objects referenced by the encoder module. + // Object-like ops are isolated and safe to copy wholesale. + // Insert after the last op that existed before we added encoder functions. + encoderBuilder.setInsertionPointAfter(lastOpBeforeEncoders); + for (Operation *objectOp : objectsToClone) { + encoderBuilder.clone(*objectOp); + } + + // Replace the expressions in the original module with parameter lookups. + replaceEncodedExprs(targetPlans); + + // CSE to clean up the encoder IR before dumping. + // This is important for deduplicating operations shared across multiple + // encoding expressions. When expressions are cloned into the encoder + // module, shared intermediate operations get duplicated at clone time. CSE + // removes these duplicates, ensuring efficient encoder module output. The + // original module likely needs a bit of cleanup but as compilation + // continues that'll happen. + { + IRRewriter rewriter(context); + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, encoderModuleOp); + } + + if (failed(mlir::verify(encoderModuleOp))) { + mlir::emitError(encoderModuleOp.getLoc()) + << "failed to verify produced encoder module"; + return signalPassFailure(); + } + + // Write module to the file specified, or stdout if empty. + if (outputFile.empty()) { + LLVM_DEBUG(DBGS() << "writing encoder module to stdout...\n"); + OpPrintingFlags flags; + encoderModuleOp.print(llvm::outs(), flags); + llvm::outs() << "\n"; + } else { + LLVM_DEBUG(DBGS() << "writing encoder module to '" << outputFile + << "'...\n"); + if (failed(writeModule(encoderModuleOp, outputFile))) { + LLVM_DEBUG(DBGS() << "MODULE WRITE FAILED\n"); + return signalPassFailure(); + } + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index c36bb57955ce..f186cdcc207b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -67,6 +67,7 @@ iree_lit_test_suite( "schedule_execution_timeline_aware.mlir", "specialize_dispatches.mlir", "specialize_encodings.mlir", + "split_parameter_encoder.mlir", "sync_initializers.mlir", "unify_encoding_for_globals.mlir", "verify_affinities.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 09c45c587c0a..16bbbb2da114 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -65,6 +65,7 @@ iree_lit_test_suite( "schedule_execution_timeline_aware.mlir" "specialize_dispatches.mlir" "specialize_encodings.mlir" + "split_parameter_encoder.mlir" "sync_initializers.mlir" "unify_encoding_for_globals.mlir" "verify_affinities.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir new file mode 100644 index 000000000000..8d3cd53aa5bf --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/split_parameter_encoder.mlir @@ -0,0 +1,1770 @@ +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=OVERLAY +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=OVERLAY-MIXED +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=consolidate' %s | FileCheck %s --check-prefix=COMPARE-CONSOLIDATE +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=COMPARE-OVERLAY +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='output-scope=my_custom_scope' %s | FileCheck %s --check-prefix=SCOPE +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='max-encoding-growth-factor=2.0' %s | FileCheck %s --check-prefix=GROWTH2 +// RUN: iree-opt --split-input-file --iree-stream-split-parameter-encoder='mode=overlay' %s | FileCheck %s --check-prefix=EMPTY + +// Tests simple constant with splat initialization. +// This is the most basic case - a global initialized with a constant splat. +// This should NOT be hoisted (no parameter input). + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @simple_constant : !stream.resource +util.global private @simple_constant : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + // CHECK: %[[SPLAT:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C1024]]} + %splat = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + // CHECK: util.global.store %[[SPLAT]], @simple_constant : !stream.resource + util.global.store %splat, @simple_constant : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests computed constant with transformation. +// This tests a constant that undergoes some computation (fill operation). +// Should NOT be hoisted (no parameter input). + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @computed_constant : !stream.resource +util.global private @computed_constant : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK-DAG: %[[C42_I32:.+]] = arith.constant 42 : i32 + %c42_i32 = arith.constant 42 : i32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index + %c256 = arith.constant 256 : index + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + + // Create base splat. + // CHECK: %[[SPLAT:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C1024]]} + %splat = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // Fill a region with different value. + // CHECK: %[[FILLED:.+]] = stream.async.fill %[[C42_I32]], %[[SPLAT]][%[[C0]] to %[[C256]] for %[[C256]]] : i32 -> %[[SPLAT]] as !stream.resource{%[[C1024]]} + %filled = stream.async.fill %c42_i32, %splat[%c0 to %c256 for %c256] : i32 -> %splat as !stream.resource{%c1024} + + // CHECK: util.global.store %[[FILLED]], @computed_constant : !stream.resource + util.global.store %filled, @computed_constant : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests multiple constants with different patterns. +// This tests that the pass can handle multiple globals, some with splat and some +// with more complex initialization. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @constant_a : !stream.resource +util.global private @constant_a : !stream.resource +// CHECK: util.global private @constant_b : !stream.resource +util.global private @constant_b : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK-DAG: %[[C1_I32:.+]] = arith.constant 1 : i32 + %c1_i32 = arith.constant 1 : i32 + // CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index + %c512 = arith.constant 512 : index + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + + // First constant: simple splat. + // CHECK: %[[SPLAT_A:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C512]]} + %splat_a = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c512} + // CHECK: util.global.store %[[SPLAT_A]], @constant_a : !stream.resource + util.global.store %splat_a, @constant_a : !stream.resource + + // Second constant: different splat value and size. + // CHECK: %[[SPLAT_B:.+]] = stream.async.splat %[[C1_I32]] : i32 -> !stream.resource{%[[C1024]]} + %splat_b = stream.async.splat %c1_i32 : i32 -> !stream.resource{%c1024} + // CHECK: util.global.store %[[SPLAT_B]], @constant_b : !stream.resource + util.global.store %splat_b, @constant_b : !stream.resource + + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter transformation. +// This tests loading a parameter and applying a transformation (fill operation). +// This SHOULD be hoisted since it has a parameter input. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @parameter_transformed : !stream.resource +util.global private @parameter_transformed : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Load parameter from external source. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"param0"> : vector<1024xi8> + + // Fill a region with different value. + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : i32 -> %param as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @parameter_transformed + util.global.store %filled, @parameter_transformed : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests pure splat should NOT be hoisted (negative case). +// This tests that a pure splat with no inputs and no transformation is not hoisted. +// The pass should leave this module unchanged. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @pure_splat_only : !stream.resource +util.global private @pure_splat_only : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK: %[[C99_I32:.+]] = arith.constant 99 : i32 + %c99_i32 = arith.constant 99 : i32 + // CHECK: %[[C2048:.+]] = arith.constant 2048 : index + %c2048 = arith.constant 2048 : index + + // Pure splat with no parameter input - should NOT be hoisted. + // CHECK: %[[SPLAT:.+]] = stream.async.splat %[[C99_I32]] : i32 -> !stream.resource{%[[C2048]]} + %splat = stream.async.splat %c99_i32 : i32 -> !stream.resource{%c2048} + + // CHECK: util.global.store %[[SPLAT]], @pure_splat_only : !stream.resource + util.global.store %splat, @pure_splat_only : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests a parameter transformed by a dispatch operation. +// Should be hoisted as it represents expensive computation on a parameter. +// Real-world: Elementwise operations, quantization, or encoding on weights. + +stream.executable private @executable { + stream.executable.export public @dispatch +} + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_dispatch : !stream.resource +util.global private @param_with_dispatch : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + // Load parameter from external source. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"dispatch_param"> : vector<1024xi8> + + // Dispatch performing operation on parameter. + // CHECK-NOT: stream.async.dispatch + %result = stream.async.dispatch @executable::@dispatch[%c1, %c1, %c1](%param[%c0 to %c1024 for %c1024]) : + (!stream.resource{%c1024}) -> !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_dispatch : !stream.resource + util.global.store %result, @param_with_dispatch : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter + splat + dispatch pattern. +// Splat should be cloned to consumers but not serialized (preferCloneToConsumers). +// Real-world: Parameter combined with constant baseline (e.g., weight + bias). + +stream.executable private @executable { + stream.executable.export public @add_dispatch +} + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_splat_dispatch : !stream.resource +util.global private @param_splat_dispatch : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + // Load parameter from external source. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"weights"> : vector<1024xi8> + + // Create splat (should be cloned but not serialized). + %c0_i32 = arith.constant 0 : i32 + // CHECK-NOT: stream.async.splat + %splat = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // Dispatch using both parameter and splat. + // CHECK-NOT: stream.async.dispatch + %result = stream.async.dispatch @executable::@add_dispatch[%c1, %c1, %c1]( + %param[%c0 to %c1024 for %c1024], + %splat[%c0 to %c1024 for %c1024] + ) : (!stream.resource{%c1024}, !stream.resource{%c1024}) -> !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_splat_dispatch : !stream.resource + util.global.store %result, @param_splat_dispatch : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter with metadata operations (subview). +// Metadata operations should not prevent hoisting. +// Real-world: Extract layer weights from combined parameter. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_metadata_ops : !stream.resource +util.global private @param_metadata_ops : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + // Load larger parameter. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"combined_param"> : vector<1024xi8> + + // Extract slice (metadata operation). + // CHECK-NOT: stream.async.slice + %slice = stream.async.slice %param[%c256 to %c512] : !stream.resource{%c1024} -> !stream.resource{%c256} + + // Apply transformation to slice. + %c100_i32 = arith.constant 100 : i32 + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c100_i32, %slice[%c0 to %c256 for %c256] : i32 -> %slice as !stream.resource{%c256} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_metadata_ops : !stream.resource + util.global.store %filled, @param_metadata_ops : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests borderline growth (1.15x) - should pass. +// Within threshold growth should be allowed. +// Real-world: Small padding for alignment. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_acceptable_growth : !stream.resource +util.global private @param_acceptable_growth : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c1180 = arith.constant 1180 : index // ~1.15x growth + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"param_growth"> : vector<1024xi8> + + // Slight growth for padding - should be within 1.2x threshold + %c0 = arith.constant 0 : index + %c156 = arith.constant 156 : index + %c0_i32 = arith.constant 0 : i32 + + // Create slightly larger buffer + // CHECK-NOT: stream.async.splat + %padded = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1180} + + // Copy parameter into padded buffer + // CHECK-NOT: stream.async.update + %result = stream.async.update %param, %padded[%c0 to %c1024] : + !stream.resource{%c1024} -> %padded as !stream.resource{%c1180} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT:.+]] = stream.resource.subview %[[PARAM]] + // CHECK: util.global.store %[[RESULT]], @param_acceptable_growth : !stream.resource + util.global.store %result, @param_acceptable_growth : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Control Flow in Initializers +//===----------------------------------------------------------------------===// + +// Tests scf.for loop with fixed bounds. +// Loop should be unrolled if bounds are constant. +// Real-world: Fixed preprocessing iterations. +// Tests scf.for loop with constant bounds. +// Should hoist the loop and its body since bounds are constant. +// Real-world: Iterative parameter transformations. + +// Encoder module should be generated with scf.for hoisted. +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// Encoder should contain the scf.for loop with result scattered to parameter. +// CHECK: %[[IMPORT_TP:.+]] = stream.timepoint.import {{.+}} %arg1 : (!hal.fence) => !stream.timepoint +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK: %[[PACK_SIZE:.+]]:2 = stream.resource.pack {{.+}} slices({ +// CHECK-NEXT: [0, 0] = %[[C1024]] +// CHECK-NEXT: }) : index +// CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized {{.+}} await(%[[IMPORT_TP]]) => !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[ALLOCA_READY:.+]] = stream.timepoint.await %[[ALLOCA_TP]] => %[[ALLOCA]] : !stream.resource{%[[PACK_SIZE]]#0} +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index +// CHECK-DAG: %[[C0_I64:.+]] = arith.constant 0 : i64 +// CHECK: %[[PARAM_RESOURCE:.+]], %[[PARAM_TP:.+]] = stream.async.parameter.load {{.+}} await(%[[IMPORT_TP]]) "model"::"iterative_param"[%[[C0_I64]]] : !stream.resource{%[[C1024]]} => !stream.timepoint +// CHECK: %[[INPUT:.+]] = stream.timepoint.await %[[PARAM_TP]] => %[[PARAM_RESOURCE]] : !stream.resource{%[[C1024]]} +// CHECK: %[[LOOP_RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG:.+]] = %[[INPUT]]) -> (!stream.resource) { +// CHECK: %[[C100:.+]] = arith.constant 100 : i32 +// CHECK: %[[FILLED:.+]] = stream.async.fill %[[C100]], %[[ARG]][%[[C0]] to %[[C256]] for %[[C256]]] : i32 -> %[[ARG]] as !stream.resource{%[[C1024]]} +// CHECK: scf.yield %[[FILLED]] : !stream.resource +// CHECK: } +// CHECK: %[[UPDATE_END:.+]] = arith.addi %[[PACK_SIZE]]#1, %[[C1024]] : index +// CHECK: %[[UPDATED:.+]] = stream.async.update {{.+}} %[[LOOP_RESULT]], %[[ALLOCA_READY]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]]] : !stream.resource{%[[C1024]]} -> %[[ALLOCA_READY]] as !stream.resource{%[[PACK_SIZE]]#0} +// CHECK: %[[BARRIER_RESULT:.+]], %[[BARRIER_TP:.+]] = stream.timepoint.barrier {{.+}} %[[UPDATED]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[SCATTER_RESULT:.+]], %[[SCATTER_TP:.+]] = stream.async.parameter.scatter {{.+}} await(%[[BARRIER_TP]]) { +// CHECK-NEXT: %[[BARRIER_RESULT]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]] for %[[C1024]]] : !stream.resource{%[[PACK_SIZE]]#0} -> ""::"parameter0"[%[[C0_I64]]] +// CHECK-NEXT: } : !stream.resource => !stream.timepoint +// CHECK: %[[JOIN_TP:.+]] = stream.timepoint.join max(%[[SCATTER_TP]]) => !stream.timepoint +// CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca {{.+}} await(%[[JOIN_TP]]) => %[[SCATTER_RESULT]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: stream.timepoint.chain_external {{.+}} %[[DEALLOCA_TP]] => (%arg2 : !hal.fence) + +// Original module should have parameter load instead of scf.for. +// CHECK-LABEL: util.global private @scf_for_fixed_bounds +util.global private @scf_for_fixed_bounds : !stream.resource + +util.initializer { + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"iterative_param"> : vector<1024xi8> + + // Fixed-bound loop that could be unrolled. + // CHECK-NOT: scf.for + // CHECK-NOT: stream.async.fill + %result = scf.for %i = %c0 to %c3 step %c1 + iter_args(%arg = %param) -> (!stream.resource) { + // Apply transformation in each iteration. + %c100_i32 = arith.constant 100 : i32 + %processed = stream.async.fill %c100_i32, %arg[%c0 to %c256 for %c256] : + i32 -> %arg as !stream.resource{%c1024} + scf.yield %processed : !stream.resource + } + + // Original module loads from parameter instead of executing loop. + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @scf_for_fixed_bounds + util.global.store %result, @scf_for_fixed_bounds : !stream.resource + util.return +} + +// ----- + +// Tests scf.if conditional with compile-time constant condition. +// Should hoist the taken branch if condition is constant. +// Real-world: Conditional initialization for specific target. + +// Encoder module should be generated with scf.if hoisted. +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// Encoder should contain the scf.if conditional with result scattered to parameter. +// CHECK: %[[IMPORT_TP:.+]] = stream.timepoint.import {{.+}} %arg1 : (!hal.fence) => !stream.timepoint +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK: %[[PACK_SIZE:.+]]:2 = stream.resource.pack {{.+}} slices({ +// CHECK-NEXT: [0, 0] = %[[C1024]] +// CHECK-NEXT: }) : index +// CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized {{.+}} await(%[[IMPORT_TP]]) => !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[ALLOCA_READY:.+]] = stream.timepoint.await %[[ALLOCA_TP]] => %[[ALLOCA]] : !stream.resource{%[[PACK_SIZE]]#0} +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index +// CHECK-DAG: %[[C0_I64:.+]] = arith.constant 0 : i64 +// CHECK: %[[PARAM_RESOURCE:.+]], %[[PARAM_TP:.+]] = stream.async.parameter.load {{.+}} await(%[[IMPORT_TP]]) "model"::"conditional_param"[%[[C0_I64]]] : !stream.resource{%[[C1024]]} => !stream.timepoint +// CHECK: %[[INPUT:.+]] = stream.timepoint.await %[[PARAM_TP]] => %[[PARAM_RESOURCE]] : !stream.resource{%[[C1024]]} +// CHECK: %[[IF_RESULT:.+]] = scf.if %[[TRUE]] -> (!stream.resource) { +// CHECK: %[[C42:.+]] = arith.constant 42 : i32 +// CHECK: %[[FILLED:.+]] = stream.async.fill %[[C42]], %[[INPUT]][%[[C0]] to %[[C256]] for %[[C256]]] : i32 -> %[[INPUT]] as !stream.resource{%[[C1024]]} +// CHECK: scf.yield %[[FILLED]] : !stream.resource +// CHECK: } else { +// CHECK: scf.yield %[[INPUT]] : !stream.resource +// CHECK: } +// CHECK: %[[UPDATE_END:.+]] = arith.addi %[[PACK_SIZE]]#1, %[[C1024]] : index +// CHECK: %[[UPDATED:.+]] = stream.async.update {{.+}} %[[IF_RESULT]], %[[ALLOCA_READY]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]]] : !stream.resource{%[[C1024]]} -> %[[ALLOCA_READY]] as !stream.resource{%[[PACK_SIZE]]#0} +// CHECK: %[[BARRIER_RESULT:.+]], %[[BARRIER_TP:.+]] = stream.timepoint.barrier {{.+}} %[[UPDATED]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: %[[SCATTER_RESULT:.+]], %[[SCATTER_TP:.+]] = stream.async.parameter.scatter {{.+}} await(%[[BARRIER_TP]]) { +// CHECK-NEXT: %[[BARRIER_RESULT]][%[[PACK_SIZE]]#1 to %[[UPDATE_END]] for %[[C1024]]] : !stream.resource{%[[PACK_SIZE]]#0} -> ""::"parameter0"[%[[C0_I64]]] +// CHECK-NEXT: } : !stream.resource => !stream.timepoint +// CHECK: %[[JOIN_TP:.+]] = stream.timepoint.join max(%[[SCATTER_TP]]) => !stream.timepoint +// CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca {{.+}} await(%[[JOIN_TP]]) => %[[SCATTER_RESULT]] : !stream.resource{%[[PACK_SIZE]]#0} => !stream.timepoint +// CHECK: stream.timepoint.chain_external {{.+}} %[[DEALLOCA_TP]] => (%arg2 : !hal.fence) + +// Original module should have parameter load instead of scf.if. +// CHECK-LABEL: util.global private @scf_if_constant_condition +util.global private @scf_if_constant_condition : !stream.resource + +util.initializer { + %true = arith.constant true + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"conditional_param"> : vector<1024xi8> + + // Conditional with compile-time constant. + // CHECK-NOT: scf.if + // CHECK-NOT: stream.async.fill + %result = scf.if %true -> (!stream.resource) { + // True branch - should be taken. + %c42_i32 = arith.constant 42 : i32 + %processed = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + scf.yield %processed : !stream.resource + } else { + // False branch - should be eliminated. + scf.yield %param : !stream.resource + } + + // Original module loads from parameter instead of executing conditional. + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @scf_if_constant_condition + util.global.store %result, @scf_if_constant_condition : !stream.resource + util.return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multiple Outputs from Single Parameter +//===----------------------------------------------------------------------===// + +// Tests single parameter producing multiple transformed outputs. +// Should hoist both transformations, outputs packed. +// Real-world: Different quantization formats for different layers. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK-DAG: util.global private @single_param_multi_output_a : !stream.resource +util.global private @single_param_multi_output_a : !stream.resource +// CHECK-DAG: util.global private @single_param_multi_output_b : !stream.resource +util.global private @single_param_multi_output_b : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c768 = arith.constant 768 : index + %c1024 = arith.constant 1024 : index + + // Single parameter input + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"shared_param"> : vector<1024xi8> + + // First transformation + %c100_i32 = arith.constant 100 : i32 + // CHECK-NOT: stream.async.fill + %output_a = stream.async.fill %c100_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // Second transformation + %c200_i32 = arith.constant 200 : i32 + %output_b = stream.async.fill %c200_i32, %param[%c512 to %c768 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // Both outputs are packed into a single parameter, loaded twice and extracted via subviews. + // CHECK-DAG: %[[PACKED_SIZE:.+]] = arith.constant 2048 : index + // CHECK-DAG: %[[SUBVIEW_OFFSET_0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[SUBVIEW_SIZE:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[PARAM_A:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_A:.+]] = stream.resource.subview %[[PARAM_A]][%[[SUBVIEW_OFFSET_0]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_A]], @single_param_multi_output_a : !stream.resource + util.global.store %output_a, @single_param_multi_output_a : !stream.resource + // CHECK-DAG: %[[PARAM_B:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_B:.+]] = stream.resource.subview %[[PARAM_B]][%[[SUBVIEW_SIZE]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_B]], @single_param_multi_output_b : !stream.resource + util.global.store %output_b, @single_param_multi_output_b : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Device Specialization & Affinity +//===----------------------------------------------------------------------===// + +// Tests parameter with affinity annotation. +// Should hoist with affinity preserved in encoder. +// Real-world: GPU-specific parameter transformation. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_affinity : !stream.resource +util.global private @param_with_affinity : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter with device affinity + %param = stream.async.constant on(#hal.device.affinity<@device_0>) : + !stream.resource{%c1024} = + #stream.parameter.named<"model"::"gpu_param"> : vector<1024xi8> + + // Transformation maintaining affinity + %c42_i32 = arith.constant 42 : i32 + // CHECK-NOT: stream.async.fill + %result = stream.async.fill on(#hal.device.affinity<@device_0>) %c42_i32, + %param[%c0 to %c256 for %c256] : i32 -> %param as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_affinity : !stream.resource + util.global.store %result, @param_with_affinity : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + + +// ----- + +//===----------------------------------------------------------------------===// +// Stress Tests +//===----------------------------------------------------------------------===// + +// Tests very small parameter (1 byte). +// Should handle minimum size parameters. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @minimum_size_param : !stream.resource +util.global private @minimum_size_param : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1 = arith.constant 1 : index + + // Tiny 1-byte parameter + %param = stream.async.constant : !stream.resource{%c1} = + #stream.parameter.named<"model"::"tiny"> : vector<1xi8> + + // Even tiny transform should work + %c0 = arith.constant 0 : index + %c42_i32 = arith.constant 42 : i32 + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c1 for %c1] : + i32 -> %param as !stream.resource{%c1} + + // CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C64]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT:.+]] = stream.resource.subview %[[PARAM]][%[[C0]]] : !stream.resource{%[[C64]]} -> !stream.resource{%[[C1]]} + // CHECK: util.global.store %[[RESULT]], @minimum_size_param : !stream.resource + util.global.store %filled, @minimum_size_param : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests deep expression DAG (multiple levels of operations). +// Should handle deep computation chains. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @deep_expression_dag : !stream.resource +util.global private @deep_expression_dag : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Load parameter + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"deep_param"> : vector<1024xi8> + + // Deep chain of transformations + // CHECK-NOT: stream.async.fill + %c1_i32 = arith.constant 1 : i32 + %stage1 = stream.async.fill %c1_i32, %param[%c0 to %c64 for %c64] : + i32 -> %param as !stream.resource{%c1024} + + %c2_i32 = arith.constant 2 : i32 + %stage2 = stream.async.fill %c2_i32, %stage1[%c64 to %c128 for %c64] : + i32 -> %stage1 as !stream.resource{%c1024} + + %c3_i32 = arith.constant 3 : i32 + %stage3 = stream.async.fill %c3_i32, %stage2[%c128 to %c192 for %c64] : + i32 -> %stage2 as !stream.resource{%c1024} + + %c4_i32 = arith.constant 4 : i32 + %stage4 = stream.async.fill %c4_i32, %stage3[%c192 to %c256 for %c64] : + i32 -> %stage3 as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @deep_expression_dag : !stream.resource + util.global.store %stage4, @deep_expression_dag : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Advanced Growth Factor Tests +//===----------------------------------------------------------------------===// + +// Tests exact 1.2x growth threshold - should pass. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @exact_growth_threshold : !stream.resource +util.global private @exact_growth_threshold : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1000 = arith.constant 1000 : index + %c1200 = arith.constant 1200 : index // Exactly 1.2x + + %param = stream.async.constant : !stream.resource{%c1000} = + #stream.parameter.named<"model"::"exact_threshold"> : vector<1000xi8> + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + // CHECK-NOT: stream.async.splat + %padded = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1200} + // CHECK-NOT: stream.async.update + %result = stream.async.update %param, %padded[%c0 to %c1000] : + !stream.resource{%c1000} -> %padded as !stream.resource{%c1200} + + // CHECK-DAG: %[[C1216:.+]] = arith.constant 1216 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1200:.+]] = arith.constant 1200 : index + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C1216]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT:.+]] = stream.resource.subview %[[PARAM]][%[[C0]]] : !stream.resource{%[[C1216]]} -> !stream.resource{%[[C1200]]} + // CHECK: util.global.store %[[RESULT]], @exact_growth_threshold : !stream.resource + util.global.store %result, @exact_growth_threshold : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests just over 1.2x growth (1.21x) - should reject. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @over_growth_threshold : !stream.resource +util.global private @over_growth_threshold : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK-DAG: %[[C1000:.+]] = arith.constant 1000 : index + %c1000 = arith.constant 1000 : index + // CHECK-DAG: %[[C1210:.+]] = arith.constant 1210 : index + %c1210 = arith.constant 1210 : index // 1.21x - over threshold + + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C1000]]} = #stream.parameter.named<"model"::"over_threshold"> + %param = stream.async.constant : !stream.resource{%c1000} = + #stream.parameter.named<"model"::"over_threshold"> : vector<1000xi8> + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[PADDED:.+]] = stream.async.splat + %padded = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1210} + // CHECK: %[[RESULT:.+]] = stream.async.update %[[PARAM]], %[[PADDED]] + %result = stream.async.update %param, %padded[%c0 to %c1000] : + !stream.resource{%c1000} -> %padded as !stream.resource{%c1210} + + // CHECK: util.global.store %[[RESULT]], @over_growth_threshold : !stream.resource + util.global.store %result, @over_growth_threshold : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Complex Data Flow Patterns +//===----------------------------------------------------------------------===// + +// Tests parameter used by multiple operations (wide DAG). +// Single parameter with many consumers. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK-DAG: util.global private @wide_expression_dag_a : !stream.resource +util.global private @wide_expression_dag_a : !stream.resource +// CHECK-DAG: util.global private @wide_expression_dag_b : !stream.resource +util.global private @wide_expression_dag_b : !stream.resource +// CHECK-DAG: util.global private @wide_expression_dag_c : !stream.resource +util.global private @wide_expression_dag_c : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c100 = arith.constant 100 : index + %c200 = arith.constant 200 : index + %c300 = arith.constant 300 : index + %c1024 = arith.constant 1024 : index + + // Single parameter used by many operations + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"wide_param"> : vector<1024xi8> + + // Many transformations using the same parameter + %c10_i32 = arith.constant 10 : i32 + // CHECK-NOT: stream.async.fill + %out_a = stream.async.fill %c10_i32, %param[%c0 to %c100 for %c100] : + i32 -> %param as !stream.resource{%c1024} + + %c20_i32 = arith.constant 20 : i32 + %out_b = stream.async.fill %c20_i32, %param[%c100 to %c200 for %c100] : + i32 -> %param as !stream.resource{%c1024} + + %c30_i32 = arith.constant 30 : i32 + %out_c = stream.async.fill %c30_i32, %param[%c200 to %c300 for %c100] : + i32 -> %param as !stream.resource{%c1024} + + // All outputs packed into a single parameter and extracted via subviews. + // CHECK-DAG: %[[PACKED_SIZE:.+]] = arith.constant 3072 : index + // CHECK-DAG: %[[OFFSET_0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[SUBVIEW_SIZE:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[OFFSET_2048:.+]] = arith.constant 2048 : index + // CHECK-DAG: %[[PARAM_A:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_A:.+]] = stream.resource.subview %[[PARAM_A]][%[[OFFSET_0]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_A]], @wide_expression_dag_a : !stream.resource + util.global.store %out_a, @wide_expression_dag_a : !stream.resource + // CHECK-DAG: %[[PARAM_B:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_B:.+]] = stream.resource.subview %[[PARAM_B]][%[[SUBVIEW_SIZE]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_B]], @wide_expression_dag_b : !stream.resource + util.global.store %out_b, @wide_expression_dag_b : !stream.resource + // CHECK-DAG: %[[PARAM_C:.+]] = stream.async.constant : !stream.resource{%[[PACKED_SIZE]]} = #stream.parameter.named<""::"parameter0"> + // CHECK-DAG: %[[RESULT_C:.+]] = stream.resource.subview %[[PARAM_C]][%[[OFFSET_2048]]] : !stream.resource{%[[PACKED_SIZE]]} -> !stream.resource{%[[SUBVIEW_SIZE]]} + // CHECK-DAG: util.global.store %[[RESULT_C]], @wide_expression_dag_c : !stream.resource + util.global.store %out_c, @wide_expression_dag_c : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter transformation with clone operation. +// Clone operations should be handled (may have preferCloneToConsumers). + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_clone : !stream.resource +util.global private @param_with_clone : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"clone_param"> : vector<1024xi8> + + // Clone operation (might have preferCloneToConsumers) + // CHECK-NOT: stream.async.clone + %cloned = stream.async.clone %param : !stream.resource{%c1024} -> + !stream.resource{%c1024} + + // Transform the clone + %c99_i32 = arith.constant 99 : i32 + // CHECK-NOT: stream.async.fill + %result = stream.async.fill %c99_i32, %cloned[%c0 to %c256 for %c256] : + i32 -> %cloned as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_clone : !stream.resource + util.global.store %result, @param_with_clone : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter transformation with clone at END of expression. +// This tests findProducedValue skipping past final clone to find producer. +// Pattern: param → clone(to *) → dispatch → clone(to constant) → store. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +stream.executable private @dispatch_for_clone_test { + stream.executable.export public @fill +} + +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_trailing_clone : !stream.resource +util.global private @param_with_trailing_clone : !stream.resource + +// The original ops (clone → dispatch → clone) should all be hoisted to encoder. +// CHECK: util.initializer { +// CHECK-NOT: stream.async.clone +// CHECK-NOT: stream.async.dispatch +// CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> +// CHECK: util.global.store %[[PARAM]], @param_with_trailing_clone +// CHECK: util.return +// CHECK: } +util.initializer { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"trailing_clone_param"> : vector<1024xi8> + + // Clone to unknown lifetime for dispatch input. + %for_dispatch = stream.async.clone %param : + !stream.resource{%c1024} -> !stream.resource<*>{%c1024} + + // Dispatch transforms the parameter. + %dispatched = stream.async.dispatch @dispatch_for_clone_test::@fill[%c1, %c1, %c1](%for_dispatch[%c0 to %c1024 for %c1024]) : + (!stream.resource<*>{%c1024}) -> !stream.resource<*>{%c1024} + + // Clone at END of expression back to constant lifetime. + // findProducedValue must skip this to find the dispatch as the producer. + %result = stream.async.clone %dispatched : + !stream.resource<*>{%c1024} -> !stream.resource{%c1024} + + util.global.store %result, @param_with_trailing_clone : !stream.resource + util.return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Transfer Operations +//===----------------------------------------------------------------------===// + +// Tests parameter with transfer operations. +// Transfers should be handled correctly. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_transfer : !stream.resource +util.global private @param_transfer : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"transfer_param"> : vector<1024xi8> + + // Transfer to different lifetime (if needed) + // CHECK-NOT: stream.async.transfer + %transferred = stream.async.transfer %param : + !stream.resource{%c1024} -> !stream.resource{%c1024} + + // Transform transferred value + %c88_i32 = arith.constant 88 : i32 + // CHECK-NOT: stream.async.fill + %result = stream.async.fill %c88_i32, %transferred[%c0 to %c256 for %c256] : + i32 -> %transferred as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_transfer : !stream.resource + util.global.store %result, @param_transfer : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests parameter copying (combining two parameters). +// Copy operations should be hoisted to encoder. +// Real-world: Combining parameter shards into single buffer. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK: util.global private @param_copy_combine : !stream.resource +util.global private @param_copy_combine : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + // Load two parameters that will be combined. + %param1 = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"shard0"> : vector<512xi8> + %param2 = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"shard1"> : vector<512xi8> + + // Create destination buffer. + %c0_i32 = arith.constant 0 : i32 + // CHECK-NOT: stream.async.splat + %combined = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // Copy first parameter. + // CHECK-NOT: stream.async.copy + %with_first = stream.async.copy %param1[%c0 to %c512], %combined[%c0 to %c512], %c512 : + !stream.resource{%c512} -> %combined as !stream.resource{%c1024} + + // Copy second parameter. + // CHECK-NOT: stream.async.copy + %result = stream.async.copy %param2[%c0 to %c512], %with_first[%c512 to %c1024], %c512 : + !stream.resource{%c512} -> %with_first as !stream.resource{%c1024} + + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_copy_combine : !stream.resource + util.global.store %result, @param_copy_combine : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multiple Initializers +//===----------------------------------------------------------------------===// + +// Tests multiple initializers in same module. +// All should be processed independently. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all +// CHECK: util.func public @__encode_parameters_all + +// CHECK-LABEL: module { +// CHECK-DAG: util.global private @multi_init_a : !stream.resource +util.global private @multi_init_a : !stream.resource +// CHECK-DAG: util.global private @multi_init_b : !stream.resource +util.global private @multi_init_b : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c512 = arith.constant 512 : index + %param_a = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"init_a"> : vector<512xi8> + + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c11_i32 = arith.constant 11 : i32 + // CHECK-NOT: stream.async.fill + %result_a = stream.async.fill %c11_i32, %param_a[%c0 to %c256 for %c256] : + i32 -> %param_a as !stream.resource{%c512} + + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index + // CHECK: %[[PARAM_A:.+]] = stream.async.constant : !stream.resource{%[[C1024]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT_A:.+]] = stream.resource.subview %[[PARAM_A]][%[[C0]]] : !stream.resource{%[[C1024]]} -> !stream.resource{%[[C512]]} + // CHECK: util.global.store %[[RESULT_A]], @multi_init_a : !stream.resource + util.global.store %result_a, @multi_init_a : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// CHECK: util.initializer { +util.initializer { + %c512 = arith.constant 512 : index + %param_b = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"init_b"> : vector<512xi8> + + %c256 = arith.constant 256 : index + %c512_0 = arith.constant 512 : index + %c22_i32 = arith.constant 22 : i32 + // CHECK-NOT: stream.async.fill + %result_b = stream.async.fill %c22_i32, %param_b[%c256 to %c512_0 for %c256] : + i32 -> %param_b as !stream.resource{%c512} + + // CHECK-DAG: %[[C1024_0:.+]] = arith.constant 1024 : index + // CHECK-DAG: %[[C512_0:.+]] = arith.constant 512 : index + // CHECK: %[[PARAM_B:.+]] = stream.async.constant : !stream.resource{%[[C1024_0]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[RESULT_B:.+]] = stream.resource.subview %[[PARAM_B]][%[[C512_0]]] : !stream.resource{%[[C1024_0]]} -> !stream.resource{%[[C512_0]]} + // CHECK: util.global.store %[[RESULT_B]], @multi_init_b : !stream.resource + util.global.store %result_b, @multi_init_b : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Resource Lifetime Tests +//===----------------------------------------------------------------------===// + +// Tests non-constant resource lifetime (should skip). + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +// CHECK: util.global private @non_constant_lifetime : !stream.resource +util.global private @non_constant_lifetime : !stream.resource + +// CHECK: util.initializer { +util.initializer { + // CHECK: %[[C1024:.+]] = arith.constant 1024 : index + %c1024 = arith.constant 1024 : index + + // Transient resource (not constant) - should skip + // CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[TRANSIENT:.+]] = stream.async.splat %[[C0_I32]] : i32 -> !stream.resource{%[[C1024]]} + %transient = stream.async.splat %c0_i32 : i32 -> !stream.resource{%c1024} + + // CHECK: util.global.store %[[TRANSIENT]], @non_constant_lifetime : !stream.resource + util.global.store %transient, @non_constant_lifetime : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Mode Testing: Consolidate vs Overlay +//===----------------------------------------------------------------------===// + +// Tests pass-through parameter in consolidate mode (default). +// A parameter loaded and stored directly with no transformation should be +// included in the encoder output when in consolidate mode. + +// CHECK: module @encoder +// CHECK: util.func public @__encode_parameter_detect_target +// CHECK: util.func public @__encode_parameter_indices_all +// CHECK: util.func public @__encode_parameter_steps_all + +// CHECK-LABEL: module { +// CHECK: util.global private @passthrough_consolidate : !stream.resource +util.global private @passthrough_consolidate : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + + // Load parameter directly without transformation. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"passthrough_param"> : vector<1024xi8> + + // Store directly - this is a pass-through (no transformation). + // In consolidate mode, this should be included in encoder output. + // CHECK: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @passthrough_consolidate + util.global.store %param, @passthrough_consolidate : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests pass-through parameter in overlay mode. +// Same as previous test but with overlay mode enabled. +// The parameter should NOT be included in the encoder output since it's +// unmodified (includeUnmodified=false in overlay mode). + +// Anchor to this specific test's main module +// OVERLAY-LABEL: util.global private @passthrough_overlay : !stream.resource +util.global private @passthrough_overlay : !stream.resource + +// OVERLAY: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + + // Load parameter directly without transformation. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"passthrough_param_overlay"> : vector<1024xi8> + + // Store directly - pass-through with no transformation. + // In overlay mode, this should NOT be in encoder output. + // The original parameter load should remain unchanged. + // OVERLAY: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<"model"::"passthrough_param_overlay"> + // OVERLAY: util.global.store %[[PARAM]], @passthrough_overlay + util.global.store %param, @passthrough_overlay : !stream.resource + // OVERLAY: util.return + util.return + // OVERLAY: } +} + +// ----- + +// Tests mixed parameters in consolidate mode. +// One parameter with transformation, one pass-through. +// Consolidate mode should include both in encoder output. + +// CHECK-LABEL: util.global private @mixed_transformed : !stream.resource +util.global private @mixed_transformed : !stream.resource +// CHECK: util.global private @mixed_passthrough : !stream.resource +util.global private @mixed_passthrough : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter 1: Transformed with fill operation. + %param1 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param1"> : vector<1024xi8> + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param1[%c0 to %c256 for %c256] : i32 -> %param1 as !stream.resource{%c1024} + + // Parameter 2: Pass-through (no transformation). + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param2"> : vector<1024xi8> + + // In consolidate mode, both should be loaded from encoder output. + // CHECK-DAG: %[[C2048:.+]] = arith.constant 2048 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index + // CHECK: %[[PARAM:.+]] = stream.async.constant : !stream.resource{%[[C2048]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[SUBVIEW1:.+]] = stream.resource.subview %[[PARAM]][%[[C0]]] : !stream.resource{%[[C2048]]} -> !stream.resource{%[[C1024]]} + // CHECK: util.global.store %[[SUBVIEW1]], @mixed_transformed + util.global.store %filled, @mixed_transformed : !stream.resource + + // CHECK: %[[PARAM_0:.+]] = stream.async.constant : !stream.resource{%[[C2048]]} = #stream.parameter.named<""::"parameter0"> + // CHECK: %[[SUBVIEW2:.+]] = stream.resource.subview %[[PARAM_0]][%[[C1024]]] : !stream.resource{%[[C2048]]} -> !stream.resource{%[[C1024]]} + // CHECK: util.global.store %[[SUBVIEW2]], @mixed_passthrough + util.global.store %param2, @mixed_passthrough : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests mixed parameters in overlay mode. +// One parameter with transformation, one pass-through. +// Overlay mode should only include the transformed parameter. + +// Anchor to the main module's first global to scope checks to this section +// OVERLAY-MIXED-LABEL: util.global private @mixed_transformed_overlay : !stream.resource +util.global private @mixed_transformed_overlay : !stream.resource +// OVERLAY-MIXED: util.global private @mixed_passthrough_overlay : !stream.resource +util.global private @mixed_passthrough_overlay : !stream.resource + +// OVERLAY-MIXED: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter 1: Transformed with fill operation. + %param1 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param1_overlay"> : vector<1024xi8> + // OVERLAY-MIXED-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param1[%c0 to %c256 for %c256] : i32 -> %param1 as !stream.resource{%c1024} + + // Parameter 2: Pass-through (no transformation). + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"mixed_param2_overlay"> : vector<1024xi8> + + // Overlay mode: transformed parameter from encoder, pass-through from original. + // Parameters can be loaded in any order (SSA), use DAG to allow flexibility. + // OVERLAY-MIXED-DAG: %{{.+}} = stream.async.constant {{.+}} #stream.parameter.named<"model"::"mixed_param2_overlay"> + // OVERLAY-MIXED-DAG: %{{.+}} = stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + + // Stores should happen in this order. + // OVERLAY-MIXED: util.global.store %{{.+}}, @mixed_transformed_overlay + util.global.store %filled, @mixed_transformed_overlay : !stream.resource + + // OVERLAY-MIXED: util.global.store %{{.+}}, @mixed_passthrough_overlay + util.global.store %param2, @mixed_passthrough_overlay : !stream.resource + // OVERLAY-MIXED: util.return + util.return + // OVERLAY-MIXED: } +} + +// ----- + +// Tests side-by-side mode comparison. +// Same input tested with both consolidate and overlay modes using different +// check prefixes to verify behavioral differences. + +// Anchor to this test's unique globals. +// COMPARE-CONSOLIDATE-LABEL: util.global private @compare_transformed : !stream.resource +// COMPARE-OVERLAY-LABEL: util.global private @compare_transformed : !stream.resource +util.global private @compare_transformed : !stream.resource +// COMPARE-CONSOLIDATE: util.global private @compare_passthrough : !stream.resource +// COMPARE-OVERLAY: util.global private @compare_passthrough : !stream.resource +util.global private @compare_passthrough : !stream.resource + +// COMPARE-CONSOLIDATE: util.initializer { +// COMPARE-OVERLAY: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Transformed parameter. + %param1 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"compare_param1"> : vector<1024xi8> + // COMPARE-CONSOLIDATE-NOT: stream.async.fill + // COMPARE-OVERLAY-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param1[%c0 to %c256 for %c256] : i32 -> %param1 as !stream.resource{%c1024} + + // Pass-through parameter. + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"compare_param2"> : vector<1024xi8> + + // Consolidate: Both from encoder output, packed into single parameter0, then subviewed. + // Just verify key operations exist without strict ordering. + // COMPARE-CONSOLIDATE-DAG: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // COMPARE-CONSOLIDATE-DAG: stream.resource.subview + // COMPARE-CONSOLIDATE-DAG: util.global.store %{{.+}}, @compare_transformed + util.global.store %filled, @compare_transformed : !stream.resource + + // COMPARE-CONSOLIDATE-DAG: util.global.store %{{.+}}, @compare_passthrough + + // Overlay: Transformed from encoder (parameter0), pass-through from original. + // COMPARE-OVERLAY-DAG: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // COMPARE-OVERLAY-DAG: stream.async.constant {{.+}} #stream.parameter.named<"model"::"compare_param2"> + // COMPARE-OVERLAY-DAG: util.global.store %{{.+}}, @compare_transformed + + // COMPARE-OVERLAY-DAG: util.global.store %{{.+}}, @compare_passthrough + util.global.store %param2, @compare_passthrough : !stream.resource + + // COMPARE-CONSOLIDATE: util.return + // COMPARE-OVERLAY: util.return + util.return + // COMPARE-CONSOLIDATE: } + // COMPARE-OVERLAY: } +} + +// ----- + +// Tests custom output scope. +// Verifies that the encoder uses a custom scope name instead of default "encoded". + +// Anchor to this test's unique global. +// SCOPE-LABEL: util.global private @custom_scope_global : !stream.resource +util.global private @custom_scope_global : !stream.resource + +// SCOPE: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Parameter with transformation. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"custom_scope_param"> : vector<1024xi8> + // SCOPE-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : i32 -> %param as !stream.resource{%c1024} + + // Should load from custom scope "my_custom_scope" instead of default "encoded". + // SCOPE: %[[PARAM:.+]] = stream.async.constant {{.+}} #stream.parameter.named<"my_custom_scope"::"parameter0"> + // SCOPE: util.global.store %[[PARAM]], @custom_scope_global + util.global.store %filled, @custom_scope_global : !stream.resource + // SCOPE: util.return + util.return + // SCOPE: } +} + +// ----- + +// Tests growth factor threshold with increased limit. +// A parameter that grows 1.8x should be rejected with default threshold (1.2x) +// but accepted with custom threshold (2.0x). + +// Anchor to this test's unique global. +// GROWTH2-LABEL: util.global private @growth_factor_test : !stream.resource +util.global private @growth_factor_test : !stream.resource + +// GROWTH2: util.initializer { +util.initializer { + %c42_i32 = arith.constant 42 : i32 + %c0 = arith.constant 0 : index + %c1000 = arith.constant 1000 : index + %c1800 = arith.constant 1800 : index + + // Parameter that grows from 1000 bytes (input) to 1800 bytes (after fill/pad). + // 1.8x growth exceeds default 1.2x threshold but passes with 2.0x threshold. + %param = stream.async.constant : !stream.resource{%c1000} = + #stream.parameter.named<"model"::"growth_param"> : vector<1000xi8> + + // Fill operation that expands the parameter size (1000 -> 1800 bytes). + // GROWTH2-NOT: stream.async.fill + %expanded = stream.async.fill %c42_i32, %param[%c0 to %c1800 for %c1800] : i32 -> %param as !stream.resource{%c1800} + + // With growth factor 2.0, this should be hoisted (1.8x < 2.0). + // Verify transformation was hoisted: parameter loads from encoder output. + // GROWTH2-DAG: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // GROWTH2-DAG: util.global.store %{{.+}}, @growth_factor_test + util.global.store %expanded, @growth_factor_test : !stream.resource + // GROWTH2: util.return + util.return + // GROWTH2: } +} + +// ----- + +// Tests empty encoder module in overlay mode. +// When all parameters are pass-through (no transformations) and in overlay mode, +// no encoder module should be generated since there's nothing to encode. + +// Anchor to this test's unique global. +// EMPTY-LABEL: util.global private @empty_test_1 : !stream.resource +util.global private @empty_test_1 : !stream.resource +// EMPTY: util.global private @empty_test_2 : !stream.resource +util.global private @empty_test_2 : !stream.resource + +// EMPTY: util.initializer { +util.initializer { + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + // Pass-through parameter 1 (no transformation). + %param1 = stream.async.constant : !stream.resource{%c512} = + #stream.parameter.named<"model"::"empty_param1"> : vector<512xi8> + + // Pass-through parameter 2 (no transformation). + %param2 = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"empty_param2"> : vector<1024xi8> + + // Both should load from original parameters (no encoder output). + // EMPTY-DAG: stream.async.constant {{.+}} #stream.parameter.named<"model"::"empty_param1"> + // EMPTY-DAG: stream.async.constant {{.+}} #stream.parameter.named<"model"::"empty_param2"> + // EMPTY-DAG: util.global.store %{{.+}}, @empty_test_1 + // EMPTY-DAG: util.global.store %{{.+}}, @empty_test_2 + util.global.store %param1, @empty_test_1 : !stream.resource + + util.global.store %param2, @empty_test_2 : !stream.resource + + // EMPTY: util.return + util.return + // EMPTY: } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multi-Block Slice Ordering Tests +//===----------------------------------------------------------------------===// +// These tests exercise the slice ordering logic when operations span multiple +// blocks or regions. The backward slice collection must maintain proper +// topological order even when captured values from nested regions are involved. + +// Tests that captured values from scf.if regions are handled correctly. +// This exercises the multi-root slice ordering logic where values defined +// outside an scf.if are used inside its regions. + +stream.executable private @captured_dispatch { + stream.executable.export public @dispatch +} + +// CHECK-LABEL: util.global private @captured_value_if_ordering : !stream.resource +util.global private @captured_value_if_ordering : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Load parameter (will be in slice). + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"captured_param"> : vector<1024xi8> + + // Value defined outside scf.if but captured inside - this tests that + // the slice ordering handles captured values correctly. + %outside_value = arith.constant 42 : i32 + + // The scf.if captures %outside_value and %param from outside. + // When building the backward slice, we collect both the stored value + // and captured values. The ordering must ensure %outside_value's producer + // (arith.constant) comes before any op inside the region that uses it. + %cond = arith.constant true + // CHECK-NOT: scf.if + %result = scf.if %cond -> !stream.resource { + // Uses %outside_value (captured) and %param. + %filled = stream.async.fill %outside_value, %param[%c0 to %c1024 for %c1024] + : i32 -> %param as !stream.resource{%c1024} + scf.yield %filled : !stream.resource + } else { + scf.yield %param : !stream.resource + } + + // Encoder should transform this to load from encoded parameter. + // CHECK: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %{{.+}}, @captured_value_if_ordering + util.global.store %result, @captured_value_if_ordering : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests captured values with scf.for loop. +// Similar to the scf.if test but with loop-carried values. + +stream.executable private @for_dispatch { + stream.executable.export public @dispatch +} + +// CHECK-LABEL: util.global private @captured_value_for_ordering : !stream.resource +util.global private @captured_value_for_ordering : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c256 = arith.constant 256 : index + + // Load parameter. + %param = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"for_param"> : vector<1024xi8> + + // Value captured by the loop body. + %fill_pattern = arith.constant 7 : i32 + + // Loop that captures %fill_pattern from outside. + // CHECK-NOT: scf.for + %result = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc = %param) -> !stream.resource { + // Uses captured %fill_pattern. + %offset = arith.muli %i, %c256 : index + %end = arith.addi %offset, %c256 : index + %filled = stream.async.fill %fill_pattern, %acc[%offset to %end for %c256] + : i32 -> %acc as !stream.resource{%c1024} + scf.yield %filled : !stream.resource + } + + // CHECK: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %{{.+}}, @captured_value_for_ordering + util.global.store %result, @captured_value_for_ordering : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests nested scf.if with dispatch that uses multiple captured values. +// This more complex case exercises ordering across multiple region levels. + +stream.executable private @nested_dispatch { + stream.executable.export public @compute +} + +// CHECK-LABEL: util.global private @nested_captured_ordering : !stream.resource +util.global private @nested_captured_ordering : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c2048 = arith.constant 2048 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Load two parameters that will both be used inside nested regions. + %param_a = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"nested_a"> : vector<1024xi8> + %param_b = stream.async.constant : !stream.resource{%c1024} = + #stream.parameter.named<"model"::"nested_b"> : vector<1024xi8> + + // Dispatch using both parameters - creates slice with multiple inputs. + // CHECK-NOT: stream.async.dispatch + %combined = stream.async.dispatch @nested_dispatch::@compute[%c1, %c1, %c1]( + %param_a[%c0 to %c1024 for %c1024], + %param_b[%c0 to %c1024 for %c1024] + ) : (!stream.resource{%c1024}, !stream.resource{%c1024}) -> !stream.resource{%c2048} + + // CHECK: stream.async.constant {{.+}} #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %{{.+}}, @nested_captured_ordering + util.global.store %combined, @nested_captured_ordering : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests that an empty module with no parameter expressions runs cleanly. +// This verifies that when no output file is specified (the default) and +// no encoding work is found, the pass completes without errors. + +// CHECK-LABEL: module { +// CHECK-NOT: module @encoder +module { + // A simple global that doesn't involve any parameters. + // CHECK: util.global private @no_params : i32 + util.global private @no_params : i32 + util.initializer { + // CHECK: %[[C42:.+]] = arith.constant 42 : i32 + %c42 = arith.constant 42 : i32 + // CHECK: util.global.store %[[C42]], @no_params + util.global.store %c42, @no_params : i32 + // CHECK: util.return + util.return + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// External Timepoint Synchronization Tests +//===----------------------------------------------------------------------===// + +// Tests that when a parameter load awaits on an external timepoint, the +// replacement async.constant also awaits on that timepoint. +// This exercises Source A of collectExternalTimepoints: external await +// timepoints from TimelineOpInterface ops in the expression. + +// CHECK: module @encoder +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_external_await : !stream.resource +util.global private @param_with_external_await : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c42_i32 = arith.constant 42 : i32 + + // External timeline op that produces a timepoint we must wait on. + // This is NOT part of the encoding expression (doesn't feed into the store). + // CHECK: %[[EXTERNAL_RESOURCE:.+]], %[[EXTERNAL_TP:.+]] = stream.test.timeline_op + %external_resource, %external_tp = stream.test.timeline_op + with() : () -> !stream.resource{%c1024} => !stream.timepoint + + // Parameter load that awaits on the external timepoint. + // The expression starts here - this op and the fill below form the expression. + %param = stream.async.constant await(%external_tp) : + !stream.resource{%c1024} = + #stream.parameter.named<"model"::"awaiting_param"> : vector<1024xi8> + + // Transform the parameter so it gets hoisted. + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // The replacement should await on the external timepoint. + // CHECK: %[[PARAM:.+]] = stream.async.constant await(%[[EXTERNAL_TP]]) + // CHECK-SAME: #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_external_await + util.global.store %filled, @param_with_external_await : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} + +// ----- + +// Tests that when a parameter load awaits on a joined timepoint from multiple +// external timeline ops, the replacement async.constant awaits on that same +// joined timepoint. This exercises the case where the join is in the expression +// slice but is not a resource contributor (it only produces a timepoint). + +// CHECK: module @encoder +// CHECK-LABEL: module { +// CHECK: util.global private @param_with_joined_external_timepoints : !stream.resource +util.global private @param_with_joined_external_timepoints : !stream.resource + +// CHECK: util.initializer { +util.initializer { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c42_i32 = arith.constant 42 : i32 + + // Two external timeline ops that produce timepoints we must wait on. + // Their resources are unused, so they're not resource contributors. + // CHECK-DAG: %[[EXT_R1:.+]], %[[EXT_TP1:.+]] = stream.test.timeline_op + %ext_r1, %ext_tp1 = stream.test.timeline_op + with() : () -> !stream.resource{%c1024} => !stream.timepoint + // CHECK-DAG: %[[EXT_R2:.+]], %[[EXT_TP2:.+]] = stream.test.timeline_op + %ext_r2, %ext_tp2 = stream.test.timeline_op + with() : () -> !stream.resource{%c1024} => !stream.timepoint + + // Join the timepoints. The join is in the expression but doesn't contribute + // resources, so its result timepoint should be considered external. + // CHECK: %[[JOINED_TP:.+]] = stream.timepoint.join max(%[[EXT_TP1]], %[[EXT_TP2]]) => !stream.timepoint + %joined_tp = stream.timepoint.join max(%ext_tp1, %ext_tp2) => !stream.timepoint + + // Parameter load that awaits on the joined timepoint. + %param = stream.async.constant await(%joined_tp) : + !stream.resource{%c1024} = + #stream.parameter.named<"model"::"joined_await_param"> : vector<1024xi8> + + // Transform the parameter so it gets hoisted. + // CHECK-NOT: stream.async.fill + %filled = stream.async.fill %c42_i32, %param[%c0 to %c256 for %c256] : + i32 -> %param as !stream.resource{%c1024} + + // The replacement should await on the same joined timepoint. + // CHECK: %[[PARAM:.+]] = stream.async.constant await(%[[JOINED_TP]]) + // CHECK-SAME: #stream.parameter.named<""::"parameter0"> + // CHECK: util.global.store %[[PARAM]], @param_with_joined_external_timepoints + util.global.store %filled, @param_with_joined_external_timepoints : !stream.resource + // CHECK: util.return + util.return + // CHECK: } +} diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 698b983d160e..9f6b6f4adb2e 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -2221,6 +2221,58 @@ void BufferConstantOp::getAsmResultNames( setNameFn(getResult(), getName().value_or("buffer_cst")); } +void BufferConstantOp::build(OpBuilder &builder, OperationState &state, + Attribute value) { + state.addTypes({builder.getType()}); + state.addAttribute("value", value); +} + +void BufferConstantOp::build(OpBuilder &builder, OperationState &state, + StringRef value) { + state.addTypes({builder.getType()}); + state.addAttribute("value", builder.getStringAttr(value)); +} + +void BufferConstantOp::build(OpBuilder &builder, OperationState &state, + ArrayRef value) { + state.addTypes({builder.getType()}); + state.addAttribute("value", + DenseIntElementsAttr::get( + VectorType::get(static_cast(value.size()), + builder.getI8Type()), + value)); +} + +// static +Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, + Attribute value) { + if (!value) { + auto bufferType = builder.getType(); + return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); + } + return IREE::Util::BufferConstantOp::create(builder, loc, value); +} + +// static +Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, + StringRef value) { + if (value.empty()) { + auto bufferType = builder.getType(); + return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); + } + return IREE::Util::BufferConstantOp::create(builder, loc, value); +} + +// static +Value BufferConstantOp::createOrNull(OpBuilder &builder, Location loc, + ArrayRef value) { + if (value.empty()) { + auto bufferType = builder.getType(); + return IREE::Util::NullOp::create(builder, loc, bufferType).getResult(); + } + return IREE::Util::BufferConstantOp::create(builder, loc, value); +} + LogicalResult BufferConstantOp::verify() { if (!isa(getValue())) { return emitOpError("unsupported non-serializable constant attribute type"); diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 208d2e958cc0..abb4c0667ecb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -1418,6 +1418,20 @@ def Util_BufferConstantOp : Util_PureOp<"buffer.constant", [ ($name^)? attr-dict `:` type($result) `=` $value }]; + let builders = [ + OpBuilder<(ins "Attribute":$value)>, + OpBuilder<(ins "StringRef":$value)>, + OpBuilder<(ins "ArrayRef":$value)>, + ]; + + let extraClassDeclaration = [{ + // Returns a new buffer op with the given contents unless they are + // nullptr/empty in which case returns util.null. + static Value createOrNull(OpBuilder &builder, Location loc, Attribute value); + static Value createOrNull(OpBuilder &builder, Location loc, StringRef value); + static Value createOrNull(OpBuilder &builder, Location loc, ArrayRef value); + }]; + let hasVerifier = 1; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td index 8a158100c80e..add5dbbef6df 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.td @@ -58,6 +58,9 @@ def Util_ListType : Util_TypeDef<"List"> { ); let builders = [ + TypeBuilder<(ins), [{ + return $_get($_ctxt, IREE::Util::VariantType::get($_ctxt)); + }]>, TypeBuilderWithInferredContext<(ins "Type":$element_type ), [{ diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel index 4bb133040940..57b34e43879d 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel @@ -52,6 +52,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:LinalgOpsIncGen", "@llvm-project//mlir:LinalgStructuredOpsIncGen", "@llvm-project//mlir:MLProgramDialect", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ValueBoundsOpInterface", diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt index a183c5539427..a52be8559525 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt @@ -39,6 +39,7 @@ iree_cc_library( MLIRLinalgOpsIncGenLib MLIRLinalgStructuredOpsIncGenLib MLIRMLProgramDialect + MLIRMemRefDialect MLIRSCFDialect MLIRTensorDialect MLIRValueBoundsOpInterface diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp index 3f0be78a7a1f..762b03a09c28 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp @@ -80,6 +80,10 @@ struct OptionalOpAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { + op->removeAttr("stream.affinity"); + } }; struct FlowBarrierTargetAffinityAttrExternalModel @@ -108,6 +112,8 @@ struct FlowBarrierTargetAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { op->removeAttr("target"); } }; struct FlowTransferTargetAffinityAttrExternalModel @@ -132,6 +138,8 @@ struct FlowTransferTargetAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { op->removeAttr("target"); } }; template @@ -164,6 +172,8 @@ struct HALTensorAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { op->removeAttr("affinity"); } }; template @@ -197,6 +207,10 @@ struct GlobalOpAffinityAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { + op->removeAttr("stream.affinity"); + } }; template @@ -227,6 +241,10 @@ struct AffinityOpAttrExternalModel IREE::Stream::AffinityAttr getResultAffinityAttr(Operation *op) const { return getAffinityAttr(op); } + + void removeAffinityAttrs(Operation *op) const { + op->removeAttr("stream.affinity"); + } }; struct TensorAffinityTypeExternalModel diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index 75d442ba463d..f39469f502d5 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -902,6 +903,55 @@ struct SCFIndexSwitchOpMutableRegionBranchOpInterface } }; +// Hoistable interface for region-containing control flow operations. +// Control flow is hoistable if control operands are constant and +// nested operations are hoistable (checked via atomic hoisting). +template +struct RegionControlFlowHoistableOpInterface + : public IREE::Util::HoistableOpInterface::ExternalModel< + RegionControlFlowHoistableOpInterface, OpTy> { + bool isHoistableOp(Operation *op) const { + // Control flow is hoistable if all nested operations are hoistable. + for (Region ®ion : op->getRegions()) { + WalkResult result = region.walk([](Operation *nestedOp) { + // Check if nested op is hoistable. + bool isHoistable = false; + if (auto hoistable = + dyn_cast(nestedOp)) { + isHoistable = hoistable.isHoistableOp(); + } else { + // Ops without interface must be memory-effect-free to be hoistable. + isHoistable = mlir::isMemoryEffectFree(nestedOp); + } + if (!isHoistable) { + return WalkResult::interrupt(); + } + // Don't descend into IsolatedFromAbove ops - treat them atomically. + return nestedOp->hasTrait() + ? WalkResult::skip() + : WalkResult::advance(); + }); + if (result.wasInterrupted()) { + return false; + } + } + return true; + } + + bool isHoistableLeafOp(Operation *) const { return false; } + bool isAtomicallyHoistableOp(Operation *) const { return true; } + bool isOperandHoistable(Operation *, OpOperand *) const { return true; } +}; + +template +struct RegionControlFlowHoistableOpInterfaceHelper { + static void registerOpInterface(MLIRContext *context) { + (Ops::template attachInterface>( + *context), + ...); + } +}; + } // namespace void registerUtilExternalModels(DialectRegistry ®istry) { @@ -1022,6 +1072,7 @@ void registerUtilExternalModels(DialectRegistry ®istry) { }); // Register MutableRegionBranchOpInterface for SCF ops. + // Register hoistable op interfaces for SCF control flow ops. registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *dialect) { scf::ForOp::attachInterface( *context); @@ -1030,6 +1081,8 @@ void registerUtilExternalModels(DialectRegistry ®istry) { *context); scf::IndexSwitchOp::attachInterface< SCFIndexSwitchOpMutableRegionBranchOpInterface>(*context); + RegionControlFlowHoistableOpInterfaceHelper< + scf::ForOp, scf::IfOp, scf::WhileOp>::registerOpInterface(context); }); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 4f38d12c88a9..58398be0f896 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -273,10 +273,10 @@ void buildGlobalOptimizationPassPipeline( exportParametersOptions)); } - if (!transformOptions.parameterSplatExportFile.empty()) { + if (!transformOptions.parameterSplatPath.empty()) { IREE::IO::Parameters::GenerateSplatParameterArchivePassOptions generateSplatOptions; - generateSplatOptions.filePath = transformOptions.parameterSplatExportFile; + generateSplatOptions.filePath = transformOptions.parameterSplatPath; mainPassManager.addPass( IREE::IO::Parameters::createGenerateSplatParameterArchivePass( generateSplatOptions)); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h index 317e2615abc2..d1e1c925c1b7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h @@ -49,9 +49,9 @@ struct TransformOptions : public PassPipelineOptions { llvm::cl::desc("Minimum size of constants to export as parameters."), llvm::cl::init(0), }; - Option parameterSplatExportFile{ + Option parameterSplatPath{ *this, - "parameter-splat-export-file", + "parameter-splat-path", llvm::cl::desc("File path to create a splat parameter archive out of all " "parameters in the module."), llvm::cl::init(""), diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index c44d7affc5e0..eb88f483bdd2 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -11,6 +11,7 @@ IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::BindingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::InputDialectOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS( mlir::iree_compiler::GlobalOptimizationOptions); +IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::ParameterOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::SchedulingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::PreprocessingOptions); IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::GlobalPipelineOptions); @@ -161,6 +162,84 @@ void PreprocessingOptions::bindOptions(OptionsBinder &binder) { llvm::cl::cat(category)); } +void ParameterOptions::bindOptions(OptionsBinder &binder) { + static llvm::cl::OptionCategory category("IREE Parameter Options"); + + // Parameter import/export options. + binder.list( + "iree-parameter-import", importPaths, + llvm::cl::desc("File paths to archives to import parameters from with an " + "optional `scope=` prefix."), + llvm::cl::cat(category)); + binder.list( + "iree-parameter-import-keys", importKeys, + llvm::cl::desc("List of parameter keys to import. Any matching keys from " + "any scope will be imported."), + llvm::cl::cat(category)); + binder.opt( + "iree-parameter-import-maximum-size", importMaximumSize, + llvm::cl::desc("Maximum size of parameters to import or 0 to disable " + "automatic import."), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-export", exportPath, + llvm::cl::desc("File path to an archive to export parameters to with an " + "optional `scope=` prefix."), + llvm::cl::cat(category)); + binder.opt( + "iree-parameter-export-minimum-size", exportMinimumSize, + llvm::cl::desc("Minimum size of constants to export to the parameter " + "archive."), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-splat", splatPath, + llvm::cl::desc("File path to create a parameter archive of splat values " + "from all parameter backed globals."), + llvm::cl::cat(category)); + + // Parameter encoder options. + binder.opt( + "iree-parameter-encoder-mode", encoderMode, + llvm::cl::desc("Controls how the encoder manages parameters."), + llvm::cl::values( + clEnumValN(ParameterEncoderMode::Consolidate, "consolidate", + "Merge all encoded and original parameters into a single " + "consolidated scope."), + clEnumValN(ParameterEncoderMode::Overlay, "overlay", + "Only produce encoded parameters and leave original " + "parameters untouched.")), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-encoder-output-file", encoderOutputFile, + llvm::cl::desc(".mlir/.mlirbc file path to write the split parameter " + "encoder module to (empty = disabled)."), + llvm::cl::cat(category)); + + binder.opt( + "iree-parameter-encoder-output-scope", encoderOutputScope, + llvm::cl::desc("Parameter scope for the encoder output parameters."), + llvm::cl::cat(category)); + + // Deprecated flags aliasing the new ones above. + binder.opt( + "iree-opt-export-parameters", exportPath, + deprecated("use --iree-parameter-export= instead"), + llvm::cl::Hidden, + llvm::cl::desc("File path to an archive to export parameters to with an " + "optional `scope=` prefix."), + llvm::cl::cat(category)); + binder.opt( + "iree-opt-splat-parameters", splatPath, + deprecated("use --iree-parameter-splat= instead"), llvm::cl::Hidden, + llvm::cl::desc( + "File path to create a parameter archive of splat values out of all " + "parameter backed globals."), + llvm::cl::cat(category)); +} + void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { static llvm::cl::OptionCategory category( "IREE options for controlling global optimizations."); @@ -216,39 +295,6 @@ void GlobalOptimizationOptions::bindOptions(OptionsBinder &binder) { "information has been extracted."), llvm::cl::cat(category)); - binder.list( - "iree-opt-import-parameters", parameterImportPaths, - llvm::cl::desc("File paths to archives to import parameters from with an " - "optional `scope=` prefix."), - llvm::cl::cat(category)); - binder.list("iree-opt-import-parameter-keys", - parameterImportKeys, - llvm::cl::desc("List of parameter keys to import."), - llvm::cl::cat(category)); - binder.opt("iree-opt-import-parameter-maximum-size", - parameterImportMaximumSize, - llvm::cl::desc("Maximum size of parameters to import."), - llvm::cl::cat(category)); - - binder.opt( - "iree-opt-export-parameters", parameterExportPath, - llvm::cl::desc("File path to an archive to export parameters to with an " - "optional `scope=` prefix."), - llvm::cl::cat(category)); - binder.opt( - "iree-opt-export-parameter-minimum-size", parameterExportMinimumSize, - llvm::cl::desc( - "Minimum size of constants to export to the archive created in " - "`iree-opt-export-parameter-archive-export-file`."), - llvm::cl::cat(category)); - - binder.opt( - "iree-opt-splat-parameters", parameterSplatExportFile, - llvm::cl::desc( - "File path to create a parameter archive of splat values out of all " - "parameter backed globals."), - llvm::cl::cat(category)); - binder.opt( "iree-opt-generalize-matmul", generalizeMatmul, {init_at_opt(llvm::OptimizationLevel::O0, false), diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index 6a7a9803c349..811bd5b17f1a 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -105,28 +105,58 @@ struct PreprocessingOptions { using FromFlags = OptionsFromFlags; }; -// Options controlling high level optimizations. -struct GlobalOptimizationOptions { - llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; +// Defines the mode for parameter encoding. +enum class ParameterEncoderMode { + // Merge all encoded and original parameters into a single consolidated scope. + Consolidate = 0, + // Only produce encoded parameters and leave original parameters untouched. + Overlay = 1, +}; + +// Options controlling parameter management (import/export and encoding). +struct ParameterOptions { + //===--------------------------------------------------------------------===// + // Parameter Import/Export + //===--------------------------------------------------------------------===// // File paths to archives to import parameters from with an optional // `scope=` prefix. - std::vector parameterImportPaths; + std::vector importPaths; // List of parameter keys to import. Any matching keys from any scope will be // imported. - std::vector parameterImportKeys; + std::vector importKeys; // Maximum size of parameters to import or 0 to disable automatic import. - int64_t parameterImportMaximumSize = 0; + int64_t importMaximumSize = 0; // File path to an archive to export parameters to with an optional // `scope=` prefix. - std::string parameterExportPath; + std::string exportPath; // Minimum size of constants to export as parameters. - int64_t parameterExportMinimumSize = 0; + int64_t exportMinimumSize = 0; // File path to create a splat parameter archive out of all parameters in the // module. - std::string parameterSplatExportFile = ""; + std::string splatPath = ""; + + //===--------------------------------------------------------------------===// + // Parameter Encoder + //===--------------------------------------------------------------------===// + + // Controls how the encoder manages parameters. + ParameterEncoderMode encoderMode = ParameterEncoderMode::Consolidate; + // .mlir/.mlirbc file path to write the split parameter encoder module to + // (empty = disabled). + std::string encoderOutputFile; + // Parameter scope for the encoder output parameters. + std::string encoderOutputScope = "encoded"; + + void bindOptions(OptionsBinder &binder); + using FromFlags = OptionsFromFlags; +}; + +// Options controlling high level optimizations. +struct GlobalOptimizationOptions { + llvm::OptimizationLevel optLevel = llvm::OptimizationLevel::O0; // Enables aggressive propagation of transposes to the inputs of named ops, // rewriting named ops as fused generics. diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index c5b471f0711a..9ae20d73a6b5 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -76,6 +76,7 @@ void buildIREEPrecompileTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions globalOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, @@ -175,18 +176,14 @@ void buildIREEPrecompileTransformPassPipeline( halAssignmentOptions); GlobalOptimization::TransformOptions globalTransformOptions; - globalTransformOptions.parameterImportPaths = - globalOptimizationOptions.parameterImportPaths; - globalTransformOptions.parameterImportKeys = - globalOptimizationOptions.parameterImportKeys; + globalTransformOptions.parameterImportPaths = parameterOptions.importPaths; + globalTransformOptions.parameterImportKeys = parameterOptions.importKeys; globalTransformOptions.parameterImportMaximumSize = - globalOptimizationOptions.parameterImportMaximumSize; - globalTransformOptions.parameterExportPath = - globalOptimizationOptions.parameterExportPath; + parameterOptions.importMaximumSize; + globalTransformOptions.parameterExportPath = parameterOptions.exportPath; globalTransformOptions.parameterExportMinimumSize = - globalOptimizationOptions.parameterExportMinimumSize; - globalTransformOptions.parameterSplatExportFile = - globalOptimizationOptions.parameterSplatExportFile; + parameterOptions.exportMinimumSize; + globalTransformOptions.parameterSplatPath = parameterOptions.splatPath; globalTransformOptions.aggressiveTransposePropagation = globalOptimizationOptions.aggressiveTransposePropagation; globalTransformOptions.propagateTransposesThroughConv = @@ -281,6 +278,7 @@ void buildIREEVMTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions globalOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, @@ -290,9 +288,9 @@ void buildIREEVMTransformPassPipeline( IREEVMPipelinePhase compileTo) { buildIREEPrecompileTransformPassPipeline( targetRegistry, pipelineOptions, bindingOptions, inputOptions, - preprocessingOptions, globalOptimizationOptions, dispatchCreationOptions, - schedulingOptions, halTargetOptions, hooks, passManager, compileFrom, - compileTo); + preprocessingOptions, parameterOptions, globalOptimizationOptions, + dispatchCreationOptions, schedulingOptions, halTargetOptions, hooks, + passManager, compileFrom, compileTo); if (compileTo <= IREEVMPipelinePhase::GlobalOptimization) return; // early-exit @@ -306,6 +304,18 @@ void buildIREEVMTransformPassPipeline( (IREE::Stream::DumpOutputFormat)schedulingOptions.dumpStatisticsFormat; streamOptions.dumpStatisticsFile = schedulingOptions.dumpStatisticsFile; + // Set parameter encoder options. These are mapped to + // SplitParameterEncoderPassOptions when the pass is created in + // Stream/Transforms/Passes.cpp. + if (!parameterOptions.encoderOutputFile.empty()) { + streamOptions.parameterEncoderMode = + (IREE::Stream::ParameterEncoderMode)parameterOptions.encoderMode; + streamOptions.parameterEncoderOutputFile = + parameterOptions.encoderOutputFile; + streamOptions.parameterEncoderOutputScope = + parameterOptions.encoderOutputScope; + } + switch (schedulingOptions.executionModel) { case SchedulingOptions::ExecutionModel::HostOnly: // No flow/stream processing (implies no tensors). @@ -451,7 +461,8 @@ void buildDefaultIREEVMTransformPassPipeline(OpPassManager &passManager) { IREE::HAL::TargetRegistry::getGlobal(), GlobalPipelineOptions::FromFlags::get(), BindingOptions::FromFlags::get(), InputDialectOptions::FromFlags::get(), - PreprocessingOptions::FromFlags::get(), highLevelOptimizations, + PreprocessingOptions::FromFlags::get(), + ParameterOptions::FromFlags::get(), highLevelOptimizations, DispatchCreationOptions::FromFlags::get(), SchedulingOptions::FromFlags::get(), IREE::HAL::TargetOptions::FromFlags::get(), diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.h b/compiler/src/iree/compiler/Pipelines/Pipelines.h index 104cc42875b5..c17f9bd893f0 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.h +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.h @@ -103,6 +103,7 @@ void buildIREEPrecompileTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions highLevelOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, @@ -120,6 +121,7 @@ void buildIREEVMTransformPassPipeline( const IREE::HAL::TargetRegistry &targetRegistry, GlobalPipelineOptions pipelineOptions, BindingOptions bindingOptions, InputDialectOptions inputOptions, PreprocessingOptions preprocessingOptions, + ParameterOptions parameterOptions, GlobalOptimizationOptions highLevelOptimizationOptions, DispatchCreationOptions dispatchCreationOptions, SchedulingOptions schedulingOptions, diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index 9b5109c2a4f7..8471e3665f49 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -53,7 +53,6 @@ iree_compiler_cc_library( "EncodingUtils.h", "EquivalenceUtils.h", "FlatbufferUtils.h", - "Folding.h", "Indexing.h", "IntegerSet.h", "ModuleUtils.h", @@ -79,6 +78,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt index b7f0fee64fdf..7842841b068e 100644 --- a/compiler/src/iree/compiler/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt @@ -19,7 +19,6 @@ iree_cc_library( "EncodingUtils.h" "EquivalenceUtils.h" "FlatbufferUtils.h" - "Folding.h" "Indexing.h" "IntegerSet.h" "ModuleUtils.h" @@ -54,6 +53,7 @@ iree_cc_library( MLIRAffineDialect MLIRAnalysis MLIRArithDialect + MLIRBytecodeWriter MLIRFuncDialect MLIRFunctionInterfaces MLIRIR diff --git a/compiler/src/iree/compiler/Utils/Folding.h b/compiler/src/iree/compiler/Utils/Folding.h deleted file mode 100644 index 02184cf9f084..000000000000 --- a/compiler/src/iree/compiler/Utils/Folding.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_UTILS_FOLDING_H_ -#define IREE_COMPILER_UTILS_FOLDING_H_ - -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/OpDefinition.h" - -namespace mlir::iree_compiler { - -// Convert a `Value` or an `Attribute` range to a range of `OpFoldResult`. -template -void toOpFoldResults(Range &&range, OutIt outIt) { - llvm::transform(std::forward(range), outIt, - [](auto v) { return OpFoldResult(v); }); -} - -template -SmallVector toOpFoldResults(Range &&range) { - SmallVector res; - toOpFoldResults(std::forward(range), std::back_inserter(res)); - return res; -} - -} // namespace mlir::iree_compiler - -#endif // IREE_COMPILER_UTILS_FOLDING_H_ diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp index f4972a5aeb21..2006a93dd87b 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp @@ -7,13 +7,17 @@ #include "iree/compiler/Utils/ModuleUtils.h" #include "iree/compiler/Utils/StringUtils.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Path.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" +#include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" namespace mlir::iree_compiler { @@ -221,4 +225,34 @@ LogicalResult mergeSourceModuleInto(Location loc, StringRef source, return mergeModuleInto(*sourceModuleRef, targetOp, targetBuilder); } +LogicalResult writeModule(mlir::ModuleOp moduleOp, StringRef path) { + // Ensure the parent paths exist. + llvm::sys::fs::create_directories(llvm::sys::path::parent_path(path)); + + // Attempt to open file - should succeed as long as permissions are ok. + std::string error; + auto file = mlir::openOutputFile(path, &error); + if (!file) { + return mlir::emitError(moduleOp.getLoc()) + << "while dumping to '" << path << "': " << error << "\n"; + } + + // If going to binary serialize out and otherwise print as text. + if (llvm::sys::path::extension(path) == ".mlirbc") { + BytecodeWriterConfig config; + if (failed(mlir::writeBytecodeToFile(moduleOp, file->os(), config))) { + return mlir::emitError(moduleOp.getLoc()) + << "failed to serialize module to '" << path << "'\n"; + } + } else { + OpPrintingFlags flags; + moduleOp.print(file->os(), flags); + } + + // Keep the temporary file after the write succeeds. + file->keep(); + + return success(); +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.h b/compiler/src/iree/compiler/Utils/ModuleUtils.h index 882266bfaa35..18b0765b2f6b 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.h +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.h @@ -38,6 +38,10 @@ LogicalResult mergeSourceModuleInto(Location loc, StringRef source, Operation *targetOp, OpBuilder &targetBuilder); +// Writes |moduleOp| to the file at |path|. +// The module will be written as MLIR text unless it has the .mlirbc extension. +LogicalResult writeModule(mlir::ModuleOp moduleOp, StringRef path); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_UTILS_MODULEUTILS_H_ diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.h b/compiler/src/iree/compiler/Utils/OptionUtils.h index c6fa25cff27b..6e6e722b0131 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.h +++ b/compiler/src/iree/compiler/Utils/OptionUtils.h @@ -69,6 +69,18 @@ struct opt_scope { } }; +// Modifier to mark an option as deprecated with a warning message. +// When the option is parsed, a deprecation warning will be printed to stderr. +// The apply method is a no-op since OptionsBinder handles the deprecation +// warning through the callback mechanism, but it's required because LLVM's +// applicator will try to call it when the modifier is forwarded. +struct deprecated { + llvm::StringRef message; + explicit deprecated(llvm::StringRef msg) : message(msg) {} + template + void apply(Opt &) const {} +}; + // Base class that can bind named options to fields of structs. // // Typically use by adding the following to your struct: @@ -98,7 +110,8 @@ class OptionsBinder { template void opt(llvm::StringRef name, V &value, Mods... Ms) { - auto [changedCallback, clCallback] = makeChangedCallback(); + const deprecated *dep = filterDeprecated(Ms...); + auto [changedCallback, clCallback] = makeChangedCallback(name, dep); OptionInfo &info = getOptionsStorage()[name]; if (!scope) { // Bind global options. @@ -402,14 +415,25 @@ class OptionsBinder { // Returns a pair of callbacks, the first returns if the option has been // parsed and the second is passed to llvm::cl to track if the option has been - // parsed. + // parsed. If a deprecation message is provided, it will be printed to stderr + // when the option is parsed. template static std::pair> - makeChangedCallback() { + makeChangedCallback(llvm::StringRef name = "", + const deprecated *dep = nullptr) { std::shared_ptr changed = std::make_shared(false); + // Capture name and message by value for lambda lifetime. + std::string optName = name.str(); + std::string depMsg = dep ? dep->message.str() : ""; return std::pair{ [changed]() -> bool { return *changed; }, - llvm::cl::cb([changed](const V &) { *changed = true; })}; + llvm::cl::cb([changed, optName, depMsg](const V &) { + *changed = true; + if (!depMsg.empty()) { + llvm::errs() << "warning: --" << optName << " is deprecated; " + << depMsg << "\n"; + } + })}; } // Scalar default specialization. @@ -446,7 +470,7 @@ class OptionsBinder { }; } - // Finds the description in args + // Finds the description in args. template static llvm::cl::desc &filterDescription(Args &...args) { llvm::cl::desc *result = nullptr; @@ -463,6 +487,20 @@ class OptionsBinder { return *result; } + // Extracts deprecated modifier from args (returns nullptr if not found). + template + static const deprecated *filterDeprecated(const Args &...args) { + const deprecated *result = nullptr; + ( + [&] { + if constexpr (std::is_same_v, deprecated>) { + result = &args; + } + }(), + ...); + return result; + } + std::unique_ptr scope; OptionsStorage localOptions; diff --git a/runtime/src/iree/hal/utils/file_transfer.c b/runtime/src/iree/hal/utils/file_transfer.c index 6089af66c9b2..6be18daa8e44 100644 --- a/runtime/src/iree/hal/utils/file_transfer.c +++ b/runtime/src/iree/hal/utils/file_transfer.c @@ -593,7 +593,8 @@ static iree_status_t iree_hal_transfer_operation_launch_read( for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { iree_hal_transfer_worker_t* worker = &operation->workers[i]; alloca_semaphore_list.semaphores[i] = worker->semaphore; - alloca_semaphore_list.payload_values[i] = ++worker->pending_timepoint; + uint64_t signal_timepoint = ++worker->pending_timepoint; + alloca_semaphore_list.payload_values[i] = signal_timepoint; } IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_device_queue_alloca( @@ -666,16 +667,17 @@ static iree_status_t iree_hal_transfer_worker_copy_buffer_to_staging( IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)transfer_length); // Timeline increments by one. + uint64_t wait_timepoint = worker->pending_timepoint; iree_hal_semaphore_list_t wait_semaphore_list = { .count = 1, .semaphores = &worker->semaphore, - .payload_values = &worker->pending_timepoint, + .payload_values = &wait_timepoint, }; - ++worker->pending_timepoint; + uint64_t signal_timepoint = ++worker->pending_timepoint; iree_hal_semaphore_list_t signal_semaphore_list = { .count = 1, .semaphores = &worker->semaphore, - .payload_values = &worker->pending_timepoint, + .payload_values = &signal_timepoint, }; // Track the pending copy operation so we know where to place it in the file. @@ -692,8 +694,7 @@ static iree_status_t iree_hal_transfer_worker_copy_buffer_to_staging( // Wait for the copy to complete so we can write it to the file. if (iree_status_is_ok(status)) { status = iree_loop_wait_one( - loop, - iree_hal_semaphore_await(worker->semaphore, worker->pending_timepoint), + loop, iree_hal_semaphore_await(worker->semaphore, signal_timepoint), iree_infinite_timeout(), iree_hal_transfer_worker_copy_staging_to_file, worker); } @@ -785,7 +786,8 @@ static iree_status_t iree_hal_transfer_operation_launch_write( for (iree_host_size_t i = 0; i < operation->worker_count; ++i) { iree_hal_transfer_worker_t* worker = &operation->workers[i]; alloca_semaphore_list.semaphores[i] = worker->semaphore; - alloca_semaphore_list.payload_values[i] = ++worker->pending_timepoint; + uint64_t signal_timepoint = ++worker->pending_timepoint; + alloca_semaphore_list.payload_values[i] = signal_timepoint; } IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_device_queue_alloca( diff --git a/runtime/src/iree/io/formats/irpa/irpa_builder.c b/runtime/src/iree/io/formats/irpa/irpa_builder.c index e1459e898874..b8ca7001c598 100644 --- a/runtime/src/iree/io/formats/irpa/irpa_builder.c +++ b/runtime/src/iree/io/formats/irpa/irpa_builder.c @@ -48,6 +48,14 @@ iree_io_parameter_archive_builder_storage_offset( iree_io_parameter_archive_builder_storage_alignment(builder)); } +IREE_API_EXPORT iree_io_physical_size_t +iree_io_parameter_archive_builder_header_size( + const iree_io_parameter_archive_builder_t* builder) { + IREE_ASSERT_ARGUMENT(builder); + return (iree_io_physical_size_t) + iree_io_parameter_archive_builder_storage_offset(builder); +} + IREE_API_EXPORT iree_io_physical_size_t iree_io_parameter_archive_builder_total_size( const iree_io_parameter_archive_builder_t* builder) { diff --git a/runtime/src/iree/io/formats/irpa/irpa_builder.h b/runtime/src/iree/io/formats/irpa/irpa_builder.h index 46d9a7f11888..767c6d5c69e1 100644 --- a/runtime/src/iree/io/formats/irpa/irpa_builder.h +++ b/runtime/src/iree/io/formats/irpa/irpa_builder.h @@ -57,6 +57,13 @@ IREE_API_EXPORT void iree_io_parameter_archive_builder_deinitialize( IREE_API_EXPORT bool iree_io_parameter_archive_builder_is_empty( const iree_io_parameter_archive_builder_t* builder); +// Returns the size required to store the parameter archive header and +// associated metadata (excluding parameters). Adding new parameters will +// invalidate this value. +IREE_API_EXPORT iree_io_physical_size_t +iree_io_parameter_archive_builder_header_size( + const iree_io_parameter_archive_builder_t* builder); + // Returns the total file size required to store the parameter archive header // and contents of all added parameters. Adding new parameters will invalidate // this value. diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c index aa3b9941d580..51a87d2b9347 100644 --- a/runtime/src/iree/tooling/context_util.c +++ b/runtime/src/iree/tooling/context_util.c @@ -468,7 +468,8 @@ static iree_status_t iree_tooling_resolve_module_dependency_callback( } else if (iree_string_view_equal(dependency->name, IREE_SV("io_parameters"))) { IREE_RETURN_IF_ERROR(iree_tooling_create_parameters_module_from_flags( - state->instance, state->host_allocator, &module)); + state->instance, /*additional_provider_count=*/0, + /*additional_providers=*/NULL, state->host_allocator, &module)); } else { // Defer to the generic module resolver registry. IREE_RETURN_IF_ERROR(iree_tooling_resolve_module_dependency( diff --git a/runtime/src/iree/tooling/parameter_util.c b/runtime/src/iree/tooling/parameter_util.c index 9499abe6e4d9..34dd964aa168 100644 --- a/runtime/src/iree/tooling/parameter_util.c +++ b/runtime/src/iree/tooling/parameter_util.c @@ -125,8 +125,9 @@ iree_status_t iree_tooling_build_parameter_indices_from_flags( } iree_status_t iree_tooling_create_parameters_module_from_flags( - iree_vm_instance_t* instance, iree_allocator_t host_allocator, - iree_vm_module_t** out_module) { + iree_vm_instance_t* instance, iree_host_size_t additional_provider_count, + iree_io_parameter_provider_t** additional_providers, + iree_allocator_t host_allocator, iree_vm_module_t** out_module) { IREE_TRACE_ZONE_BEGIN(z0); iree_io_scope_map_t scope_map; @@ -136,9 +137,9 @@ iree_status_t iree_tooling_create_parameters_module_from_flags( iree_status_t status = iree_tooling_build_parameter_indices_from_flags(&scope_map); - // Create one provider per scope. - iree_host_size_t provider_count = 0; - iree_io_parameter_provider_t** providers = + // Create one provider per scope from flags. + iree_host_size_t flag_provider_count = 0; + iree_io_parameter_provider_t** flag_providers = (iree_io_parameter_provider_t**)iree_alloca( scope_map.count * sizeof(iree_io_parameter_provider_t*)); if (iree_status_is_ok(status)) { @@ -146,21 +147,36 @@ iree_status_t iree_tooling_create_parameters_module_from_flags( status = iree_io_parameter_index_provider_create( scope_map.entries[i]->scope, scope_map.entries[i]->index, IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS, - host_allocator, &providers[i]); + host_allocator, &flag_providers[i]); if (!iree_status_is_ok(status)) break; - ++provider_count; + ++flag_provider_count; } } - // Create the module with the list of providers. + // Merge flag-created providers with additional providers. + iree_host_size_t total_provider_count = + flag_provider_count + additional_provider_count; + iree_io_parameter_provider_t** all_providers = + (iree_io_parameter_provider_t**)iree_alloca( + total_provider_count * sizeof(iree_io_parameter_provider_t*)); + for (iree_host_size_t i = 0; i < flag_provider_count; ++i) { + all_providers[i] = flag_providers[i]; + } + for (iree_host_size_t i = 0; i < additional_provider_count; ++i) { + all_providers[flag_provider_count + i] = additional_providers[i]; + } + + // Create the module with the merged list of providers. if (iree_status_is_ok(status)) { - status = iree_io_parameters_module_create( - instance, provider_count, providers, host_allocator, out_module); + status = iree_io_parameters_module_create(instance, total_provider_count, + all_providers, host_allocator, + out_module); } // Cleanup (module owns providers which own indices/etc). - for (iree_host_size_t i = 0; i < provider_count; ++i) { - iree_io_parameter_provider_release(providers[i]); + // Only release flag providers - additional providers are owned by caller. + for (iree_host_size_t i = 0; i < flag_provider_count; ++i) { + iree_io_parameter_provider_release(flag_providers[i]); } iree_io_scope_map_deinitialize(&scope_map); diff --git a/runtime/src/iree/tooling/parameter_util.h b/runtime/src/iree/tooling/parameter_util.h index a0e46c12fd0e..f1633416d0e4 100644 --- a/runtime/src/iree/tooling/parameter_util.h +++ b/runtime/src/iree/tooling/parameter_util.h @@ -21,10 +21,17 @@ typedef struct iree_io_scope_map_t iree_io_scope_map_t; iree_status_t iree_tooling_build_parameter_indices_from_flags( iree_io_scope_map_t* scope_map); +typedef struct iree_io_parameter_provider_t iree_io_parameter_provider_t; + // Builds an I/O parameters module based on the runtime flags provided. +// If |additional_provider_count| is non-zero then |additional_providers| +// contains providers that will be added to the module in addition to those +// parsed from --parameters= flags. Additional providers are retained by the +// module and can be released by the caller after this call returns. iree_status_t iree_tooling_create_parameters_module_from_flags( - iree_vm_instance_t* instance, iree_allocator_t host_allocator, - iree_vm_module_t** out_module); + iree_vm_instance_t* instance, iree_host_size_t additional_provider_count, + iree_io_parameter_provider_t** additional_providers, + iree_allocator_t host_allocator, iree_vm_module_t** out_module); #ifdef __cplusplus } // extern "C" diff --git a/tests/e2e/parameters/BUILD.bazel b/tests/e2e/parameters/BUILD.bazel index 81e36b107eb9..3bdb9d6daf98 100644 --- a/tests/e2e/parameters/BUILD.bazel +++ b/tests/e2e/parameters/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "encode_parameters.mlir", "export_parameters.mlir", "generate_splat_archive.mlir", ], @@ -29,6 +30,7 @@ iree_lit_test_suite( tools = [ "//tools:iree-compile", "//tools:iree-dump-parameters", + "//tools:iree-encode-parameters", "//tools:iree-run-module", "@llvm-project//llvm:FileCheck", ], diff --git a/tests/e2e/parameters/CMakeLists.txt b/tests/e2e/parameters/CMakeLists.txt index 61ad996589da..e24056bae27c 100644 --- a/tests/e2e/parameters/CMakeLists.txt +++ b/tests/e2e/parameters/CMakeLists.txt @@ -15,12 +15,14 @@ iree_lit_test_suite( NAME lit SRCS + "encode_parameters.mlir" "export_parameters.mlir" "generate_splat_archive.mlir" TOOLS FileCheck iree-compile iree-dump-parameters + iree-encode-parameters iree-run-module LABELS "driver=local-task" diff --git a/tests/e2e/parameters/encode_parameters.mlir b/tests/e2e/parameters/encode_parameters.mlir new file mode 100644 index 000000000000..9613be11f520 --- /dev/null +++ b/tests/e2e/parameters/encode_parameters.mlir @@ -0,0 +1,68 @@ +// RUN: rm -f %t_main.vmfb %t_encoder.mlir %t_encoder.vmfb %t_input.irpa %t_output.irpa +// +// Compile main module with encoder MLIR output and splat parameter export. +// RUN: iree-compile %s \ +// RUN: --iree-hal-target-device=local \ +// RUN: --iree-hal-local-target-device-backends=vmvx \ +// RUN: --iree-parameter-encoder-output-file=%t_encoder.mlir \ +// RUN: --iree-parameter-splat=%t_input.irpa \ +// RUN: -o %t_main.vmfb +// +// Compile the encoder module separately. +// RUN: iree-compile %t_encoder.mlir \ +// RUN: --iree-hal-target-device=local \ +// RUN: --iree-hal-local-target-device-backends=vmvx \ +// RUN: -o %t_encoder.vmfb +// +// Run the encoder to transform parameters. +// RUN: iree-encode-parameters \ +// RUN: --module=%t_encoder.vmfb \ +// RUN: --parameters=model=%t_input.irpa \ +// RUN: --output=encoded=%t_output.irpa \ +// RUN: --quiet +// +// Run the main module with both input and encoded parameters. +// The encoded parameters contain the pre-computed transformed values. +// RUN: iree-run-module \ +// RUN: --device=local-sync \ +// RUN: --module=%t_main.vmfb \ +// RUN: --function=main \ +// RUN: --parameters=model=%t_input.irpa \ +// RUN: --parameters=encoded=%t_output.irpa | \ +// RUN: FileCheck %s + +// Test parameter transformation with encoder. +// The global loads a parameter and applies an add operation to transform it. +// The encoder runs the add offline, and the main module loads the +// pre-computed result from the encoded parameter scope. + +// CHECK-LABEL: EXEC @main +// CHECK: 256xi32=42 42 42 42 + +// Parameter loaded from input archive (model scope). +// The splat export creates this with all zeros. +util.global private @raw_param = #flow.parameter.named<"model"::"param_global"> : tensor<256xi32> + +// This global holds the transformed value. +util.global private @transformed : tensor<256xi32> + +util.initializer { + // Load the raw parameter (all zeros from splat). + %raw = util.global.load @raw_param : tensor<256xi32> + // Add 42 to each element - this uses the parameter values and can be encoded. + // With input of 0s, result is 42s. + %c42 = arith.constant 42 : i32 + %init = tensor.empty() : tensor<256xi32> + %c42_tensor = linalg.fill ins(%c42 : i32) outs(%init : tensor<256xi32>) -> tensor<256xi32> + %added = linalg.add ins(%raw, %c42_tensor : tensor<256xi32>, tensor<256xi32>) outs(%init : tensor<256xi32>) -> tensor<256xi32> + util.global.store %added, @transformed : tensor<256xi32> + util.return +} + +func.func @main() -> tensor<256xi32> { + // Load and return the full transformed tensor. + // If encoding worked, all elements should be 42 (0 + 42). + // If encoding didn't work, all elements would be 0 (splat init). + %tensor = util.global.load @transformed : tensor<256xi32> + return %tensor : tensor<256xi32> +} diff --git a/tests/e2e/parameters/export_parameters.mlir b/tests/e2e/parameters/export_parameters.mlir index 4288a90bd823..ef8321ee6ffd 100644 --- a/tests/e2e/parameters/export_parameters.mlir +++ b/tests/e2e/parameters/export_parameters.mlir @@ -1,8 +1,8 @@ // RUN: iree-compile %s \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ -// RUN: --iree-opt-export-parameters=scope=%t.irpa \ -// RUN: --iree-opt-export-parameter-minimum-size=0 | \ +// RUN: --iree-parameter-export=scope=%t.irpa \ +// RUN: --iree-parameter-export-minimum-size=0 | \ // RUN: iree-run-module \ // RUN: --device=local-sync \ // RUN: --module=- \ diff --git a/tests/e2e/parameters/generate_splat_archive.mlir b/tests/e2e/parameters/generate_splat_archive.mlir index 79a8b632ddd9..fb888c7aef65 100644 --- a/tests/e2e/parameters/generate_splat_archive.mlir +++ b/tests/e2e/parameters/generate_splat_archive.mlir @@ -2,7 +2,7 @@ // RUN: iree-compile %s \ // RUN: --iree-hal-target-device=local \ // RUN: --iree-hal-local-target-device-backends=vmvx \ -// RUN: --iree-opt-splat-parameters=%t.irpa | \ +// RUN: --iree-parameter-splat=%t.irpa | \ // RUN: iree-run-module \ // RUN: --device=local-sync \ // RUN: --module=- \ diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel index 666779b937c5..32fce9c6f8c5 100644 --- a/tools/BUILD.bazel +++ b/tools/BUILD.bazel @@ -154,6 +154,27 @@ iree_runtime_cc_binary( ], ) +iree_runtime_cc_binary( + name = "iree-encode-parameters", + srcs = ["iree-encode-parameters-main.c"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal:flags", + "//runtime/src/iree/hal", + "//runtime/src/iree/io:file_handle", + "//runtime/src/iree/io:parameter_index", + "//runtime/src/iree/io:parameter_index_provider", + "//runtime/src/iree/io:scope_map", + "//runtime/src/iree/io:stream", + "//runtime/src/iree/io/formats/irpa", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/tooling:context_util", + "//runtime/src/iree/tooling:function_util", + "//runtime/src/iree/tooling:parameter_util", + "//runtime/src/iree/vm", + ], +) + iree_runtime_cc_binary( name = "iree-fatelf", srcs = ["iree-fatelf.c"], diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 7a70cfb36300..329f589cc70d 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -221,6 +221,30 @@ iree_cc_binary( INSTALL_COMPONENT IREETools-Runtime ) +iree_cc_binary( + NAME + iree-encode-parameters + SRCS + "iree-encode-parameters-main.c" + DEPS + iree::base + iree::base::internal::flags + iree::hal + iree::io::file_handle + iree::io::formats::irpa + iree::io::parameter_index + iree::io::parameter_index_provider + iree::io::scope_map + iree::io::stream + iree::modules::hal + iree::tooling::context_util + iree::tooling::function_util + iree::tooling::parameter_util + iree::vm + COVERAGE ${IREE_ENABLE_RUNTIME_COVERAGE} + INSTALL_COMPONENT IREETools-Runtime +) + # Only enable fatelf tool when we're compiling it in. # Currently it requires that the host and target both support embedded ELFs as # the ELF implementation is only compiled when the target supports it. diff --git a/tools/iree-encode-parameters-main.c b/tools/iree-encode-parameters-main.c new file mode 100644 index 000000000000..447bf188cec6 --- /dev/null +++ b/tools/iree-encode-parameters-main.c @@ -0,0 +1,1116 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/hal/api.h" +#include "iree/io/file_handle.h" +#include "iree/io/formats/irpa/irpa_builder.h" +#include "iree/io/parameter_index.h" +#include "iree/io/parameter_index_provider.h" +#include "iree/io/scope_map.h" +#include "iree/io/stream.h" +#include "iree/modules/hal/module.h" +#include "iree/tooling/context_util.h" +#include "iree/tooling/function_util.h" +#include "iree/tooling/parameter_util.h" +#include "iree/vm/api.h" + +//===----------------------------------------------------------------------===// +// Flags +//===----------------------------------------------------------------------===// + +IREE_FLAG(bool, list_targets, false, + "Lists the targets an encoding module can produce parameters for and " + "exit."); + +IREE_FLAG(bool, list_parameters, false, + "Lists the parameters that will be encoded and exit."); + +IREE_FLAG(string, target, "", + "Target to use for encoding. If not specified, uses auto-detection."); + +IREE_FLAG(bool, quiet, false, + "Suppress output except for errors. Exit code indicates success."); + +IREE_FLAG_LIST(string, output, + "Specifies an output parameter file per scope.\n" + "Format: `scope=path.irpa` or `path.irpa` for default scope.\n" + "Example: `--output=encoded=output.irpa`"); + +//===----------------------------------------------------------------------===// +// Encoder target discovery +//===----------------------------------------------------------------------===// + +// Encoder function set for a single target. +typedef struct iree_encode_target_t { + iree_string_view_t target; + iree_vm_function_t indices_fn; + iree_vm_function_t steps_fn; + iree_vm_function_t encode_fn; +} iree_encode_target_t; + +// Storage for discovered encoder targets. +typedef struct iree_encode_target_set_t { + iree_vm_function_t detect_target_fn; + iree_host_size_t target_count; + iree_host_size_t target_capacity; + iree_encode_target_t* targets; + iree_allocator_t allocator; +} iree_encode_target_set_t; + +static void iree_encode_target_set_initialize( + iree_allocator_t allocator, iree_encode_target_set_t* out_target_set) { + memset(out_target_set, 0, sizeof(*out_target_set)); + out_target_set->allocator = allocator; +} + +static void iree_encode_target_set_deinitialize( + iree_encode_target_set_t* target_set) { + if (target_set->targets) { + iree_allocator_free(target_set->allocator, target_set->targets); + } + memset(target_set, 0, sizeof(*target_set)); +} + +static iree_status_t iree_encode_target_set_add( + iree_encode_target_set_t* target_set, iree_string_view_t target_name, + iree_encode_target_t** out_target) { + // Check if target already exists. + for (iree_host_size_t i = 0; i < target_set->target_count; ++i) { + if (iree_string_view_equal(target_set->targets[i].target, target_name)) { + *out_target = &target_set->targets[i]; + return iree_ok_status(); + } + } + // Grow if needed. + if (target_set->target_count >= target_set->target_capacity) { + iree_host_size_t new_capacity = + target_set->target_capacity ? target_set->target_capacity * 2 : 4; + IREE_RETURN_IF_ERROR(iree_allocator_realloc( + target_set->allocator, new_capacity * sizeof(iree_encode_target_t), + (void**)&target_set->targets)); + target_set->target_capacity = new_capacity; + } + // Add new target. + iree_encode_target_t* target = &target_set->targets[target_set->target_count]; + memset(target, 0, sizeof(*target)); + target->target = target_name; + ++target_set->target_count; + *out_target = target; + return iree_ok_status(); +} + +// Looks up a reflection attribute value by key. +static iree_string_view_t iree_encode_lookup_reflection_attr( + iree_vm_function_t* function, iree_string_view_t key) { + return iree_vm_function_lookup_attr_by_name(function, key); +} + +// Discovers encoder functions from the module by scanning exported function +// attributes. +static iree_status_t iree_encode_discover_functions( + iree_vm_module_t* module, iree_encode_target_set_t* target_set) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_vm_module_signature_t signature = iree_vm_module_signature(module); + + for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) { + iree_vm_function_t function; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_module_lookup_function_by_ordinal( + module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); + + // Check for iree.encode.function attribute. + iree_string_view_t encode_function = iree_encode_lookup_reflection_attr( + &function, IREE_SV("iree.encode.function")); + if (iree_string_view_is_empty(encode_function)) continue; + + if (iree_string_view_equal(encode_function, IREE_SV("detect_target"))) { + target_set->detect_target_fn = function; + } else { + // Get target name for indices/steps/encode functions. + iree_string_view_t target_name = iree_encode_lookup_reflection_attr( + &function, IREE_SV("iree.encode.target")); + if (iree_string_view_is_empty(target_name)) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "encoder function missing iree.encode.target"); + } + + iree_encode_target_t* target = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_encode_target_set_add(target_set, target_name, &target)); + + if (iree_string_view_equal(encode_function, IREE_SV("indices"))) { + target->indices_fn = function; + } else if (iree_string_view_equal(encode_function, IREE_SV("steps"))) { + target->steps_fn = function; + } else if (iree_string_view_equal(encode_function, IREE_SV("encode"))) { + target->encode_fn = function; + } + } + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Output scope/archive types +//===----------------------------------------------------------------------===// + +typedef struct iree_output_scope_t { + iree_string_view_t scope; + iree_string_view_t path; +} iree_output_scope_t; + +typedef struct iree_output_scope_list_t { + iree_host_size_t count; + iree_output_scope_t* entries; + iree_allocator_t allocator; +} iree_output_scope_list_t; + +static void iree_output_scope_list_initialize(iree_allocator_t allocator, + iree_output_scope_list_t* list) { + memset(list, 0, sizeof(*list)); + list->allocator = allocator; +} + +static void iree_output_scope_list_deinitialize( + iree_output_scope_list_t* list) { + if (list->entries) { + iree_allocator_free(list->allocator, list->entries); + } + memset(list, 0, sizeof(*list)); +} + +// Archive context for a single output scope. +typedef struct iree_output_archive_t { + iree_string_view_t scope; + iree_string_view_t path; + iree_io_parameter_archive_builder_t builder; + iree_io_file_handle_t* file_handle; + iree_io_parameter_index_t* index; + iree_io_parameter_provider_t* provider; +} iree_output_archive_t; + +static void iree_output_archive_deinitialize(iree_output_archive_t* archive) { + iree_io_parameter_provider_release(archive->provider); + iree_io_parameter_index_release(archive->index); + iree_io_file_handle_release(archive->file_handle); + iree_io_parameter_archive_builder_deinitialize(&archive->builder); +} + +//===----------------------------------------------------------------------===// +// Load modules and discover encoder functions +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_load_and_discover( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_tooling_module_list_t* out_module_list, + iree_vm_module_t** out_encoder_module, + iree_encode_target_set_t* out_target_set) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_tooling_module_list_initialize(out_module_list); + iree_encode_target_set_initialize(host_allocator, out_target_set); + + // Load modules from flags. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_load_modules_from_flags(instance, host_allocator, + out_module_list)); + + if (out_module_list->count == 0) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no modules specified; use --module=path.vmfb"); + } + + // Encoder module is the last module (by convention). + *out_encoder_module = out_module_list->values[out_module_list->count - 1]; + + // Discover encoder functions. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_encode_discover_functions(*out_encoder_module, out_target_set)); + + if (out_target_set->target_count == 0) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "no encoder functions found in module; ensure the module was produced " + "by iree-compile with --iree-parameter-encoder-output-file"); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Select target +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_select_target( + iree_encode_target_set_t* target_set, + iree_encode_target_t** out_selected_target) { + iree_string_view_t target_flag = iree_make_cstring_view(FLAG_target); + + if (iree_string_view_is_empty(target_flag)) { + // Use first target. + *out_selected_target = &target_set->targets[0]; + return iree_ok_status(); + } + + // Find matching target. + for (iree_host_size_t i = 0; i < target_set->target_count; ++i) { + if (iree_string_view_equal(target_set->targets[i].target, target_flag)) { + *out_selected_target = &target_set->targets[i]; + return iree_ok_status(); + } + } + + return iree_make_status(IREE_STATUS_NOT_FOUND, + "target '%s' not found in encoder module; " + "use --list-targets to see available targets", + FLAG_target); +} + +static iree_status_t iree_encode_validate_target(iree_encode_target_t* target) { + if (!target->indices_fn.module) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "indices function not found for target '%.*s'; " + "encoder module may be incomplete", + (int)target->target.size, target->target.data); + } + if (!target->encode_fn.module) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "encode function not found for target '%.*s'; " + "encoder module may be incomplete", + (int)target->target.size, target->target.data); + } + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// --list_targets implementation +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_print_targets( + iree_vm_module_t* encoder_module, iree_encode_target_set_t* target_set) { + iree_string_view_t module_name = iree_vm_module_name(encoder_module); + fprintf(stdout, "Encoder module: %.*s\n", (int)module_name.size, + module_name.data); + fprintf(stdout, "Available targets:\n"); + + for (iree_host_size_t i = 0; i < target_set->target_count; ++i) { + iree_encode_target_t* target = &target_set->targets[i]; + fprintf(stdout, " %.*s\n", (int)target->target.size, target->target.data); + + iree_string_view_t scopes = iree_encode_lookup_reflection_attr( + &target->indices_fn, IREE_SV("iree.encode.scopes")); + if (!iree_string_view_is_empty(scopes)) { + fprintf(stdout, " scopes: %.*s\n", (int)scopes.size, scopes.data); + } + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Call indices function +//===----------------------------------------------------------------------===// + +// Creates a temporary context and calls the indices function. +// The indices function returns constant data and doesn't need parameters. +// TODO(benvanik): Consider calling without full context if function has no +// imports. +static iree_status_t iree_encode_call_indices( + iree_vm_instance_t* instance, iree_tooling_module_list_t* module_list, + iree_encode_target_t* target, iree_allocator_t host_allocator, + iree_vm_list_t** out_indices_list) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_vm_context_t* context = NULL; + iree_hal_device_t* device = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_create_context_from_flags( + instance, module_list->count, module_list->values, + /*default_device_uri=*/iree_string_view_empty(), host_allocator, + &context, &device, /*out_device_allocator=*/NULL)); + + // Invoke indices function. + iree_vm_list_t* outputs = NULL; + iree_status_t status = iree_vm_list_create(iree_vm_make_undefined_type_def(), + 1, host_allocator, &outputs); + if (iree_status_is_ok(status)) { + status = iree_vm_invoke( + context, target->indices_fn, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, /*inputs=*/NULL, outputs, host_allocator); + } + + // Extract result list. + if (iree_status_is_ok(status)) { + iree_vm_ref_t list_ref = iree_vm_ref_null(); + status = iree_vm_list_get_ref_assign(outputs, 0, &list_ref); + if (iree_status_is_ok(status)) { + *out_indices_list = iree_vm_list_deref(list_ref); + if (*out_indices_list) { + iree_vm_list_retain(*out_indices_list); + } + } + } + + iree_vm_list_release(outputs); + iree_hal_device_release(device); + iree_vm_context_release(context); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// --list_parameters implementation +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_print_parameters( + iree_vm_list_t* indices_list) { + iree_host_size_t scope_count = iree_vm_list_size(indices_list); + + for (iree_host_size_t scope_i = 0; scope_i < scope_count; ++scope_i) { + iree_vm_ref_t scope_struct_ref = iree_vm_ref_null(); + if (!iree_status_is_ok(iree_vm_list_get_ref_assign(indices_list, scope_i, + &scope_struct_ref))) { + continue; + } + iree_vm_list_t* scope_struct = iree_vm_list_deref(scope_struct_ref); + if (!scope_struct || iree_vm_list_size(scope_struct) < 2) continue; + + // Get scope name. + iree_vm_ref_t scope_name_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(scope_struct, 0, &scope_name_ref); + iree_vm_buffer_t* scope_name_buffer = iree_vm_buffer_deref(scope_name_ref); + iree_string_view_t scope_name = + scope_name_buffer ? iree_vm_buffer_as_string(scope_name_buffer) + : IREE_SV(""); + + fprintf(stdout, "Scope: \"%.*s\"\n", (int)scope_name.size, scope_name.data); + + // Get entries list. + iree_vm_ref_t entries_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(scope_struct, 1, &entries_ref); + iree_vm_list_t* entries = iree_vm_list_deref(entries_ref); + if (!entries) continue; + + // Print each entry. + iree_host_size_t entry_count = iree_vm_list_size(entries); + for (iree_host_size_t entry_i = 0; entry_i < entry_count; ++entry_i) { + iree_vm_ref_t entry_ref = iree_vm_ref_null(); + if (!iree_status_is_ok( + iree_vm_list_get_ref_assign(entries, entry_i, &entry_ref))) { + continue; + } + iree_vm_list_t* entry = iree_vm_list_deref(entry_ref); + if (!entry || iree_vm_list_size(entry) < 5) continue; + + iree_vm_value_t type_value, length_value; + iree_vm_list_get_value(entry, 0, &type_value); + iree_vm_list_get_value(entry, 3, &length_value); + + iree_vm_ref_t key_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(entry, 1, &key_ref); + iree_vm_buffer_t* key_buffer = iree_vm_buffer_deref(key_ref); + iree_string_view_t key = key_buffer ? iree_vm_buffer_as_string(key_buffer) + : IREE_SV(""); + + if (type_value.i64 == 0) { + // SPLAT entry. + iree_vm_value_t pattern_value, pattern_length_value; + iree_vm_list_get_value(entry, 4, &pattern_value); + iree_vm_list_get_value(entry, 5, &pattern_length_value); + fprintf(stdout, + " %.*s: SPLAT, %" PRIu64 " bytes, pattern=0x%0*" PRIx64 "\n", + (int)key.size, key.data, (uint64_t)length_value.i64, + (int)pattern_length_value.i64 * 2, (uint64_t)pattern_value.i64); + } else { + // DATA entry. + iree_vm_value_t alignment_value; + iree_vm_list_get_value(entry, 4, &alignment_value); + fprintf(stdout, + " %.*s: DATA, %" PRIu64 " bytes, alignment %" PRIu64 "\n", + (int)key.size, key.data, (uint64_t)length_value.i64, + (uint64_t)alignment_value.i64); + } + } + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Parse output flags +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_parse_output_flags( + iree_output_scope_list_t* list) { + iree_host_size_t count = FLAG_output_list().count; + if (count == 0) return iree_ok_status(); + + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + list->allocator, count * sizeof(iree_output_scope_t), + (void**)&list->entries)); + list->count = count; + + for (iree_host_size_t i = 0; i < count; ++i) { + iree_string_view_t flag = FLAG_output_list().values[i]; + iree_string_view_t scope, path; + if (iree_string_view_split(flag, '=', &scope, &path) == -1) { + // No scope provided - use empty scope. + path = scope; + scope = iree_string_view_empty(); + } + list->entries[i].scope = scope; + list->entries[i].path = path; + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Create output archives +//===----------------------------------------------------------------------===// + +// Parses parameter indices and populates archive builders. +static iree_status_t iree_encode_parse_indices_into_archives( + iree_vm_list_t* indices_list, iree_output_archive_t* archives, + iree_host_size_t archive_count) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_host_size_t scope_count = iree_vm_list_size(indices_list); + for (iree_host_size_t scope_i = 0; scope_i < scope_count; ++scope_i) { + // Get scope struct: [scope_name, entries_list]. + iree_vm_ref_t scope_struct_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_vm_list_get_ref_assign(indices_list, scope_i, &scope_struct_ref)); + iree_vm_list_t* scope_struct = iree_vm_list_deref(scope_struct_ref); + if (!scope_struct || iree_vm_list_size(scope_struct) < 2) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid scope struct in indices"); + } + + // Get scope name. + iree_vm_ref_t scope_name_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(scope_struct, 0, &scope_name_ref)); + iree_vm_buffer_t* scope_name_buffer = iree_vm_buffer_deref(scope_name_ref); + iree_string_view_t scope_name = + scope_name_buffer ? iree_vm_buffer_as_string(scope_name_buffer) + : iree_string_view_empty(); + + // Find matching archive. + iree_output_archive_t* archive = NULL; + for (iree_host_size_t j = 0; j < archive_count; ++j) { + if (iree_string_view_equal(archives[j].scope, scope_name)) { + archive = &archives[j]; + break; + } + } + if (!archive) continue; // Scope not in output list. + + // Get entries list. + iree_vm_ref_t entries_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(scope_struct, 1, &entries_ref)); + iree_vm_list_t* entries = iree_vm_list_deref(entries_ref); + if (!entries) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid entries in scope struct"); + } + + // Process each parameter entry. + iree_host_size_t entry_count = iree_vm_list_size(entries); + for (iree_host_size_t entry_i = 0; entry_i < entry_count; ++entry_i) { + iree_vm_ref_t entry_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(entries, entry_i, &entry_ref)); + iree_vm_list_t* entry = iree_vm_list_deref(entry_ref); + if (!entry || iree_vm_list_size(entry) < 5) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid entry in entries list"); + } + + // Parse entry fields: [type, key, metadata, length, ...]. + iree_vm_value_t type_value, length_value; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 0, &type_value)); + + iree_vm_ref_t key_ref = iree_vm_ref_null(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_ref_assign(entry, 1, &key_ref)); + iree_vm_buffer_t* key_buffer = iree_vm_buffer_deref(key_ref); + if (!key_buffer) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "parameter entry missing key"); + } + iree_string_view_t key = iree_vm_buffer_as_string(key_buffer); + + iree_vm_ref_t metadata_ref = iree_vm_ref_null(); + iree_vm_list_get_ref_assign(entry, 2, &metadata_ref); + iree_vm_buffer_t* metadata_buffer = iree_vm_buffer_deref(metadata_ref); + iree_const_byte_span_t metadata = iree_const_byte_span_empty(); + if (metadata_buffer) { + metadata = iree_vm_buffer_const_contents(metadata_buffer); + } + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 3, &length_value)); + uint64_t length = (uint64_t)length_value.i64; + + if (type_value.i64 == 0) { + // SPLAT entry: [type, key, metadata, length, pattern, pattern_length]. + iree_vm_value_t pattern_value, pattern_length_value; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 4, &pattern_value)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 5, &pattern_length_value)); + + uint64_t pattern = (uint64_t)pattern_value.i64; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_parameter_archive_builder_add_splat_entry( + &archive->builder, key, metadata, &pattern, + (uint8_t)pattern_length_value.i64, length)); + } else { + // DATA entry: [type, key, metadata, length, alignment]. + iree_vm_value_t alignment_value; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_value(entry, 4, &alignment_value)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_parameter_archive_builder_add_data_entry( + &archive->builder, key, metadata, + (uint64_t)alignment_value.i64, length)); + } + } + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Creates archive files and providers for each output scope. +static iree_status_t iree_encode_create_archives( + iree_vm_list_t* indices_list, iree_output_scope_list_t* output_list, + iree_allocator_t host_allocator, iree_output_archive_t** out_archives) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Allocate archive array. + iree_output_archive_t* archives = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, + output_list->count * sizeof(iree_output_archive_t), + (void**)&archives)); + + // Initialize archive builders. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < output_list->count; ++i) { + memset(&archives[i], 0, sizeof(archives[i])); + archives[i].scope = output_list->entries[i].scope; + archives[i].path = output_list->entries[i].path; + status = iree_io_parameter_archive_builder_initialize(host_allocator, + &archives[i].builder); + if (!iree_status_is_ok(status)) break; + } + + // Parse indices into archive builders. + if (iree_status_is_ok(status)) { + status = iree_encode_parse_indices_into_archives(indices_list, archives, + output_list->count); + } + + // Create files and write headers. + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < output_list->count; ++i) { + iree_output_archive_t* archive = &archives[i]; + + iree_io_physical_size_t archive_size = + iree_io_parameter_archive_builder_total_size(&archive->builder); + + // Create null-terminated path. + char* path_cstr = NULL; + status = iree_allocator_malloc(host_allocator, archive->path.size + 1, + (void**)&path_cstr); + if (!iree_status_is_ok(status)) break; + memcpy(path_cstr, archive->path.data, archive->path.size); + path_cstr[archive->path.size] = '\0'; + + // Create output file. + status = iree_io_file_handle_create( + IREE_IO_FILE_MODE_READ | IREE_IO_FILE_MODE_WRITE, + iree_make_cstring_view(path_cstr), archive_size, host_allocator, + &archive->file_handle); + iree_allocator_free(host_allocator, path_cstr); + if (!iree_status_is_ok(status)) break; + + // Create stream and index. + iree_io_stream_t* stream = NULL; + status = + iree_io_stream_open(IREE_IO_STREAM_MODE_WRITABLE, + archive->file_handle, 0, host_allocator, &stream); + if (!iree_status_is_ok(status)) break; + + status = iree_io_parameter_index_create(host_allocator, &archive->index); + if (!iree_status_is_ok(status)) { + iree_io_stream_release(stream); + break; + } + + // Write archive header. + status = iree_io_parameter_archive_builder_write( + &archive->builder, archive->file_handle, 0, stream, archive->index); + iree_io_stream_release(stream); + if (!iree_status_is_ok(status)) break; + + // Create provider backed by the archive. + status = iree_io_parameter_index_provider_create( + archive->scope, archive->index, + IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS, + host_allocator, &archive->provider); + if (!iree_status_is_ok(status)) break; + } + } + + if (!iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < output_list->count; ++i) { + iree_output_archive_deinitialize(&archives[i]); + } + iree_allocator_free(host_allocator, archives); + IREE_TRACE_ZONE_END(z0); + return status; + } + + *out_archives = archives; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Create encoding context with output providers +//===----------------------------------------------------------------------===// + +// Creates the encoding context with output providers attached. +// TODO(benvanik): Allow adding providers to existing parameters module to avoid +// recreating context. +static iree_status_t iree_encode_create_encoding_context( + iree_vm_instance_t* instance, iree_tooling_module_list_t* module_list, + iree_output_archive_t* archives, iree_host_size_t archive_count, + iree_allocator_t host_allocator, iree_vm_context_t** out_context, + iree_hal_device_t** out_device) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Collect output providers. + iree_host_size_t provider_count = 0; + for (iree_host_size_t i = 0; i < archive_count; ++i) { + if (archives[i].provider) ++provider_count; + } + + iree_io_parameter_provider_t** providers = + (iree_io_parameter_provider_t**)iree_alloca( + provider_count * sizeof(iree_io_parameter_provider_t*)); + for (iree_host_size_t i = 0, j = 0; i < archive_count; ++i) { + if (archives[i].provider) { + providers[j++] = archives[i].provider; + } + } + + // Create parameters module with output providers. + iree_vm_module_t* params_module = NULL; + iree_status_t status = iree_tooling_create_parameters_module_from_flags( + instance, provider_count, providers, host_allocator, ¶ms_module); + + // Pre-populate resolved_list with params module so resolver won't create + // default. + iree_tooling_module_list_t resolved_list; + iree_tooling_module_list_initialize(&resolved_list); + + if (iree_status_is_ok(status)) { + status = iree_tooling_module_list_push_back(&resolved_list, params_module); + } + + // Resolve dependencies (adds HAL, etc.). + if (iree_status_is_ok(status)) { + status = iree_tooling_resolve_modules( + instance, module_list->count, module_list->values, + /*default_device_uri=*/iree_string_view_empty(), host_allocator, + &resolved_list, out_device, /*out_device_allocator=*/NULL); + } + + // Create context. + if (iree_status_is_ok(status)) { + status = iree_vm_context_create_with_modules( + instance, IREE_VM_CONTEXT_FLAG_NONE, resolved_list.count, + resolved_list.values, host_allocator, out_context); + } + + iree_tooling_module_list_reset(&resolved_list); + iree_vm_module_release(params_module); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Call steps function +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_call_steps(iree_vm_context_t* context, + iree_encode_target_t* target, + iree_allocator_t host_allocator, + iree_vm_list_t** out_steps_list) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_steps_list = NULL; + if (!target->steps_fn.module) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); // Steps function is optional. + } + + iree_vm_list_t* outputs = NULL; + iree_status_t status = iree_vm_list_create(iree_vm_make_undefined_type_def(), + 1, host_allocator, &outputs); + if (iree_status_is_ok(status)) { + status = iree_vm_invoke( + context, target->steps_fn, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, /*inputs=*/NULL, outputs, host_allocator); + } + + if (iree_status_is_ok(status)) { + iree_vm_ref_t list_ref = iree_vm_ref_null(); + status = iree_vm_list_get_ref_assign(outputs, 0, &list_ref); + if (iree_status_is_ok(status)) { + *out_steps_list = iree_vm_list_deref(list_ref); + if (*out_steps_list) { + iree_vm_list_retain(*out_steps_list); + } + } + } + + iree_vm_list_release(outputs); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Execute encoder +//===----------------------------------------------------------------------===// + +static iree_status_t iree_encode_execute(iree_vm_context_t* context, + iree_hal_device_t* device, + iree_encode_target_t* target, + iree_vm_list_t* steps_list, + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Build inputs: [steps_list, wait_fence, signal_fence]. + iree_vm_list_t* inputs = NULL; + iree_status_t status = iree_vm_list_create(iree_vm_make_undefined_type_def(), + 3, host_allocator, &inputs); + + // Push steps list (may be NULL). + if (iree_status_is_ok(status)) { + if (steps_list) { + iree_vm_ref_t steps_ref = iree_vm_list_retain_ref(steps_list); + status = iree_vm_list_push_ref_move(inputs, &steps_ref); + } else { + iree_vm_ref_t null_ref = iree_vm_ref_null(); + status = iree_vm_list_push_ref_move(inputs, &null_ref); + } + } + + // Append async fences. + iree_hal_fence_t* signal_fence = NULL; + if (iree_status_is_ok(status)) { + status = + iree_tooling_append_async_fences(inputs, target->encode_fn, device, + /*wait_fence=*/NULL, &signal_fence); + } + + // Invoke encoder. + if (iree_status_is_ok(status)) { + status = iree_vm_invoke( + context, target->encode_fn, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, inputs, /*outputs=*/NULL, host_allocator); + } + + iree_vm_list_release(inputs); + + // Wait for completion. + if (iree_status_is_ok(status) && signal_fence) { + status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout(), + IREE_HAL_WAIT_FLAG_DEFAULT); + } + + iree_hal_fence_release(signal_fence); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Dump output parameters +//===----------------------------------------------------------------------===// + +// Dumps the contents of output archives similar to iree-dump-parameters. +static iree_status_t iree_encode_dump_outputs(iree_output_archive_t* archives, + iree_host_size_t archive_count, + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_string_builder_t sb; + iree_string_builder_initialize(host_allocator, &sb); + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < archive_count && iree_status_is_ok(status); + ++i) { + iree_output_archive_t* archive = &archives[i]; + if (!archive->index) continue; + + status = iree_string_builder_append_cstring( + &sb, + "//" + "===-----------------------------------------------------------------" + "---------------------------------------------===//\n"); + if (!iree_status_is_ok(status)) break; + + // Print archive header. + iree_io_physical_size_t archive_size = + iree_io_parameter_archive_builder_total_size(&archive->builder); + status = iree_string_builder_append_format( + &sb, "// Output: %.*s (%" PRIu64 " bytes)\n", (int)archive->path.size, + archive->path.data, archive_size); + if (!iree_status_is_ok(status)) break; + + // Dump parameter index. + status = iree_io_parameter_index_dump(archive->scope, archive->index, &sb); + } + + if (iree_status_is_ok(status)) { + fprintf(stdout, "%.*s", (int)iree_string_builder_size(&sb), + iree_string_builder_buffer(&sb)); + } + + iree_string_builder_deinitialize(&sb); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Main encoding workflow +//===----------------------------------------------------------------------===// + +static iree_status_t iree_tooling_encode_parameters( + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); + + // Create VM instance. + iree_vm_instance_t* instance = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_create_instance(host_allocator, &instance)); + + // Load modules and discover encoder functions. + iree_tooling_module_list_t module_list; + iree_vm_module_t* encoder_module = NULL; + iree_encode_target_set_t target_set; + status = iree_encode_load_and_discover(instance, host_allocator, &module_list, + &encoder_module, &target_set); + + // Select target. + iree_encode_target_t* selected_target = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_select_target(&target_set, &selected_target); + } + if (iree_status_is_ok(status)) { + status = iree_encode_validate_target(selected_target); + } + + // Handle --list_targets (early exit). + if (iree_status_is_ok(status) && FLAG_list_targets) { + status = iree_encode_print_targets(encoder_module, &target_set); + iree_encode_target_set_deinitialize(&target_set); + iree_tooling_module_list_reset(&module_list); + iree_vm_instance_release(instance); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Call indices function. + iree_vm_list_t* indices_list = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_call_indices(instance, &module_list, selected_target, + host_allocator, &indices_list); + } + + // Handle --list_parameters (early exit). + if (iree_status_is_ok(status) && FLAG_list_parameters) { + status = iree_encode_print_parameters(indices_list); + iree_vm_list_release(indices_list); + iree_encode_target_set_deinitialize(&target_set); + iree_tooling_module_list_reset(&module_list); + iree_vm_instance_release(instance); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Parse output flags. + iree_output_scope_list_t output_list; + iree_output_scope_list_initialize(host_allocator, &output_list); + if (iree_status_is_ok(status)) { + status = iree_encode_parse_output_flags(&output_list); + } + if (iree_status_is_ok(status) && output_list.count == 0) { + status = iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "no output specified; use --output=[scope=]path.irpa " + "(e.g., --output=encoded=output.irpa or --output=output.irpa)"); + } + + // Create output archives. + iree_output_archive_t* archives = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_create_archives(indices_list, &output_list, + host_allocator, &archives); + } + + // Create encoding context with output providers. + iree_vm_context_t* context = NULL; + iree_hal_device_t* device = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_create_encoding_context( + instance, &module_list, archives, output_list.count, host_allocator, + &context, &device); + } + + // Call steps function. + iree_vm_list_t* steps_list = NULL; + if (iree_status_is_ok(status)) { + status = iree_encode_call_steps(context, selected_target, host_allocator, + &steps_list); + } + + // Execute encoder. + if (iree_status_is_ok(status)) { + status = iree_encode_execute(context, device, selected_target, steps_list, + host_allocator); + } + + // Dump output parameters (unless quiet mode). + if (iree_status_is_ok(status) && !FLAG_quiet) { + status = + iree_encode_dump_outputs(archives, output_list.count, host_allocator); + } + + // Cleanup. + iree_vm_list_release(steps_list); + iree_vm_list_release(indices_list); + if (archives) { + for (iree_host_size_t i = 0; i < output_list.count; ++i) { + iree_output_archive_deinitialize(&archives[i]); + } + iree_allocator_free(host_allocator, archives); + } + iree_hal_device_release(device); + iree_vm_context_release(context); + iree_output_scope_list_deinitialize(&output_list); + iree_encode_target_set_deinitialize(&target_set); + iree_tooling_module_list_reset(&module_list); + iree_vm_instance_release(instance); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + +int main(int argc, char** argv) { + IREE_TRACE_APP_ENTER(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_t host_allocator = iree_allocator_system(); + int exit_code = EXIT_SUCCESS; + + iree_flags_set_usage( + "iree-encode-parameters", + "Encodes parameter files using an encoding module.\n" + "\n" + "This tool transforms model parameters using an encoder module produced\n" + "by iree-compile with --iree-parameter-encoder-output-file. The encoder\n" + "pre-computes parameter transformations (packing, encoding, dispatches)\n" + "that would otherwise run at model load time.\n" + "\n" + "WORKFLOW:\n" + " 1. Compile main module with encoder output:\n" + " iree-compile model.mlir \\\n" + " --iree-parameter-encoder-output-file=encoder.mlir \\\n" + " --iree-parameter-splat-path=input.irpa \\\n" + " -o main.vmfb\n" + "\n" + " 2. Compile the encoder module:\n" + " iree-compile encoder.mlir -o encoder.vmfb\n" + "\n" + " 3. Run the encoder to transform parameters:\n" + " iree-encode-parameters \\\n" + " --module=encoder.vmfb \\\n" + " --parameters=model=input.irpa \\\n" + " --output=encoded=output.irpa\n" + "\n" + " 4. Run the main module with encoded parameters:\n" + " iree-run-module \\\n" + " --module=main.vmfb \\\n" + " --parameters=model=input.irpa \\\n" + " --parameters=encoded=output.irpa\n" + "\n" + "FLAGS:\n" + " --module=path.vmfb Encoder module (required)\n" + " --parameters=scope=path Input parameter file(s)\n" + " --output=scope=path.irpa Output encoded parameter file(s)\n" + " --list-targets List available encoding targets\n" + " --list-parameters List parameters that will be encoded\n" + " --target=name Select specific target (default: auto-detect)\n" + " --quiet Suppress output except errors\n"); + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); + + if (argc > 1) { + fprintf(stderr, "Error: no positional arguments expected.\n"); + fprintf(stderr, + "Use one or more --parameters=file.ext flags to specify parameter " + "files.\n"); + IREE_TRACE_ZONE_END(z0); + IREE_TRACE_APP_EXIT(exit_code); + return EXIT_FAILURE; + } + + iree_status_t status = iree_tooling_encode_parameters(host_allocator); + + fflush(stdout); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_free(status); + exit_code = EXIT_FAILURE; + } + fflush(stderr); + + IREE_TRACE_ZONE_END(z0); + IREE_TRACE_APP_EXIT(exit_code); + return exit_code; +} From e440a88303a8de8c09823245e8ddc42b8b1458bb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:45:38 -0800 Subject: [PATCH 36/71] Bump sarisia/actions-status-discord from 1.15.5 to 1.16.0 in the github-actions group (#23106) --- .github/workflows/ci_linux_arm64_clang.yml | 2 +- .github/workflows/ci_linux_x64_clang_byollvm.yml | 2 +- .github/workflows/ci_linux_x64_clang_debug.yml | 2 +- .github/workflows/ci_linux_x64_clang_tsan.yml | 2 +- .github/workflows/ci_linux_x64_gcc.yml | 2 +- .github/workflows/ci_macos_arm64_clang.yml | 2 +- .github/workflows/ci_macos_x64_clang.yml | 2 +- .github/workflows/workflow_summary.yml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci_linux_arm64_clang.yml b/.github/workflows/ci_linux_arm64_clang.yml index 0972afe3a2c6..2229b4a9b656 100644 --- a/.github/workflows/ci_linux_arm64_clang.yml +++ b/.github/workflows/ci_linux_arm64_clang.yml @@ -63,7 +63,7 @@ jobs: run: ./build_tools/cmake/test_iree_dialects.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_byollvm.yml b/.github/workflows/ci_linux_x64_clang_byollvm.yml index a322e9c83891..e866c3262620 100644 --- a/.github/workflows/ci_linux_x64_clang_byollvm.yml +++ b/.github/workflows/ci_linux_x64_clang_byollvm.yml @@ -30,7 +30,7 @@ jobs: run: ./build_tools/cmake/build_and_test_byo_llvm.sh - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_debug.yml b/.github/workflows/ci_linux_x64_clang_debug.yml index 33582148d477..493379d08c2d 100644 --- a/.github/workflows/ci_linux_x64_clang_debug.yml +++ b/.github/workflows/ci_linux_x64_clang_debug.yml @@ -47,7 +47,7 @@ jobs: # would add 10+ minutes to the job. - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_clang_tsan.yml b/.github/workflows/ci_linux_x64_clang_tsan.yml index a1dbe97509f5..2b1a2a6db568 100644 --- a/.github/workflows/ci_linux_x64_clang_tsan.yml +++ b/.github/workflows/ci_linux_x64_clang_tsan.yml @@ -50,7 +50,7 @@ jobs: sccache --show-stats - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_linux_x64_gcc.yml b/.github/workflows/ci_linux_x64_gcc.yml index d5600d5bc250..d453bbbd88d7 100644 --- a/.github/workflows/ci_linux_x64_gcc.yml +++ b/.github/workflows/ci_linux_x64_gcc.yml @@ -38,7 +38,7 @@ jobs: run: ./build_tools/cmake/build_all.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_macos_arm64_clang.yml b/.github/workflows/ci_macos_arm64_clang.yml index e924f9fa5ca9..d18a8bb532bc 100644 --- a/.github/workflows/ci_macos_arm64_clang.yml +++ b/.github/workflows/ci_macos_arm64_clang.yml @@ -60,7 +60,7 @@ jobs: run: bash ./build_tools/cmake/build_all.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/ci_macos_x64_clang.yml b/.github/workflows/ci_macos_x64_clang.yml index d6aa3fa29e06..722de844ebb7 100644 --- a/.github/workflows/ci_macos_x64_clang.yml +++ b/.github/workflows/ci_macos_x64_clang.yml @@ -53,7 +53,7 @@ jobs: run: bash ./build_tools/cmake/ctest_all.sh "${BUILD_DIR}" - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} diff --git a/.github/workflows/workflow_summary.yml b/.github/workflows/workflow_summary.yml index 2a184b66a0ff..23ba06067479 100644 --- a/.github/workflows/workflow_summary.yml +++ b/.github/workflows/workflow_summary.yml @@ -55,7 +55,7 @@ jobs: exit 1 fi - name: Post to Discord on Failure - uses: sarisia/actions-status-discord@b8381b25576cb341b2af39926ab42c5056cc44ed # v1.15.5 + uses: sarisia/actions-status-discord@eb045afee445dc055c18d3d90bd0f244fd062708 # v1.16.0 if: failure() && github.ref_name == 'main' && github.repository_owner == 'iree-org' with: webhook: ${{ secrets.DISCORD_WEBHOOK }} From 5b7ceb23068b0b9c0cb0ce0c6107bc98eb9f1428 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 14 Jan 2026 09:56:35 -0800 Subject: [PATCH 37/71] Adding libFuzzer-based fuzzing infrastructure. (#23122) This PR adds support for coverage-guided fuzz testing using libFuzzer. The infrastructure integrates with both Bazel and CMake build systems, automatically enabling AddressSanitizer for memory error detection. New fuzz targets are defined using `iree_runtime_cc_fuzz` (or `iree_compiler_cc_fuzz` for compiler code). These targets implement the standard `LLVMFuzzerTestOneInput` entry point and are excluded from the default build to avoid slowing down normal development. Two initial fuzz targets are included as proof-of-life: `unicode_fuzz` for UTF-8 and Unicode utilities, and `string_view_fuzz` for string parsing functions. Running these for just a few seconds immediately found two bugs in string_view.c: an out-of-bounds read in `iree_string_view_find_last_of` when given a position past the string length, and an exponential backtracking issue in `iree_string_view_match_pattern` with pathological wildcard patterns like `*?*?*?*?*`. Both are fixed in this PR. See `docs/website/docs/developers/debugging/fuzzing.md` for full documentation. Bazel example: ```shell bazel build --config=fuzzer //runtime/src/iree/base/internal:unicode_fuzz ./bazel-bin/runtime/src/iree/base/internal/unicode_fuzz -max_total_time=60 ``` CMake example: ```shell cmake -B build -DIREE_ENABLE_FUZZING=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo cmake --build build --target unicode_fuzz ./build/runtime/src/iree/base/internal/unicode_fuzz -max_total_time=60 ``` In the future supporting AFL is easy - for now, this with a to-be-landed iree-bazel-fuzz are effective enough. --- CMakeLists.txt | 2 + build_tools/bazel/build_defs.oss.bzl | 14 +- build_tools/bazel/iree.bazelrc | 29 ++- build_tools/bazel/iree_cc_fuzz.bzl | 109 +++++++++ .../bazel_to_cmake_converter.py | 42 ++++ build_tools/cmake/iree_cc_fuzz.cmake | 115 ++++++++++ build_tools/cmake/iree_setup_toolchain.cmake | 13 ++ .../docs/developers/debugging/fuzzing.md | 182 +++++++++++++++ docs/website/mkdocs.yml | 1 + runtime/src/iree/base/BUILD.bazel | 10 +- runtime/src/iree/base/CMakeLists.txt | 9 + runtime/src/iree/base/internal/BUILD.bazel | 10 +- runtime/src/iree/base/internal/CMakeLists.txt | 9 + .../src/iree/base/internal/unicode_fuzz.cc | 217 ++++++++++++++++++ runtime/src/iree/base/string_view.c | 87 +++++-- runtime/src/iree/base/string_view_fuzz.cc | 185 +++++++++++++++ runtime/src/iree/base/string_view_test.cc | 76 ++++++ 17 files changed, 1080 insertions(+), 30 deletions(-) create mode 100644 build_tools/bazel/iree_cc_fuzz.bzl create mode 100644 build_tools/cmake/iree_cc_fuzz.cmake create mode 100644 docs/website/docs/developers/debugging/fuzzing.md create mode 100644 runtime/src/iree/base/internal/unicode_fuzz.cc create mode 100644 runtime/src/iree/base/string_view_fuzz.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ad29f2646ed..ab747f26c7c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -523,6 +523,7 @@ option(IREE_ENABLE_ASAN "Enable address sanitizer" OFF) option(IREE_ENABLE_MSAN "Enable memory sanitizer" OFF) option(IREE_ENABLE_TSAN "Enable thread sanitizer" OFF) option(IREE_ENABLE_UBSAN "Enable undefined behavior sanitizer" OFF) +option(IREE_ENABLE_FUZZING "Enable libFuzzer-based fuzz targets" OFF) option(IREE_ENABLE_SPLIT_DWARF "Enable gsplit-dwarf for debug information if the platform supports it" OFF) option(IREE_ENABLE_THIN_ARCHIVES "Enables thin ar archives (elf systems only). Disable for released static archives" OFF) option(IREE_LINK_COMPILER_SHARED_LIBRARY "Links IREE tools using the compiler compiled into a shared library" ON) @@ -629,6 +630,7 @@ include(iree_copts) include(iree_cc_binary) include(iree_cc_library) include(iree_cc_test) +include(iree_cc_fuzz) include(iree_import_binary) include(iree_install_support) include(iree_external_cmake_options) diff --git a/build_tools/bazel/build_defs.oss.bzl b/build_tools/bazel/build_defs.oss.bzl index 6cf23934c622..c8f3c9c2cc48 100644 --- a/build_tools/bazel/build_defs.oss.bzl +++ b/build_tools/bazel/build_defs.oss.bzl @@ -4,11 +4,23 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") + +# All load statements must come first in Starlark. load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") +load( + "//build_tools/bazel:iree_cc_fuzz.bzl", + _iree_cc_fuzz = "iree_cc_fuzz", + _iree_compiler_cc_fuzz = "iree_compiler_cc_fuzz", + _iree_runtime_cc_fuzz = "iree_runtime_cc_fuzz", +) """Common Bazel definitions for IREE.""" -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +# Re-export fuzz rules for external use. +iree_cc_fuzz = _iree_cc_fuzz +iree_compiler_cc_fuzz = _iree_compiler_cc_fuzz +iree_runtime_cc_fuzz = _iree_runtime_cc_fuzz def defaulting_select(selector): """Pass through to select() with special semantics when converting to CMake. diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc index 60bfb010cc7e..2369df478c8e 100644 --- a/build_tools/bazel/iree.bazelrc +++ b/build_tools/bazel/iree.bazelrc @@ -178,6 +178,17 @@ build:msvc_release --compilation_mode=opt # https://github.com/google/sanitizers/wiki/AddressSanitizer ############################################################################### +# Don't strip debug info +build:sanitizer --strip=never +# Ignore settings of `linkopts = ["-static"]` which can screw up the sanitizer. +# We don't use this in IREE (that's what linkstatic is for), but it could show +# up in dependencies. +build:sanitizer --force_ignore_dash_static +# sanitizer tests tend to take longer, so increase the timeouts +build:sanitizer --test_timeout=120,600,1800,-1 +# Get better stack traces +build:sanitizer --copt=-fno-omit-frame-pointer + # ASAN (address sanitizer) # https://clang.llvm.org/docs/AddressSanitizer.html build:asan --config=sanitizer @@ -216,16 +227,14 @@ build:ubsan --linkopt=-fsanitize=undefined build:ubsan --linkopt=-lubsan build:ubsan --cc_output_directory_tag=ubsan -# Don't strip debug info -build:sanitizer --strip=never -# Ignore settings of `linkopts = ["-static"]` which can screw up the sanitizer. -# We don't use this in IREE (that's what linkstatic is for), but it could show -# up in dependencies. -build:sanitizer --force_ignore_dash_static -# sanitizer tests tend to take longer, so increase the timeouts -build:sanitizer --test_timeout=120,600,1800,-1 -# Get better stack traces -build:sanitizer --copt=-fno-omit-frame-pointer +# Fuzzer (libFuzzer) configuration +# https://llvm.org/docs/LibFuzzer.html +# Includes ASAN by default - there's no reason to fuzz without memory sanitization. +build:fuzzer --config=asan +build:fuzzer --copt=-fsanitize=fuzzer-no-link +build:fuzzer --linkopt=-fsanitize=fuzzer-no-link +build:fuzzer --cc_output_directory_tag=asan-fuzzer +build:fuzzer --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION ############################################################################### # Architecture specific options diff --git a/build_tools/bazel/iree_cc_fuzz.bzl b/build_tools/bazel/iree_cc_fuzz.bzl new file mode 100644 index 000000000000..d50e65033207 --- /dev/null +++ b/build_tools/bazel/iree_cc_fuzz.bzl @@ -0,0 +1,109 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Macros for defining libFuzzer-based fuzz targets. + +Fuzz targets require --config=fuzzer to build properly. The config instruments +all code for coverage feedback and adds appropriate compile/link flags. + +Example usage: + load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_fuzz") + + iree_runtime_cc_fuzz( + name = "unicode_fuzz", + srcs = ["unicode_fuzz.cc"], + deps = [":unicode"], + ) + +Building and running: + bazel build --config=fuzzer //path/to:unicode_fuzz + ./bazel-bin/path/to/unicode_fuzz corpus/ -max_total_time=60 +""" + +def iree_cc_fuzz( + name, + srcs, + deps = None, + data = None, + copts = None, + defines = None, + linkopts = None, + tags = None, + **kwargs): + """Creates a libFuzzer-based fuzz target. + + Args: + name: Target name (e.g., "unicode_fuzz"). + srcs: Source files containing LLVMFuzzerTestOneInput(). + deps: Library dependencies. + data: Data file dependencies. + copts: Additional compile options. + defines: Preprocessor definitions. + linkopts: Additional link options. + tags: Target tags. "fuzz" tag is added automatically. + **kwargs: Additional cc_binary attributes. + """ + if deps == None: + deps = [] + if data == None: + data = [] + if copts == None: + copts = [] + if defines == None: + defines = [] + if linkopts == None: + linkopts = [] + if tags == None: + tags = [] + + # Add "fuzz" tag if not present. + if "fuzz" not in tags: + tags = tags + ["fuzz"] + + native.cc_binary( + name = name, + srcs = srcs, + deps = deps, + data = data, + copts = copts, + defines = defines + ["FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION"], + linkopts = linkopts + ["-fsanitize=fuzzer"], + tags = tags, + testonly = True, + **kwargs + ) + +def iree_runtime_cc_fuzz(deps = None, **kwargs): + """Fuzz target for runtime code using libFuzzer. + + Wraps iree_cc_fuzz and adds //runtime/src:runtime_defines dependency. + + Args: + deps: Library dependencies (runtime_defines added automatically). + **kwargs: Additional arguments passed to iree_cc_fuzz. + """ + if deps == None: + deps = [] + iree_cc_fuzz( + deps = deps + ["//runtime/src:runtime_defines"], + **kwargs + ) + +def iree_compiler_cc_fuzz(deps = None, **kwargs): + """Fuzz target for compiler code using libFuzzer. + + Wraps iree_cc_fuzz and adds //compiler/src:defs dependency. + + Args: + deps: Library dependencies (compiler defs added automatically). + **kwargs: Additional arguments passed to iree_cc_fuzz. + """ + if deps == None: + deps = [] + iree_cc_fuzz( + deps = deps + ["//compiler/src:defs"], + **kwargs + ) diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py index 39d46b384829..9ef6efdc223d 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py @@ -542,6 +542,48 @@ def cc_binary( f")\n\n" ) + def iree_cc_fuzz( + self, + name, + srcs=None, + data=None, + deps=None, + copts=None, + defines=None, + linkopts=None, + tags=None, + **kwargs, + ): + if self._should_skip_target(tags=tags, **kwargs): + return + name_block = self._convert_string_arg_block("NAME", name, quote=False) + srcs_block = self._convert_srcs_block(srcs) + data_block = self._convert_target_list_block("DATA", data) + deps_block = self._convert_target_list_block("DEPS", deps) + copts_block = self._convert_string_list_block("COPTS", copts, sort=False) + defines_block = self._convert_string_list_block("DEFINES", defines) + linkopts_block = self._convert_string_list_block("LINKOPTS", linkopts) + labels_block = self._convert_string_list_block("LABELS", tags) + + self._converter.body += ( + f"iree_cc_fuzz(\n" + f"{name_block}" + f"{srcs_block}" + f"{data_block}" + f"{deps_block}" + f"{copts_block}" + f"{defines_block}" + f"{linkopts_block}" + f"{labels_block}" + f")\n\n" + ) + + def iree_runtime_cc_fuzz(self, **kwargs): + self.iree_cc_fuzz(**kwargs) + + def iree_compiler_cc_fuzz(self, **kwargs): + self.iree_cc_fuzz(**kwargs) + def iree_c_embed_data( self, name, diff --git a/build_tools/cmake/iree_cc_fuzz.cmake b/build_tools/cmake/iree_cc_fuzz.cmake new file mode 100644 index 000000000000..296f4f26edc4 --- /dev/null +++ b/build_tools/cmake/iree_cc_fuzz.cmake @@ -0,0 +1,115 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# iree_cc_fuzz() +# +# CMake function to create a libFuzzer-based fuzz target. +# +# Parameters: +# NAME: name of target. This name is used for the generated executable. +# SRCS: List of source files for the fuzzer (must define LLVMFuzzerTestOneInput). +# DATA: List of other targets and files required for this binary. +# DEPS: List of other libraries to be linked in to the binary targets. +# COPTS: List of private compile options. +# DEFINES: List of public defines. +# LINKOPTS: List of link options. +# LABELS: Additional labels to apply to the target. +# +# Note: +# - Fuzz targets require IREE_BUILD_TESTS=ON AND IREE_ENABLE_FUZZING=ON. +# - Fuzz targets are NOT added to CTest (they run differently than tests). +# - Fuzz targets are excluded from the default 'all' target (build explicitly). +# - Binary name is ${NAME} in the bin directory. +# +# Usage: +# iree_cc_fuzz( +# NAME +# unicode_fuzz +# SRCS +# "unicode_fuzz.cc" +# DEPS +# iree::base::internal::unicode +# ) +function(iree_cc_fuzz) + # Fuzz targets require both tests enabled AND fuzzing enabled. + if(NOT IREE_BUILD_TESTS) + return() + endif() + if(NOT IREE_ENABLE_FUZZING) + return() + endif() + + cmake_parse_arguments( + _RULE + "" + "NAME" + "SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS;LABELS" + ${ARGN} + ) + + # Prefix the library with the package name, so we get: iree_package_name + iree_package_name(_PACKAGE_NAME) + iree_package_ns(_PACKAGE_NS) + set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}") + + add_executable(${_NAME} "") + # Alias the iree_package_name fuzz binary to iree::package::name. + add_executable(${_PACKAGE_NS}::${_RULE_NAME} ALIAS ${_NAME}) + + set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_RULE_NAME}") + target_sources(${_NAME} + PRIVATE + ${_RULE_SRCS} + ) + target_include_directories(${_NAME} SYSTEM + PUBLIC + "$" + "$" + ) + target_compile_definitions(${_NAME} + PUBLIC + ${_RULE_DEFINES} + ) + target_compile_options(${_NAME} + PRIVATE + ${IREE_DEFAULT_COPTS} + ${_RULE_COPTS} + ) + + # Link with libFuzzer runtime. The -fsanitize=fuzzer flag provides the main() + # function and fuzzing driver. All other code is compiled with + # -fsanitize=fuzzer-no-link (set in iree_setup_toolchain.cmake) for coverage + # instrumentation without linking the fuzzer runtime. + target_link_options(${_NAME} + PRIVATE + ${IREE_DEFAULT_LINKOPTS} + ${_RULE_LINKOPTS} + "-fsanitize=fuzzer" + ) + + # Replace dependencies passed by ::name with iree::package::name + list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::") + + # Implicit deps. + if(IREE_IMPLICIT_DEFS_CC_DEPS) + list(APPEND _RULE_DEPS ${IREE_IMPLICIT_DEFS_CC_DEPS}) + endif() + + target_link_libraries(${_NAME} + PUBLIC + ${_RULE_DEPS} + ) + iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA}) + + # Add all IREE fuzz targets to a folder in the IDE for organization. + set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER}/fuzz) + + set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD ${IREE_CXX_STANDARD}) + set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD_REQUIRED ON) + + # Exclude from 'all' target - fuzz targets must be built explicitly. + set_property(TARGET ${_NAME} PROPERTY EXCLUDE_FROM_ALL ON) +endfunction() diff --git a/build_tools/cmake/iree_setup_toolchain.cmake b/build_tools/cmake/iree_setup_toolchain.cmake index 756d9b073409..5ff74afbf296 100644 --- a/build_tools/cmake/iree_setup_toolchain.cmake +++ b/build_tools/cmake/iree_setup_toolchain.cmake @@ -144,6 +144,12 @@ macro(iree_setup_toolchain) # defined with the same sanitizer flags, including e.g. standard library # symbols that might be used by both IREE and non-IREE (e.g. LLVM) code. + # Fuzzing requires ASan - enable it automatically if not already set. + if(IREE_ENABLE_FUZZING AND NOT IREE_ENABLE_ASAN) + message(STATUS "Fuzzing enabled: automatically enabling ASan") + set(IREE_ENABLE_ASAN ON) + endif() + if(IREE_ENABLE_ASAN) string(APPEND CMAKE_CXX_FLAGS " -fsanitize=address") string(APPEND CMAKE_C_FLAGS " -fsanitize=address") @@ -187,6 +193,13 @@ macro(iree_setup_toolchain) string(APPEND CMAKE_CXX_FLAGS " -fsanitize=undefined") string(APPEND CMAKE_C_FLAGS " -fsanitize=undefined") endif() + if(IREE_ENABLE_FUZZING) + # Instrument all code for libFuzzer coverage feedback without linking the + # fuzzer runtime. Fuzz targets link with -fsanitize=fuzzer separately to + # get the main() function and driver. + string(APPEND CMAKE_CXX_FLAGS " -fsanitize=fuzzer-no-link") + string(APPEND CMAKE_C_FLAGS " -fsanitize=fuzzer-no-link") + endif() #----------------------------------------------------------------------------- # Build performance optimizations diff --git a/docs/website/docs/developers/debugging/fuzzing.md b/docs/website/docs/developers/debugging/fuzzing.md new file mode 100644 index 000000000000..4d533aecc914 --- /dev/null +++ b/docs/website/docs/developers/debugging/fuzzing.md @@ -0,0 +1,182 @@ +--- +icon: material/bug +--- + +# Fuzzing with libFuzzer + +[libFuzzer](https://llvm.org/docs/LibFuzzer.html) is a coverage-guided fuzzing +engine provided by LLVM. It generates random inputs and mutates them based on +code coverage feedback to find crashes, hangs, and memory errors. + +IREE provides build infrastructure for creating libFuzzer-based fuzz targets +that integrate with the existing build system. + +## When to use fuzzing + +Fuzzing is most effective for: + +- Parsers and decoders (UTF-8, binary formats, etc.) +- Serialization/deserialization code +- Input validation logic +- Any code that processes untrusted or external data + +## Enabling fuzzing builds + +### Bazel + +```shell +bazel build --config=fuzzer //runtime/src/iree/base/internal:unicode_fuzz +``` + +The `--config=fuzzer` flag enables coverage instrumentation and ASan. + +### CMake + +```shell +cmake -B build -DIREE_ENABLE_FUZZING=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo +cmake --build build --target unicode_fuzz +``` + +Fuzzing automatically enables ASan. Fuzz targets are excluded from the default +`all` target and must be built explicitly. + +## Running fuzz targets + +Fuzz targets are standalone executables that accept libFuzzer arguments: + +```shell +# Run indefinitely (Ctrl+C to stop) +./build/runtime/src/iree/base/internal/unicode_fuzz + +# Run for 60 seconds +./build/runtime/src/iree/base/internal/unicode_fuzz -max_total_time=60 + +# Use a corpus directory (recommended) +mkdir -p corpus/unicode +./build/runtime/src/iree/base/internal/unicode_fuzz corpus/unicode/ +``` + +### Common options + +Option | Description +------ | ----------- +`-max_total_time=N` | Stop after N seconds +`-max_len=N` | Maximum input size in bytes +`-timeout=N` | Per-input timeout in seconds (0 to disable) +`-jobs=N` | Run N parallel fuzzing jobs +`-workers=N` | Number of worker processes for parallel fuzzing +`-dict=file` | Use a dictionary file for structured inputs +`-seed=N` | Use specific random seed for reproducibility + +See [libFuzzer documentation](https://llvm.org/docs/LibFuzzer.html) for all +options. + +## Writing fuzz targets + +Fuzz targets implement the `LLVMFuzzerTestOneInput` function: + +```cpp +// my_fuzz.cc +#include +#include + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + // Process the fuzzer-generated input + my_function_under_test(data, size); + return 0; // Always return 0 +} +``` + +### Adding to the build system + +In `BUILD.bazel`: + +```python +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_fuzz") + +iree_runtime_cc_fuzz( + name = "my_fuzz", + srcs = ["my_fuzz.cc"], + deps = [ + ":my_library", + ], +) +``` + +Then run `python build_tools/bazel_to_cmake/bazel_to_cmake.py` to generate the +CMake equivalent. + +## Best practices + +### Maintain a corpus + +Store interesting inputs in a corpus directory. The fuzzer uses existing corpus +entries as seeds for mutation: + +```shell +mkdir -p corpus/my_fuzz +./my_fuzz corpus/my_fuzz/ -max_total_time=3600 +``` + +After finding bugs, minimize the corpus to remove redundant entries: + +```shell +mkdir corpus/my_fuzz_minimized +./my_fuzz -merge=1 corpus/my_fuzz_minimized/ corpus/my_fuzz/ +``` + +### Add unit tests for found bugs + +When fuzzing discovers a crash: + +1. Minimize the reproducer: `./my_fuzz -minimize_crash=1 crash-xxx` +2. Add the minimized input as a unit test case +3. Fix the bug +4. Verify the fix with the original crash input + +This prevents regressions and documents the bug. + +### Use dictionaries for structured formats + +For inputs with specific syntax (protocols, file formats), provide a dictionary: + +```text +# my_dict.txt +"keyword1" +"keyword2" +"\x00\x01\x02" +``` + +```shell +./my_fuzz -dict=my_dict.txt corpus/ +``` + +## Troubleshooting + +### Fuzzer runs slowly + +- Ensure `CMAKE_BUILD_TYPE=RelWithDebInfo` or `Release` (Debug is very slow) +- Check that the target doesn't do excessive I/O or allocations per iteration +- Use `-jobs=N` for parallel fuzzing on multi-core machines + +### Out of memory + +- Limit input size with `-max_len=N` +- Add early returns for oversized inputs in your fuzz target +- Use `-rss_limit_mb=N` to set memory limits + +### No new coverage + +- Verify the target actually processes the input +- Check that coverage instrumentation is enabled (`-fsanitize=fuzzer-no-link`) +- Try seeding with representative inputs in the corpus + +### Timeout errors + +libFuzzer kills inputs that take too long (default 1200 seconds). If you see +`ALARM: working on the last Unit for N seconds` followed by a timeout: + +- Use `-timeout=N` to adjust the per-input timeout (in seconds) +- Use `-timeout=0` to disable timeouts entirely (useful for debugging) +- Check if certain inputs cause algorithmic complexity issues (e.g., pathological + regex patterns, deeply nested structures) diff --git a/docs/website/mkdocs.yml b/docs/website/mkdocs.yml index 115bbd6c8e3d..fb3507fdee7b 100644 --- a/docs/website/mkdocs.yml +++ b/docs/website/mkdocs.yml @@ -245,6 +245,7 @@ nav: - "developers/debugging/model-development.md" - "developers/debugging/releases.md" - "developers/debugging/sanitizers.md" + - "developers/debugging/fuzzing.md" - "Performance": - "developers/performance/benchmarking.md" - "developers/performance/profiling.md" diff --git a/runtime/src/iree/base/BUILD.bazel b/runtime/src/iree/base/BUILD.bazel index 5f83948d351d..1691f95047cc 100644 --- a/runtime/src/iree/base/BUILD.bazel +++ b/runtime/src/iree/base/BUILD.bazel @@ -6,7 +6,7 @@ # Common types and utilities used in the IREE codebase. -load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test") +load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_fuzz", "iree_runtime_cc_library", "iree_runtime_cc_test") package( default_visibility = ["//visibility:public"], @@ -112,6 +112,14 @@ iree_runtime_cc_test( ], ) +iree_runtime_cc_fuzz( + name = "string_view_fuzz", + srcs = ["string_view_fuzz.cc"], + deps = [ + ":base", + ], +) + iree_runtime_cc_test( name = "string_view_test", srcs = ["string_view_test.cc"], diff --git a/runtime/src/iree/base/CMakeLists.txt b/runtime/src/iree/base/CMakeLists.txt index 8d0b1e8123fa..56144d624c5c 100644 --- a/runtime/src/iree/base/CMakeLists.txt +++ b/runtime/src/iree/base/CMakeLists.txt @@ -148,6 +148,15 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_fuzz( + NAME + string_view_fuzz + SRCS + "string_view_fuzz.cc" + DEPS + ::base +) + iree_cc_test( NAME string_view_test diff --git a/runtime/src/iree/base/internal/BUILD.bazel b/runtime/src/iree/base/internal/BUILD.bazel index 5d0b2a47c068..fde8859615a7 100644 --- a/runtime/src/iree/base/internal/BUILD.bazel +++ b/runtime/src/iree/base/internal/BUILD.bazel @@ -8,7 +8,7 @@ # These are not part of the IREE API. Though they may be used by external # projects their API may change at any time. -load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_binary", "iree_runtime_cc_library", "iree_runtime_cc_test") +load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_binary", "iree_runtime_cc_fuzz", "iree_runtime_cc_library", "iree_runtime_cc_test") load("//build_tools/bazel:cc_binary_benchmark.bzl", "cc_binary_benchmark") load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") @@ -367,6 +367,14 @@ iree_runtime_cc_library( ], ) +iree_runtime_cc_fuzz( + name = "unicode_fuzz", + srcs = ["unicode_fuzz.cc"], + deps = [ + ":unicode", + ], +) + iree_runtime_cc_test( name = "unicode_test", srcs = ["unicode_test.cc"], diff --git a/runtime/src/iree/base/internal/CMakeLists.txt b/runtime/src/iree/base/internal/CMakeLists.txt index 1d2e5ae15790..41daa557021b 100644 --- a/runtime/src/iree/base/internal/CMakeLists.txt +++ b/runtime/src/iree/base/internal/CMakeLists.txt @@ -395,6 +395,15 @@ iree_cc_library( PUBLIC ) +iree_cc_fuzz( + NAME + unicode_fuzz + SRCS + "unicode_fuzz.cc" + DEPS + ::unicode +) + iree_cc_test( NAME unicode_test diff --git a/runtime/src/iree/base/internal/unicode_fuzz.cc b/runtime/src/iree/base/internal/unicode_fuzz.cc new file mode 100644 index 000000000000..69e02147fd04 --- /dev/null +++ b/runtime/src/iree/base/internal/unicode_fuzz.cc @@ -0,0 +1,217 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Fuzz target for Unicode utilities: UTF-8 decoding/validation, category +// classification, case folding, and composition. Includes invariant assertions +// that crash on consistency violations. +// +// See https://iree.dev/developers/debugging/fuzzing/ for build and run info. + +#include +#include + +#include "iree/base/internal/unicode.h" + +// Invariant assertion that crashes on failure. +// We use __builtin_trap() to get a clean crash for the fuzzer to detect. +#define FUZZ_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + __builtin_trap(); \ + } \ + } while (0) + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + iree_string_view_t input = + iree_make_string_view(reinterpret_cast(data), size); + + // Test UTF-8 validation and counting. + (void)iree_unicode_utf8_validate(input); + (void)iree_unicode_utf8_codepoint_count(input); + + // Test incomplete tail detection. + (void)iree_unicode_utf8_incomplete_tail_length(input.data, input.size); + + // Decode all codepoints and test classification/transformation functions. + iree_host_size_t position = 0; + while (position < input.size) { + uint32_t codepoint = iree_unicode_utf8_decode(input, &position); + + // Category classification. + (void)iree_unicode_category(codepoint); + (void)iree_unicode_is_letter(codepoint); + (void)iree_unicode_is_mark(codepoint); + (void)iree_unicode_is_number(codepoint); + (void)iree_unicode_is_punctuation(codepoint); + (void)iree_unicode_is_symbol(codepoint); + (void)iree_unicode_is_separator(codepoint); + (void)iree_unicode_is_other(codepoint); + (void)iree_unicode_is_whitespace(codepoint); + (void)iree_unicode_is_control(codepoint); + (void)iree_unicode_is_cjk(codepoint); + (void)iree_unicode_is_hiragana(codepoint); + (void)iree_unicode_is_katakana(codepoint); + (void)iree_unicode_is_hangul(codepoint); + + // Case folding. + (void)iree_unicode_to_lower(codepoint); + (void)iree_unicode_to_upper(codepoint); + + // NFD decomposition. + (void)iree_unicode_nfd_base(codepoint); + + // Canonical Combining Class. + (void)iree_unicode_ccc(codepoint); + + // UTF-8 encoding (roundtrip test). + char encode_buffer[4]; + (void)iree_unicode_utf8_encode(codepoint, encode_buffer); + (void)iree_unicode_utf8_encoded_length(codepoint); + } + + //===--------------------------------------------------------------------===// + // Direct codepoint testing (raw byte interpretation) + //===--------------------------------------------------------------------===// + // Interpret every 4 bytes as a raw uint32_t codepoint to test the full + // codepoint space including invalid ranges (>0x10FFFF, surrogates). + // This exercises table lookup binary search with boundary values. + for (size_t i = 0; i + 4 <= size; i += 4) { + uint32_t codepoint = (static_cast(data[i]) << 24) | + (static_cast(data[i + 1]) << 16) | + (static_cast(data[i + 2]) << 8) | + static_cast(data[i + 3]); + + // Test all classification functions on arbitrary codepoint values. + iree_unicode_category_t category = iree_unicode_category(codepoint); + bool is_letter = iree_unicode_is_letter(codepoint); + bool is_mark = iree_unicode_is_mark(codepoint); + bool is_number = iree_unicode_is_number(codepoint); + bool is_punctuation = iree_unicode_is_punctuation(codepoint); + bool is_symbol = iree_unicode_is_symbol(codepoint); + bool is_separator = iree_unicode_is_separator(codepoint); + bool is_other = iree_unicode_is_other(codepoint); + (void)iree_unicode_is_whitespace(codepoint); + (void)iree_unicode_is_control(codepoint); + (void)iree_unicode_is_cjk(codepoint); + (void)iree_unicode_is_hiragana(codepoint); + (void)iree_unicode_is_katakana(codepoint); + (void)iree_unicode_is_hangul(codepoint); + + // Invariant: category classification consistency. + // If is_X returns true, the corresponding category bit must be set. + if (is_letter) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_LETTER) != 0); + } + if (is_mark) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_MARK) != 0); + } + if (is_number) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_NUMBER) != 0); + } + if (is_punctuation) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_PUNCTUATION) != 0); + } + if (is_symbol) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_SYMBOL) != 0); + } + if (is_separator) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_SEPARATOR) != 0); + } + if (is_other) { + FUZZ_ASSERT((category & IREE_UNICODE_CATEGORY_OTHER) != 0); + } + + // Test case folding and NFD. + uint32_t lower = iree_unicode_to_lower(codepoint); + uint32_t upper = iree_unicode_to_upper(codepoint); + uint32_t nfd = iree_unicode_nfd_base(codepoint); + (void)iree_unicode_ccc(codepoint); + + // Invariant: case folding idempotency. + // Applying the same case operation twice should yield the same result. + FUZZ_ASSERT(iree_unicode_to_lower(lower) == lower); + FUZZ_ASSERT(iree_unicode_to_upper(upper) == upper); + + // Note: NFD decomposition may be multi-level (e.g., ẳ → ạ → a), so + // nfd_base is NOT necessarily idempotent. Instead, verify it converges + // to a fixed point within a reasonable number of steps. + uint32_t nfd_current = nfd; + for (int depth = 0; depth < 10; ++depth) { + uint32_t nfd_next = iree_unicode_nfd_base(nfd_current); + if (nfd_next == nfd_current) break; // Reached fixed point. + nfd_current = nfd_next; + } + // After at most 10 iterations, we must have reached a fixed point. + FUZZ_ASSERT(iree_unicode_nfd_base(nfd_current) == nfd_current); + + // Invariant: encode/decode roundtrip for valid codepoints. + int encoded_length = iree_unicode_utf8_encoded_length(codepoint); + if (encoded_length > 0) { + char encode_buffer[4]; + int actual_length = iree_unicode_utf8_encode(codepoint, encode_buffer); + + // Invariant: encoded_length and encode must agree. + FUZZ_ASSERT(encoded_length == actual_length); + + // Decode what we just encoded and verify roundtrip. + iree_string_view_t encoded = iree_make_string_view( + encode_buffer, static_cast(actual_length)); + iree_host_size_t decode_position = 0; + uint32_t decoded = iree_unicode_utf8_decode(encoded, &decode_position); + + // Invariant: roundtrip must recover the original codepoint. + FUZZ_ASSERT(decoded == codepoint); + FUZZ_ASSERT(decode_position == + static_cast(actual_length)); + } + } + + //===--------------------------------------------------------------------===// + // Composition testing with status verification + //===--------------------------------------------------------------------===// + // Test composition on valid UTF-8 sequences, verifying status codes. + if (iree_unicode_utf8_validate(input)) { + // Allocate output buffer (composition can only shrink). + char* compose_buffer = new char[size + 1]; + iree_host_size_t out_length = 0; + iree_status_t status = + iree_unicode_compose(input, compose_buffer, size + 1, &out_length); + + // Status must be OK or RESOURCE_EXHAUSTED (for very long combining seqs). + // Any other status indicates a bug in the compose function. + FUZZ_ASSERT(iree_status_is_ok(status) || + iree_status_code(status) == IREE_STATUS_RESOURCE_EXHAUSTED); + + if (iree_status_is_ok(status)) { + // Invariant: output length must not exceed input length. + // Composition can only shrink (combining base + mark -> precomposed). + FUZZ_ASSERT(out_length <= input.size); + + // The output should also be valid UTF-8. + iree_string_view_t output = + iree_make_string_view(compose_buffer, out_length); + FUZZ_ASSERT(iree_unicode_utf8_validate(output)); + } else { + iree_status_ignore(status); + } + delete[] compose_buffer; + } + + // Test pairwise composition with interpreted codepoints. + if (size >= 8) { + uint32_t base = (static_cast(data[0]) << 24) | + (static_cast(data[1]) << 16) | + (static_cast(data[2]) << 8) | + static_cast(data[3]); + uint32_t combining = (static_cast(data[4]) << 24) | + (static_cast(data[5]) << 16) | + (static_cast(data[6]) << 8) | + static_cast(data[7]); + (void)iree_unicode_compose_pair(base, combining); + } + + return 0; +} diff --git a/runtime/src/iree/base/string_view.c b/runtime/src/iree/base/string_view.c index 583b58b0b53b..349196c125bf 100644 --- a/runtime/src/iree/base/string_view.c +++ b/runtime/src/iree/base/string_view.c @@ -96,7 +96,7 @@ IREE_API_EXPORT iree_host_size_t iree_string_view_find_last_of( for (iree_host_size_t i = 0; i < s.size; ++i) { lookup_table[(uint8_t)s.data[i]] = true; } - pos = iree_min(pos, value.size) + 1; + pos = iree_min(pos, value.size - 1) + 1; iree_host_size_t i = pos; while (i != 0) { --i; @@ -261,28 +261,81 @@ static bool iree_string_view_match_pattern_impl(iree_string_view_t value, return true; } char pattern_char = pattern.data[0]; - if (pattern_char == '*' && pattern.size > 1 && - iree_string_view_is_empty(value)) { + + // Normalize wildcard sequences to avoid exponential backtracking. + // A sequence like *?*?* is equivalent to "match 2+ chars then match rest". + // We coalesce all * and ? into a single * with a minimum char requirement. + if (pattern_char == '*' || pattern_char == '?') { + iree_host_size_t min_chars = 0; + iree_host_size_t skip = 0; + bool has_star = false; + while (skip < pattern.size) { + char c = pattern.data[skip]; + if (c == '*') { + has_star = true; + ++skip; + } else if (c == '?') { + ++min_chars; + ++skip; + } else { + break; + } + } + + // Remaining pattern after wildcards. + iree_string_view_t rest = + iree_string_view_substr(pattern, skip, IREE_STRING_VIEW_NPOS); + + if (!has_star) { + // Only ? wildcards - must match exactly min_chars characters. + if (value.size < min_chars) return false; + return iree_string_view_match_pattern_impl( + iree_string_view_substr(value, min_chars, IREE_STRING_VIEW_NPOS), + rest); + } + + // Has * - must match at least min_chars, possibly more. + if (value.size < min_chars) return false; + + // Empty rest means * matches everything remaining. + if (iree_string_view_is_empty(rest)) return true; + + // Try matching rest at each position from min_chars to end. + for (iree_host_size_t i = min_chars; i <= value.size; ++i) { + if (iree_string_view_match_pattern_impl( + iree_string_view_substr(value, i, IREE_STRING_VIEW_NPOS), rest)) { + return true; + } + } return false; - } else if (pattern_char == '*' && pattern.size == 1) { - return true; - } else if (pattern_char == '?' || value.data[0] == pattern_char) { - return iree_string_view_match_pattern_impl( - iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), - iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)); - } else if (pattern_char == '*') { - return iree_string_view_match_pattern_impl( - value, - iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)) || - iree_string_view_match_pattern_impl( - iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), - pattern); } - return false; + + // Literal character - must match exactly. + if (iree_string_view_is_empty(value) || value.data[0] != pattern_char) { + return false; + } + return iree_string_view_match_pattern_impl( + iree_string_view_substr(value, 1, IREE_STRING_VIEW_NPOS), + iree_string_view_substr(pattern, 1, IREE_STRING_VIEW_NPOS)); } +// Maximum wildcards allowed in a pattern to prevent pathological matching. +// 16 is enough for any reasonable glob (e.g., "*foo*bar*baz*") while avoiding +// O(n^2) blowup on patterns like "?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*?*". +#define IREE_STRING_VIEW_MAX_PATTERN_WILDCARDS 16 + IREE_API_EXPORT bool iree_string_view_match_pattern( iree_string_view_t value, iree_string_view_t pattern) { + // Count wildcards and reject patterns with too many. + iree_host_size_t wildcard_count = 0; + for (iree_host_size_t i = 0; i < pattern.size; ++i) { + if (pattern.data[i] == '*' || pattern.data[i] == '?') { + ++wildcard_count; + } + } + if (wildcard_count > IREE_STRING_VIEW_MAX_PATTERN_WILDCARDS) { + return false; + } return iree_string_view_match_pattern_impl(value, pattern); } diff --git a/runtime/src/iree/base/string_view_fuzz.cc b/runtime/src/iree/base/string_view_fuzz.cc new file mode 100644 index 000000000000..43841b633d27 --- /dev/null +++ b/runtime/src/iree/base/string_view_fuzz.cc @@ -0,0 +1,185 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Fuzz target for string parsing utilities: integer/float parsing, device size +// parsing with units, bitfield parsing, pattern matching, and hex byte parsing. +// +// See https://iree.dev/developers/debugging/fuzzing/ for build and run info. + +#include +#include + +#include "iree/base/api.h" + +// Sample bitfield mapping table for fuzzing iree_bitfield_parse. +// Uses realistic flag names similar to actual IREE usage. +static const iree_bitfield_string_mapping_t kTestBitfieldMappings[] = { + {0x7, IREE_SVL("ALL")}, // Combined flag (A|B|C). + {0x1, IREE_SVL("READ")}, // Bit 0. + {0x2, IREE_SVL("WRITE")}, // Bit 1. + {0x4, IREE_SVL("EXECUTE")}, // Bit 2. + {0x8, IREE_SVL("DISCARD")}, // Bit 3. + {0x10, IREE_SVL("MAPPABLE")}, // Bit 4. + {0x20, IREE_SVL("COHERENT")}, // Bit 5. +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + iree_string_view_t input = + iree_make_string_view(reinterpret_cast(data), size); + + //===--------------------------------------------------------------------===// + // Integer parsing (signed and unsigned, various bases) + //===--------------------------------------------------------------------===// + + { + int32_t value_i32 = 0; + (void)iree_string_view_atoi_int32(input, &value_i32); + (void)iree_string_view_atoi_int32_base(input, 10, &value_i32); + (void)iree_string_view_atoi_int32_base(input, 16, &value_i32); + (void)iree_string_view_atoi_int32_base(input, 2, &value_i32); + } + + { + uint32_t value_u32 = 0; + (void)iree_string_view_atoi_uint32(input, &value_u32); + (void)iree_string_view_atoi_uint32_base(input, 10, &value_u32); + (void)iree_string_view_atoi_uint32_base(input, 16, &value_u32); + (void)iree_string_view_atoi_uint32_base(input, 2, &value_u32); + } + + { + int64_t value_i64 = 0; + (void)iree_string_view_atoi_int64(input, &value_i64); + (void)iree_string_view_atoi_int64_base(input, 10, &value_i64); + (void)iree_string_view_atoi_int64_base(input, 16, &value_i64); + (void)iree_string_view_atoi_int64_base(input, 2, &value_i64); + } + + { + uint64_t value_u64 = 0; + (void)iree_string_view_atoi_uint64(input, &value_u64); + (void)iree_string_view_atoi_uint64_base(input, 10, &value_u64); + (void)iree_string_view_atoi_uint64_base(input, 16, &value_u64); + (void)iree_string_view_atoi_uint64_base(input, 2, &value_u64); + } + + //===--------------------------------------------------------------------===// + // Floating point parsing + //===--------------------------------------------------------------------===// + + { + float value_f32 = 0.0f; + (void)iree_string_view_atof(input, &value_f32); + } + + { + double value_f64 = 0.0; + (void)iree_string_view_atod(input, &value_f64); + } + + //===--------------------------------------------------------------------===// + // Device size parsing with units (e.g., "1kb", "2mib", "3gb") + //===--------------------------------------------------------------------===// + + { + iree_device_size_t device_size = 0; + iree_status_t status = + iree_string_view_parse_device_size(input, &device_size); + iree_status_ignore(status); + } + + //===--------------------------------------------------------------------===// + // Bitfield parsing + //===--------------------------------------------------------------------===// + + { + uint32_t bitfield_value = 0; + iree_status_t status = + iree_bitfield_parse(input, IREE_ARRAYSIZE(kTestBitfieldMappings), + kTestBitfieldMappings, &bitfield_value); + iree_status_ignore(status); + } + + //===--------------------------------------------------------------------===// + // Pattern matching (wildcard patterns with * and ?) + //===--------------------------------------------------------------------===// + + // Use the first half as value and second half as pattern. + if (size >= 2) { + size_t mid = size / 2; + iree_string_view_t value = + iree_make_string_view(reinterpret_cast(data), mid); + iree_string_view_t pattern = iree_make_string_view( + reinterpret_cast(data + mid), size - mid); + (void)iree_string_view_match_pattern(value, pattern); + } + + // Also test pattern matching with specific patterns that stress recursion. + (void)iree_string_view_match_pattern(input, IREE_SV("*")); + (void)iree_string_view_match_pattern(input, IREE_SV("?*?")); + (void)iree_string_view_match_pattern(input, IREE_SV("***")); + + //===--------------------------------------------------------------------===// + // Hex byte parsing + //===--------------------------------------------------------------------===// + + // Parse up to 64 bytes of hex data. + { + uint8_t hex_buffer[64] = {0}; + (void)iree_string_view_parse_hex_bytes(input, sizeof(hex_buffer), + hex_buffer); + } + + // Try parsing various sizes to test boundary conditions. + for (size_t parse_size = 1; parse_size <= 8; ++parse_size) { + uint8_t small_buffer[8] = {0}; + (void)iree_string_view_parse_hex_bytes(input, parse_size, small_buffer); + } + + //===--------------------------------------------------------------------===// + // String view operations that process the data + //===--------------------------------------------------------------------===// + + (void)iree_string_view_trim(input); + + // Split operations with various split characters. + { + iree_string_view_t lhs, rhs; + (void)iree_string_view_split(input, '|', &lhs, &rhs); + (void)iree_string_view_split(input, '=', &lhs, &rhs); + (void)iree_string_view_split(input, ',', &lhs, &rhs); + (void)iree_string_view_split(input, ':', &lhs, &rhs); + } + + // Find operations. + if (size > 0) { + char search_char = static_cast(data[0]); + (void)iree_string_view_find_char(input, search_char, 0); + + if (size > 1) { + iree_string_view_t search_set = + iree_make_string_view(reinterpret_cast(data), size / 2); + (void)iree_string_view_find_first_of(input, search_set, 0); + (void)iree_string_view_find_last_of(input, search_set, SIZE_MAX); + } + } + + // Comparison operations. + if (size >= 2) { + size_t mid = size / 2; + iree_string_view_t left = + iree_make_string_view(reinterpret_cast(data), mid); + iree_string_view_t right = iree_make_string_view( + reinterpret_cast(data + mid), size - mid); + (void)iree_string_view_equal(left, right); + (void)iree_string_view_equal_case(left, right); + (void)iree_string_view_compare(left, right); + (void)iree_string_view_starts_with(left, right); + (void)iree_string_view_ends_with(left, right); + } + + return 0; +} diff --git a/runtime/src/iree/base/string_view_test.cc b/runtime/src/iree/base/string_view_test.cc index de623aa7495f..890a90101f9c 100644 --- a/runtime/src/iree/base/string_view_test.cc +++ b/runtime/src/iree/base/string_view_test.cc @@ -670,4 +670,80 @@ TEST(StringViewTest, ParseDeviceSizeInvalid) { EXPECT_THAT(ParseDeviceSize("abc"), StatusIs(StatusCode::kInvalidArgument)); } +TEST(StringViewTest, MatchPattern) { + auto match = [](const char* value, const char* pattern) -> bool { + return iree_string_view_match_pattern(iree_make_cstring_view(value), + iree_make_cstring_view(pattern)); + }; + + // Empty patterns and values. + EXPECT_TRUE(match("", "")); + EXPECT_FALSE(match("a", "")); + EXPECT_FALSE(match("", "a")); + + // Exact matches. + EXPECT_TRUE(match("abc", "abc")); + EXPECT_FALSE(match("abc", "abd")); + EXPECT_FALSE(match("abc", "ab")); + EXPECT_FALSE(match("ab", "abc")); + + // Single character wildcard (?). + EXPECT_TRUE(match("a", "?")); + EXPECT_TRUE(match("abc", "a?c")); + EXPECT_TRUE(match("abc", "???")); + EXPECT_FALSE(match("ab", "???")); + EXPECT_FALSE(match("abcd", "???")); + + // Multi-character wildcard (*). + EXPECT_TRUE(match("", "*")); + EXPECT_TRUE(match("a", "*")); + EXPECT_TRUE(match("abc", "*")); + EXPECT_TRUE(match("abc", "a*")); + EXPECT_TRUE(match("abc", "*c")); + EXPECT_TRUE(match("abc", "a*c")); + EXPECT_TRUE(match("abxyzc", "a*c")); + EXPECT_FALSE(match("abc", "a*d")); + + // Combined wildcards. + EXPECT_TRUE(match("abc", "?*")); + EXPECT_TRUE(match("abc", "*?")); + EXPECT_TRUE(match("abc", "?*?")); + EXPECT_TRUE(match("abcdef", "a?c*f")); + + // Consecutive wildcards (tests coalescing to avoid exponential backtracking). + EXPECT_TRUE(match("abc", "**")); + EXPECT_TRUE(match("abc", "***")); + EXPECT_TRUE(match("abc", "a**c")); + EXPECT_TRUE(match("abc", "**c")); + EXPECT_TRUE(match("abc", "a**")); + + // Pathological pattern that would cause exponential backtracking without + // coalescing: many wildcards followed by a non-matching suffix. + // This must complete in reasonable time (milliseconds, not seconds). + EXPECT_FALSE(match("aaaaaaaaaaaaaaaaaaaab", "**************c")); + EXPECT_TRUE(match("aaaaaaaaaaaaaaaaaaaab", "**************b")); + + // Alternating ?* patterns - also pathological without normalization. + // ?* means "1 or more chars", ?*?* means "2 or more chars", etc. + EXPECT_TRUE(match("ab", "?*")); + EXPECT_TRUE(match("abc", "?*?")); + EXPECT_TRUE(match("abc", "?*?*")); + EXPECT_TRUE(match("abcd", "?*?*")); + EXPECT_FALSE(match("a", "?*?*")); // Need at least 2 chars. + + // Pathological alternating patterns - must complete quickly. + EXPECT_FALSE(match("aaaaaaaaaaaaaaaaaaaab", "?*?*?*?*?*?*?*c")); + EXPECT_TRUE(match("aaaaaaaaaaaaaaaaaaaab", "?*?*?*?*?*?*?*b")); + EXPECT_FALSE(match("aaaaaaaaaaaaaaaaaaaab", "*?*?*?*?*?*?*?c")); + EXPECT_TRUE(match("aaaaaaaaaaaaaaaaaaaab", "*?*?*?*?*?*?*?b")); + + // Patterns with too many wildcards are rejected (returns false). + // Limit is 16 wildcards to prevent O(n^2) blowup. + EXPECT_TRUE(match("abcdefghijklmnop", "????????????????")); // 16 - ok + EXPECT_FALSE( + match("abcdefghijklmnopq", "?????????????????")); // 17 - rejected + EXPECT_FALSE( + match("anything", "?*?*?*?*?*?*?*?*?*")); // 18 wildcards - rejected +} + } // namespace From bde140bc30f75819632df591989b22307f875a92 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 14 Jan 2026 11:17:47 -0800 Subject: [PATCH 38/71] [Docs] Add `iree-fusilli-write` team to contributing doc (#23107) As titled. https://github.com/orgs/iree-org/teams/iree-fusilli-write Signed-off-by: Sambhav Jain --- docs/website/docs/developers/general/contributing.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/website/docs/developers/general/contributing.md b/docs/website/docs/developers/general/contributing.md index 74b82047832c..069b5b76a5c3 100644 --- a/docs/website/docs/developers/general/contributing.md +++ b/docs/website/docs/developers/general/contributing.md @@ -278,7 +278,7 @@ Access to repositories is divided into tiers following the | Tier | Description | Team links | | ---- | ----------- | --------- | Triage | **New project members should typically start here**
:material-check: Can be [assigned issues](https://docs.github.com/en/issues/tracking-your-work-with-issues/assigning-issues-and-pull-requests-to-other-github-users)
:material-check: Can apply labels to issues / PRs
:material-check: Can run workflows [without approval](https://docs.github.com/en/actions/managing-workflow-runs/approving-workflow-runs-from-public-forks) |
  • [iree-triage](https://github.com/orgs/iree-org/teams/iree-triage)
    (access to most repositories)
-Write | **Established contributors can request this access**
:material-check: Can [merge approved pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request)
:material-check: Can create branches
:material-check: Can [re-run workflows](https://docs.github.com/en/actions/managing-workflow-runs-and-deployments/managing-workflow-runs/re-running-workflows-and-jobs) |
+Write | **Established contributors can request this access**
:material-check: Can [merge approved pull requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/merging-a-pull-request)
:material-check: Can create branches
:material-check: Can [re-run workflows](https://docs.github.com/en/actions/managing-workflow-runs-and-deployments/managing-workflow-runs/re-running-workflows-and-jobs) |
  • [iree-write](https://github.com/orgs/iree-org/teams/iree-write)
    (access to most repositories)
  • [iree-turbine-write](https://github.com/orgs/iree-org/teams/iree-turbine-write)
    (access to iree-turbine)
  • [iree-fusilli-write](https://github.com/orgs/iree-org/teams/iree-fusilli-write)
    (access to fusilli)
Maintain/Admin | :material-check: Can [edit repository settings](https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features)
:material-check: Can push to [protected branches](https://docs.github.com/en/repositories/configuring-branches-and-merges-in-your-repository/managing-protected-branches/about-protected-branches) | Added case-by-case All access tiers first require joining the From 5a686811db2eeadaff5573c37e1e8fa12f819aa7 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:35:31 -0600 Subject: [PATCH 39/71] [GPU] Remove slice guard when doing pad producer fusion in FuseAndHoist (#23126) Since we are doing structured code generation we are able to handle zero slices appropriately and don't need slice guards. Fixes : https://github.com/iree-org/iree/issues/23028 Signed-off-by: Nirvedh Meshram --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 10 +++++- .../GPU/test/gpu_fuse_and_hoist_forall.mlir | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index 4a8dfffa10c5..c5f3e9c5c2de 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -266,7 +266,9 @@ struct FuseTilableSliceProducers final return failure(); } auto tilableProducer = sliceOp.getSource().getDefiningOp(); - if (!tilableProducer) { + // Pad fusion is handled separately as we dont want zero slice guards that + // happen by default. + if (!tilableProducer || isa(tilableProducer)) { return failure(); } @@ -394,6 +396,12 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); + auto zeroSliceGuard = [](tensor::ExtractSliceOp) -> std::optional { + // Do not use zero slice gaurd. + return false; + }; + patterns.add(context, + zeroSliceGuard); if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 3)"); return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir index deee7df12d3b..768ed232df8d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir @@ -875,3 +875,37 @@ func.func @fuse_warp_and_lane_foralls_with_coalesced_dma(%src: tensor<2x2x64xf32 // CHECK: } // CHECK: } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} // CHECK: return %[[THREAD_FORALL]] + + +// ----- +// Check that we dont make a zeroslice guard when fusing pad. +#map = affine_map<(d0) -> (d0 * 64)> +func.func @fuse_pad(%arg0: tensor, %arg1: index) -> tensor<128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<128xf16> + %padded = tensor.pad %arg0 low[0] high[%arg1] { + ^bb0(%arg10: index): + tensor.yield %cst : f16 + } : tensor to tensor<128xf16> + %1 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %0) -> (tensor<128xf16>) { + %2 = affine.apply #map(%arg2) + %extracted_slice = tensor.extract_slice %padded[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %extracted_slice_0 = tensor.extract_slice %arg3[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %3 = linalg.copy ins(%extracted_slice : tensor<64xf16>) outs(%extracted_slice_0 : tensor<64xf16>) -> tensor<64xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [64] [1] : tensor<64xf16> into tensor<128xf16> + } + } {mapping = [#gpu.thread]} + return %1 : tensor<128xf16> +} + +// CHECK-LABEL: func @fuse_pad +// CHECK: scf.forall +// CHECK-NOT: scf.if +// CHECK: tensor.pad +// CHECK: linalg.copy +// CHECK: scf.forall.in_parallel +// CHECK: return From ff055cce6ba33d38015a4832e96c2710fb39a183 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 15 Jan 2026 13:09:55 +0530 Subject: [PATCH 40/71] [LinalgExt] Add map_gather e2e tests for CPU/VMVX/ROCM (#23124) It attaches PartitionableLoops interfaces to the op and adds e2e tests for map_gather op on CPU/VMVX/ROCM. No additional lit tests are added because they are templated and identical to other LinalgExt ops. Adding lit tests do not provide additional value because all the code paths are covered via other LinalgExt op tests. Signed-off-by: Abhishek Varma --- .../PartitionableLoopsInterface.cpp | 2 + tests/e2e/linalg_ext_ops/BUILD.bazel | 6 +++ tests/e2e/linalg_ext_ops/CMakeLists.txt | 3 ++ tests/e2e/linalg_ext_ops/map_gather.mlir | 54 +++++++++++++++++++ 4 files changed, 65 insertions(+) create mode 100644 tests/e2e/linalg_ext_ops/map_gather.mlir diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp index a9f9c229e76a..cbb37d3c4d2e 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp @@ -300,6 +300,8 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry ®istry) { *ctx); IREE::LinalgExt::MapScatterOp::attachInterface< AllParallelAsPartitionableLoops>(*ctx); + IREE::LinalgExt::MapGatherOp::attachInterface< + AllParallelAsPartitionableLoops>(*ctx); }); registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { tensor::PadOp::attachInterface< diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel index 0fa227cc29c1..38442ae241e3 100644 --- a/tests/e2e/linalg_ext_ops/BUILD.bazel +++ b/tests/e2e/linalg_ext_ops/BUILD.bazel @@ -19,6 +19,7 @@ ALL_SRCS = enforce_glob( "attention.mlir", "attention_i1_mask_encoding.mlir", "gather.mlir", + "map_gather.mlir", "map_scatter.mlir", "scan.mlir", "scatter.mlir", @@ -69,6 +70,7 @@ VMVX_SRCS = enforce_glob( [ "arg_compare.mlir", "gather.mlir", + "map_gather.mlir", "map_scatter.mlir", "scan.mlir", "scatter.mlir", @@ -109,6 +111,7 @@ LLVM_GPU_SRCS = enforce_glob( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", + "map_gather.mlir", "map_scatter.mlir", ], ) @@ -134,6 +137,7 @@ ROCM_HIP_SRCS = enforce_glob( [ "arg_compare.mlir", "gather.mlir", + "map_gather.mlir", "map_scatter.mlir", "scan.mlir", "scatter.mlir", @@ -175,6 +179,7 @@ iree_check_single_backend_test_suite( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", + "map_gather.mlir", "map_scatter.mlir", "top-k.mlir", ], @@ -201,6 +206,7 @@ iree_check_single_backend_test_suite( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", + "map_gather.mlir", "map_scatter.mlir", "top-k.mlir", ], diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt index 8dcf9032dd93..ce5279321bc7 100644 --- a/tests/e2e/linalg_ext_ops/CMakeLists.txt +++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt @@ -18,6 +18,7 @@ iree_check_single_backend_test_suite( "attention.mlir" "attention_i1_mask_encoding.mlir" "gather.mlir" + "map_gather.mlir" "map_scatter.mlir" "scan.mlir" "scatter.mlir" @@ -57,6 +58,7 @@ iree_check_single_backend_test_suite( SRCS "arg_compare.mlir" "gather.mlir" + "map_gather.mlir" "map_scatter.mlir" "scan.mlir" "scatter.mlir" @@ -100,6 +102,7 @@ iree_check_single_backend_test_suite( SRCS "arg_compare.mlir" "gather.mlir" + "map_gather.mlir" "map_scatter.mlir" "scan.mlir" "scatter.mlir" diff --git a/tests/e2e/linalg_ext_ops/map_gather.mlir b/tests/e2e/linalg_ext_ops/map_gather.mlir new file mode 100644 index 000000000000..96a7b27d09de --- /dev/null +++ b/tests/e2e/linalg_ext_ops/map_gather.mlir @@ -0,0 +1,54 @@ +func.func @copy_like() { + %source = util.unfoldable_constant dense<123.0> : tensor<4x16x64xf32> + %output = tensor.empty() : tensor<4x16x64xf32> + %padding = arith.constant 0.0 : f32 + %0 = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index, %idx1: index, %idx2: index): + iree_linalg_ext.yield %idx0, %idx1, %idx2, %padding : index, index, index, f32 + } : tensor<4x16x64xf32> into tensor<4x16x64xf32> -> tensor<4x16x64xf32> + check.expect_almost_eq(%0, %source) : tensor<4x16x64xf32> + return +} + +func.func @expand_shape_like() { + %source = util.unfoldable_constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : tensor<16xf32> + %padding = arith.constant 0.0 : f32 + %output = tensor.empty() : tensor<4x4xf32> + %result = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index, %idx1: index): + %linear = affine.linearize_index disjoint [%idx0, %idx1] by (4, 4) : index + iree_linalg_ext.yield %linear, %padding : index, f32 + } : tensor<16xf32> into tensor<4x4xf32> -> tensor<4x4xf32> + %expected = tensor.expand_shape %source [[0, 1]] output_shape [4, 4] : tensor<16xf32> into tensor<4x4xf32> + check.expect_almost_eq(%result, %expected) : tensor<4x4xf32> + return +} + +func.func @collapse_shape_like() { + %source = util.unfoldable_constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi32> + %padding = arith.constant 0 : i32 + %output = tensor.empty() : tensor<16xi32> + %result = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index): + %2:2 = affine.delinearize_index %idx0 into (4, 4) : index, index + iree_linalg_ext.yield %2#0, %2#1, %padding : index, index, i32 + } : tensor<4x4xi32> into tensor<16xi32> -> tensor<16xi32> + %expected = tensor.collapse_shape %source [[0, 1]] : tensor<4x4xi32> into tensor<16xi32> + check.expect_eq(%result, %expected) : tensor<16xi32> + return +} + +func.func @pad_slice_like() { + // Source is 4 elements, output is 8 elements (with padding for out-of-bounds) + %source = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %padding = arith.constant 0.0 : f32 + %output = tensor.empty() : tensor<8xf32> + %result = iree_linalg_ext.map_gather %source into %output { + ^bb0(%idx0: index): + // Identity mapping - indices 0-3 are in-bounds, 4-7 get padding + iree_linalg_ext.yield %idx0, %padding : index, f32 + } : tensor<4xf32> into tensor<8xf32> -> tensor<8xf32> + %expected = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]> : tensor<8xf32> + check.expect_almost_eq(%result, %expected) : tensor<8xf32> + return +} From 511ce88d60166040565be5d76c22f8310a01beaf Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 15 Jan 2026 13:43:39 +0000 Subject: [PATCH 41/71] Revert "[GPU] Remove slice guard when doing pad producer fusion in FuseAndHoist" (#23132) Reverts iree-org/iree#23126 Fails post submit CI ci-extra: test_torch --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 10 +----- .../GPU/test/gpu_fuse_and_hoist_forall.mlir | 34 ------------------- 2 files changed, 1 insertion(+), 43 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index c5f3e9c5c2de..4a8dfffa10c5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -266,9 +266,7 @@ struct FuseTilableSliceProducers final return failure(); } auto tilableProducer = sliceOp.getSource().getDefiningOp(); - // Pad fusion is handled separately as we dont want zero slice guards that - // happen by default. - if (!tilableProducer || isa(tilableProducer)) { + if (!tilableProducer) { return failure(); } @@ -396,12 +394,6 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); - auto zeroSliceGuard = [](tensor::ExtractSliceOp) -> std::optional { - // Do not use zero slice gaurd. - return false; - }; - patterns.add(context, - zeroSliceGuard); if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 3)"); return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir index 768ed232df8d..deee7df12d3b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir @@ -875,37 +875,3 @@ func.func @fuse_warp_and_lane_foralls_with_coalesced_dma(%src: tensor<2x2x64xf32 // CHECK: } // CHECK: } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} // CHECK: return %[[THREAD_FORALL]] - - -// ----- -// Check that we dont make a zeroslice guard when fusing pad. -#map = affine_map<(d0) -> (d0 * 64)> -func.func @fuse_pad(%arg0: tensor, %arg1: index) -> tensor<128xf16> { - %c4 = arith.constant 4 : index - %c128 = arith.constant 128 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 - %0 = tensor.empty() : tensor<128xf16> - %padded = tensor.pad %arg0 low[0] high[%arg1] { - ^bb0(%arg10: index): - tensor.yield %cst : f16 - } : tensor to tensor<128xf16> - %1 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %0) -> (tensor<128xf16>) { - %2 = affine.apply #map(%arg2) - %extracted_slice = tensor.extract_slice %padded[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> - %extracted_slice_0 = tensor.extract_slice %arg3[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> - %3 = linalg.copy ins(%extracted_slice : tensor<64xf16>) outs(%extracted_slice_0 : tensor<64xf16>) -> tensor<64xf16> - scf.forall.in_parallel { - tensor.parallel_insert_slice %3 into %arg3[%2] [64] [1] : tensor<64xf16> into tensor<128xf16> - } - } {mapping = [#gpu.thread]} - return %1 : tensor<128xf16> -} - -// CHECK-LABEL: func @fuse_pad -// CHECK: scf.forall -// CHECK-NOT: scf.if -// CHECK: tensor.pad -// CHECK: linalg.copy -// CHECK: scf.forall.in_parallel -// CHECK: return From 9e2efd452f7446c8c6d2a8c612eb6fd7986ea8fe Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Thu, 15 Jan 2026 11:03:07 -0500 Subject: [PATCH 42/71] [Codegen] Implement reshape propagation by expansion for inner_tiled (#23118) Adds support for propagating tensor.collapse_shape and tensor.expand_shape operations through iree_codegen.inner_tiled ops. This enables reshape fusion to work with GPU MMA operations that use the inner_tiled abstraction. Two patterns are introduced: - FoldProducerCollapseShapeWithInnerTiled: Propagates collapse_shape through inner_tiled by expanding the operation and inserting a collapse on the result. - FoldConsumerExpandShapeWithInnerTiled: Propagates expand_shape back through inner_tiled by expanding all operands. Only outer (iteration) dimensions can be reshaped; inner dimensions that depend on the MMA layout are preserved. Also adds the patterns to the BlockDynamicDimensions pass and the PropagateReshapesByExpansion pass. --------- Signed-off-by: Max Dawkins --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../Codegen/Common/BlockDynamicDimensions.cpp | 5 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Common/PropagateReshapesByExpansion.cpp | 3 + .../test/propagate_reshapes_by_expansion.mlir | 202 +++++++++++ .../Dialect/Codegen/Transforms/BUILD.bazel | 35 ++ .../Dialect/Codegen/Transforms/CMakeLists.txt | 33 ++ .../Codegen/Transforms/ReshapeFusion.cpp | 313 ++++++++++++++++++ .../Dialect/Codegen/Transforms/Transforms.h | 28 ++ 9 files changed, 621 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index e5e5806b78f9..050485d60e78 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -195,6 +195,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen", "//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms:IREECodegenTransforms", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets", diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp index e8b66770b0bc..a4e8a0f8606b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" #include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" @@ -318,6 +319,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { controlFusionFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns(patterns, controlFusionFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns(patterns, + controlFusionFn); // Add patterns to fold `tensor.empty` operations with its consumers. tensor::populateFoldTensorEmptyPatterns(patterns); // Add some additional patterns that can simplify the IR. @@ -367,6 +370,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { controlFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( bubbleExpandShapePatterns, controlFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and // "pushed-down" `tensor.collapse_shape` operation with their interface // bindings or `tensor.empty` operations. diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 128b254f6a70..5ec758dba5d0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -229,6 +229,7 @@ iree_cc_library( iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + iree::compiler::Codegen::Dialect::Codegen::Transforms::IREECodegenTransforms iree::compiler::Codegen::Dialect::Codegen::Utils iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index febda9283240..20a7a3c8700b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -412,6 +413,8 @@ void PropagateReshapesByExpansionPass::runOnOperation() { }; linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, bubbleUpExpansionControlFn); + IREE::Codegen::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, bubbleUpExpansionControlFn); // Add patterns to do some additional cleanup (on top of canonicalizations // that can be done later) of reshape ops. tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index 70ccb68955d8..467d72156786 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -459,3 +459,205 @@ func.func @no_swap_rank_reducing_slice(%arg0: tensor<3x6xi8>) -> tensor<3xi16> { // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8> // CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] // CHECK-NEXT: iree_tensor_ext.bitcast %[[SLICE]] + +// ----- + +// Test propagating collapse_shape producer through inner_tiled op. +// Using proper 2D matmul indexing maps with MFMA_F32_16x16x16_F16 layout. +// Tensor shapes: LHS[outer_m, outer_k, 16, 16], RHS[outer_k, outer_n, 16, 16], ACC[outer_m, outer_n, 16, 16] +#contraction_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_collapse_through_inner_tiled( + %src: tensor<2x3x4x16x16xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<6x2x16x16xf32> { + // Collapse the first two outer dims of LHS: [2,3] -> [6] + %collapsed = tensor.collapse_shape %src [[0, 1], [2], [3], [4]] + : tensor<2x3x4x16x16xf16> into tensor<6x4x16x16xf16> + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + return %result : tensor<6x2x16x16xf32> +} + +// CHECK-LABEL: func @propagate_collapse_through_inner_tiled +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor<2x3x4x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor<6x2x16x16xf32> +// CHECK: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[SRC]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor<2x3x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<2x3x2x16x16xf32> +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INNER_TILED]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<2x3x2x16x16xf32> into tensor<6x2x16x16xf32> +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test propagating expand_shape consumer through inner_tiled op. +#contraction_accesses2 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_expand_through_inner_tiled( + %lhs: tensor<6x4x16x16xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<2x3x2x16x16xf32> { + %result = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%out) { + indexing_maps = #contraction_accesses2, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + %expanded = tensor.expand_shape %result [[0, 1], [2], [3], [4]] + output_shape [2, 3, 2, 16, 16] : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> + return %expanded : tensor<2x3x2x16x16xf32> +} + +// CHECK-LABEL: func @propagate_expand_through_inner_tiled +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<6x4x16x16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor<6x2x16x16xf32> +// CHECK-DAG: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x2x16x16xf32> into tensor<2x3x2x16x16xf32> +// CHECK-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor<6x4x16x16xf16> into tensor<2x3x4x16x16xf16> +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[EXPANDED_LHS]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor<2x3x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<2x3x2x16x16xf32> +// CHECK: return %[[INNER_TILED]] + +// ----- + +// Test that reshape touching inner dimensions is NOT propagated. +#contraction_accesses3 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @no_propagate_inner_dim_reshape( + %src: tensor<6x4x16x2x8xf16>, %rhs: tensor<4x2x16x16xf16>, %out: tensor<6x2x16x16xf32>) + -> tensor<6x2x16x16xf32> { + // Collapsing inner dims [3,4] which are part of inner tile - should NOT propagate. + %collapsed = tensor.collapse_shape %src [[0], [1], [2], [3, 4]] + : tensor<6x4x16x2x8xf16> into tensor<6x4x16x16xf16> + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses3, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor<6x4x16x16xf16>, tensor<4x2x16x16xf16> into tensor<6x2x16x16xf32> + return %result : tensor<6x2x16x16xf32> +} + +// CHECK-LABEL: func @no_propagate_inner_dim_reshape +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape +// CHECK: iree_codegen.inner_tiled ins(%[[COLLAPSED]], + +// ----- + +// Test propagating collapse_shape producer through inner_tiled op with dynamic outer shapes. +#contraction_accesses_dyn1 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_collapse_through_inner_tiled_dynamic( + %src: tensor, %rhs: tensor<4x2x16x16xf16>, %out: tensor) + -> tensor { + // Collapse the first two outer dims of LHS: [?, 3] -> [?*3] + %collapsed = tensor.collapse_shape %src [[0, 1], [2], [3], [4]] + : tensor into tensor + %result = iree_codegen.inner_tiled ins(%collapsed, %rhs) outs(%out) { + indexing_maps = #contraction_accesses_dyn1, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor, tensor<4x2x16x16xf16> into tensor + return %result : tensor +} + +// CHECK-LABEL: func @propagate_collapse_through_inner_tiled_dynamic +// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor +// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[SRC]], %c0 +// CHECK: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DIM]], 3, 2, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[SRC]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d2, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor, tensor<4x2x16x16xf16> into tensor +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INNER_TILED]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: : tensor into tensor +// CHECK: return %[[COLLAPSED]] + +// ----- + +// Test propagating expand_shape consumer through inner_tiled op with dynamic outer shapes. +#contraction_accesses_dyn2 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +func.func @propagate_expand_through_inner_tiled_dynamic( + %lhs: tensor, %rhs: tensor<4x2x16x16xf16>, %out: tensor, + %dyn_dim: index) + -> tensor { + %result = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%out) { + indexing_maps = #contraction_accesses_dyn2, + iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type], + kind = #iree_gpu.mma_layout, + permutations = [array, array, array], + semantics = #iree_gpu.mma_semantics + } : tensor, tensor<4x2x16x16xf16> into tensor + %expanded = tensor.expand_shape %result [[0, 1], [2], [3], [4]] + output_shape [%dyn_dim, 3, 2, 16, 16] : tensor into tensor + return %expanded : tensor +} + +// CHECK-LABEL: func @propagate_expand_through_inner_tiled_dynamic +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x2x16x16xf16> +// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[DYN_DIM:[A-Za-z0-9]+]]: index +// CHECK-DAG: %[[EXPANDED_OUT:.+]] = tensor.expand_shape %[[OUT]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DYN_DIM]], 3, 2, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK-DAG: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[LHS]] {{\[}}[0, 1], [2], [3], [4]{{\]}} +// CHECK-SAME: output_shape [%[[DYN_DIM]], 3, 4, 16, 16] +// CHECK-SAME: : tensor into tensor +// CHECK: %[[INNER_TILED:.+]] = iree_codegen.inner_tiled +// CHECK-SAME: ins(%[[EXPANDED_LHS]], %[[RHS]]) +// CHECK-SAME: outs(%[[EXPANDED_OUT]]) +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>] +// CHECK-SAME: iterator_types = [#linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type, #linalg.iterator_type] +// CHECK-SAME: : tensor, tensor<4x2x16x16xf16> into tensor +// CHECK: return %[[INNER_TILED]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel new file mode 100644 index 000000000000..69c2f3cf089d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel @@ -0,0 +1,35 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "IREECodegenTransforms", + srcs = [ + "ReshapeFusion.cpp", + ], + hdrs = [ + "Transforms.h", + ], + deps = [ + "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..05bdbb4f53f4 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/CMakeLists.txt @@ -0,0 +1,33 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + IREECodegenTransforms + HDRS + "Transforms.h" + SRCS + "ReshapeFusion.cpp" + DEPS + LLVMSupport + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRSupport + MLIRTensorDialect + MLIRTransformUtils + MLIRTransforms + iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp new file mode 100644 index 000000000000..f5b99f7ce9fc --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp @@ -0,0 +1,313 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::Codegen { + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +namespace { + +/// Check if an InnerTiledOp can be expanded by propagating a reshape through +/// it. The main real condition is that the inner dimensions of the op are not +/// expanded. Otherwise, we artificially restrict to single result inner_tiled +/// ops for now. +static LogicalResult +canExpandInnerTiledOp(InnerTiledOp op, OpOperand *fusedOperand, + ArrayRef reassociation) { + // Only single result inner_tiled ops are tested or used anywhere, so restrict + // to single result for now. + if (op->getNumResults() != 1) + return failure(); + + // Only outer dims can be expanded because inner dims depend on the `kind` + // attribute's implementation. + int64_t outerRank = + op.getIndexingMapsArray()[fusedOperand->getOperandNumber()] + .getNumResults(); + if (llvm::any_of(reassociation.drop_front(outerRank), + [](ArrayRef group) { return group.size() != 1; })) { + return failure(); + } + return success(); +} + +/// Expand an InnerTiledOp by propagating a reshape through it. +/// `fusedOperand` is the operand connected to the reshape. +/// `reassociation` describes how the collapsed dims map to expanded dims. +/// `expandedShape` is the full expanded shape (outer + inner dims). +/// `expandedValue` is the expanded value to replace the fused operand. +/// `outputReassociations` will be cleared and filled with the reassociation +/// indices for each output, to be used for collapsing the result back to its +/// original shape. +/// The outer dimensions of the InnerTiledOp are expected to not be expanded, +/// which is enforced by the canExpandInnerTiledOp precondition. +static InnerTiledOp expandInnerTiledOp( + InnerTiledOp op, OpOperand *fusedOperand, + ArrayRef reassociation, + ArrayRef expandedShape, Value expandedValue, + SmallVectorImpl> &outputReassociations, + PatternRewriter &rewriter) { + assert(reassociation.size() == + cast(fusedOperand->get().getType()).getRank() && + "expected reassociation rank to match fused operand rank"); + + // Build mapping: iterDim -> list of (expandedIterDim, size). + SmallVector indexingMaps = op.getIndexingMapsArray(); + AffineMap fusedMap = indexingMaps[fusedOperand->getOperandNumber()]; + int64_t numIterDims = fusedMap.getNumDims(); + SmallVector>> iterDimExpansion( + numIterDims); + int64_t expandedDimCounter = 0; + for (auto [resultIdx, expr] : llvm::enumerate(fusedMap.getResults())) { + int64_t iterDim = cast(expr).getPosition(); + for (int64_t expandedOperandIdx : reassociation[resultIdx]) { + iterDimExpansion[iterDim].push_back( + {expandedDimCounter++, expandedShape[expandedOperandIdx]}); + } + } + // Iteration dims outside the fused map's results are independent from the + // expansion, but update their dim position to account for earlier expanded + // dims. Get iteration domain to query sizes of dims not in the fused operand. + SmallVector iterationDomain = op.getIterationDomain(rewriter); + for (int64_t i = 0; i < numIterDims; ++i) { + if (iterDimExpansion[i].empty()) + iterDimExpansion[i].push_back( + {expandedDimCounter++, iterationDomain[i].size}); + } + + SmallVector newIndexingMaps; + SmallVector newOperands; + outputReassociations.clear(); + Location loc = op.getLoc(); + for (OpOperand &operand : op->getOpOperands()) { + AffineMap origMap = indexingMaps[operand.getOperandNumber()]; + auto operandType = cast(operand.get().getType()); + int64_t operandOuterRank = origMap.getNumResults(); + int64_t innerRank = operandType.getRank() - operandOuterRank; + SmallVector newMapResults; + SmallVector operandReassoc; + SmallVector expandedOperandSizes; + int64_t dimCounter = 0; + for (AffineExpr expr : origMap.getResults()) { + int64_t iterDim = cast(expr).getPosition(); + ReassociationIndices group; + for (auto [expandedDim, size] : iterDimExpansion[iterDim]) { + newMapResults.push_back(getAffineDimExpr(expandedDim, op.getContext())); + group.push_back(dimCounter++); + expandedOperandSizes.push_back(size); + } + operandReassoc.push_back(group); + } + // Inner dims are never expanded. + for (int64_t i = 0; i < innerRank; ++i) { + operandReassoc.push_back({dimCounter++}); + expandedOperandSizes.push_back(tensor::getMixedSize( + rewriter, loc, operand.get(), operandOuterRank + i)); + } + newIndexingMaps.push_back( + AffineMap::get(expandedDimCounter, 0, newMapResults, op.getContext())); + + // Store output reassociations for later use. + if (operand.getOperandNumber() >= op.getNumInputs()) { + outputReassociations.push_back(operandReassoc); + } + + if (&operand == fusedOperand) { + newOperands.push_back(expandedValue); + continue; + } + + if (llvm::all_of(operandReassoc, [](ArrayRef group) { + return group.size() == 1; + })) { + newOperands.push_back(operand.get()); + continue; + } + + SmallVector staticShape; + std::tie(staticShape, std::ignore) = + decomposeMixedValues(expandedOperandSizes); + auto expandedType = + RankedTensorType::get(staticShape, operandType.getElementType()); + newOperands.push_back(tensor::ExpandShapeOp::create( + rewriter, loc, expandedType, operand.get(), operandReassoc, + expandedOperandSizes)); + } + + // Expand iterator types. + SmallVector newIterTypes; + for (auto [idx, iterType] : llvm::enumerate(op.getIteratorTypesArray())) { + newIterTypes.append(iterDimExpansion[idx].size(), iterType); + } + + int64_t numInputs = op.getNumInputs(); + SmallVector newInputs(newOperands.begin(), + newOperands.begin() + numInputs); + SmallVector newOutputs(newOperands.begin() + numInputs, + newOperands.end()); + + // Permutations are unchanged, since they are for inner dims, but we need to + // convert from ArrayAttr to SmallVector>. + std::optional>> newPermutations; + if (auto permAttr = op.getPermutations()) { + newPermutations = llvm::map_to_vector( + permAttr->getAsRange(), [](DenseI64ArrayAttr perm) { + return SmallVector(perm.asArrayRef()); + }); + } + + return InnerTiledOp::create(rewriter, loc, newInputs, newOutputs, + newIndexingMaps, newIterTypes, op.getKind(), + op.getSemantics(), newPermutations); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +/// Pattern to propagate a tensor::CollapseShapeOp through a consumer +/// InnerTiledOp. The collapsed dimensions must not include any inner dimensions +/// of the InnerTiledOp. +/// +/// Example: +/// %collapsed = tensor.collapse_shape %src [[0, 1], ...] +/// %result = inner_tiled ins(%collapsed, ...) outs(%out) +/// => +/// %expanded_out = tensor.expand_shape %out [[0, 1], ...] +/// %result = inner_tiled ins(%src, ...) outs(%expanded_out) +/// %collapsed_result = tensor.collapse_shape %result [[0, 1], ...] +struct FoldProducerCollapseShapeWithInnerTiled + : public OpRewritePattern { + FoldProducerCollapseShapeWithInnerTiled(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + if (!collapseOp->hasOneUse()) { + return failure(); + } + OpOperand &use = *collapseOp->use_begin(); + auto innerTiledOp = dyn_cast(use.getOwner()); + if (!innerTiledOp || !controlFn(&use)) { + return failure(); + } + if (failed(canExpandInnerTiledOp(innerTiledOp, &use, + collapseOp.getReassociationIndices()))) { + return failure(); + } + + SmallVector expandedShape = tensor::getMixedSizes( + rewriter, collapseOp.getLoc(), collapseOp.getSrc()); + SmallVector> outputReassociations; + InnerTiledOp expandedOp = expandInnerTiledOp( + innerTiledOp, &use, collapseOp.getReassociationIndices(), expandedShape, + collapseOp.getSrc(), outputReassociations, rewriter); + + SmallVector results; + for (auto [idx, result] : llvm::enumerate(expandedOp.getResults())) { + auto resultType = + cast(innerTiledOp.getResultTypes()[idx]); + results.push_back(tensor::CollapseShapeOp::create( + rewriter, innerTiledOp.getLoc(), resultType, result, + outputReassociations[idx])); + } + rewriter.replaceOp(innerTiledOp, results); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + +/// Pattern to propagate a tensor::ExpandShapeOp consumer back through an +/// InnerTiledOp producer. The expanded dimensions must not include any inner +/// dimensions of the InnerTiledOp. +/// +/// Example: +/// %result = inner_tiled ins(%lhs, ...) outs(%out) +/// %expanded = tensor.expand_shape %result [[0, 1], ...] +/// => +/// %expanded_lhs = tensor.expand_shape %lhs [[0, 1], ...] +/// %expanded_out = tensor.expand_shape %out [[0, 1], ...] +/// %result = inner_tiled ins(%expanded_lhs, ...) outs(%expanded_out) +struct FoldConsumerExpandShapeWithInnerTiled + : public OpRewritePattern { + FoldConsumerExpandShapeWithInnerTiled(MLIRContext *context, + linalg::ControlFusionFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto producerResult = dyn_cast(expandOp.getSrc()); + if (!producerResult) { + return failure(); + } + auto innerTiledOp = dyn_cast(producerResult.getOwner()); + if (!innerTiledOp || !controlFn(&expandOp.getSrcMutable())) { + return failure(); + } + + int64_t resultIdx = producerResult.getResultNumber(); + OpOperand *outputOperand = innerTiledOp.getDpsInitOperand(resultIdx); + if (failed(canExpandInnerTiledOp(innerTiledOp, outputOperand, + expandOp.getReassociationIndices()))) { + return failure(); + } + + // The DPS init will be expanded in the same way as the result, so insert + // the expand_shape on the init first in order to reuse the + // expandInnerTiledOp transformation utility. + SmallVector expandedShape = expandOp.getMixedOutputShape(); + SmallVector staticShape; + std::tie(staticShape, std::ignore) = decomposeMixedValues(expandedShape); + auto sourceType = cast(outputOperand->get().getType()); + auto expandedType = + RankedTensorType::get(staticShape, sourceType.getElementType()); + auto expandedInit = tensor::ExpandShapeOp::create( + rewriter, expandOp.getLoc(), expandedType, outputOperand->get(), + expandOp.getReassociationIndices(), expandedShape); + + SmallVector> outputReassociations; + InnerTiledOp expandedOp = expandInnerTiledOp( + innerTiledOp, outputOperand, expandOp.getReassociationIndices(), + expandedShape, expandedInit, outputReassociations, rewriter); + rewriter.replaceOp(expandOp, expandedOp.getResult(resultIdx)); + return success(); + } + +private: + linalg::ControlFusionFn controlFn; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Populate Functions +//===----------------------------------------------------------------------===// + +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes) { + patterns.add(patterns.getContext(), + controlFoldingReshapes); +} + +} // namespace mlir::iree_compiler::IREE::Codegen diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h new file mode 100644 index 000000000000..ef83658add96 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/Transforms.h @@ -0,0 +1,28 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ +#define IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::Codegen { + +//===----------------------------------------------------------------------===// +// Populate functions. +//===----------------------------------------------------------------------===// + +/// Populate patterns to propagate reshapes by expansion. This folds +/// tensor.expand_shape and tensor.collapse_shape ops with their producer +/// and consumer operations respectively. +void populateFoldReshapeOpsByExpansionPatterns( + RewritePatternSet &patterns, + const linalg::ControlFusionFn &controlFoldingReshapes); + +} // namespace mlir::iree_compiler::IREE::Codegen + +#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_TRANSFORMS_TRANSFORMS_H_ From 06ecabd76ece996f0deb0ef805535e500bdcfbd8 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Thu, 15 Jan 2026 13:01:16 -0600 Subject: [PATCH 43/71] Reapply [GPU] Remove slice guard when doing pad producer fusion in FuseAndHoist (#23140) In the original PR https://github.com/iree-org/iree/pull/23126, we allowed the `FuseTilableDestinationProducers` pattern to still tile pad ops, however it is best to leave that to `ExtractSliceOfPadTensorSwapPattern` now, here is the status of the new test `fuse_pad_dest` added by the reapplication Before https://github.com/iree-org/iree/pull/23126 fusion didnt happen After https://github.com/iree-org/iree/pull/23126 hang becuase it was creating new ops and making the old ops dead and got stuck in a loop After this PR we are able to fold the pad op Fixes : https://github.com/iree-org/iree/issues/23028 --------- Signed-off-by: Nirvedh Meshram --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 15 +++- .../GPU/test/gpu_fuse_and_hoist_forall.mlir | 69 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index 4a8dfffa10c5..48f4e11ba04c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -175,6 +175,11 @@ struct FuseTilableDestinationProducers final : OpRewritePattern { tileableProducer = forallOp.getTiedLoopInit(iterArg) ->get() .getDefiningOp(); + // Pad fusion is handled separately as we dont want zero slice guards that + // happen by default. + if (tileableProducer && isa(tileableProducer)) { + tileableProducer = nullptr; + } if (tileableProducer) { break; } @@ -266,7 +271,9 @@ struct FuseTilableSliceProducers final return failure(); } auto tilableProducer = sliceOp.getSource().getDefiningOp(); - if (!tilableProducer) { + // Pad fusion is handled separately as we dont want zero slice guards that + // happen by default. + if (!tilableProducer || isa(tilableProducer)) { return failure(); } @@ -394,6 +401,12 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); + auto zeroSliceGuard = [](tensor::ExtractSliceOp) -> std::optional { + // Do not use zero slice gaurd. + return false; + }; + patterns.add(context, + zeroSliceGuard); if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 3)"); return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir index deee7df12d3b..278dcdb89474 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir @@ -875,3 +875,72 @@ func.func @fuse_warp_and_lane_foralls_with_coalesced_dma(%src: tensor<2x2x64xf32 // CHECK: } // CHECK: } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} // CHECK: return %[[THREAD_FORALL]] + +// ----- + +// Check that we dont make a zeroslice guard when fusing pad. +#map = affine_map<(d0) -> (d0 * 64)> +func.func @fuse_pad(%arg0: tensor, %arg1: index) -> tensor<128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<128xf16> + %padded = tensor.pad %arg0 low[0] high[%arg1] { + ^bb0(%arg10: index): + tensor.yield %cst : f16 + } : tensor to tensor<128xf16> + %1 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %0) -> (tensor<128xf16>) { + %2 = affine.apply #map(%arg2) + %extracted_slice = tensor.extract_slice %padded[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %extracted_slice_0 = tensor.extract_slice %arg3[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %3 = linalg.copy ins(%extracted_slice : tensor<64xf16>) outs(%extracted_slice_0 : tensor<64xf16>) -> tensor<64xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [64] [1] : tensor<64xf16> into tensor<128xf16> + } + } {mapping = [#gpu.thread]} + return %1 : tensor<128xf16> +} + +// CHECK-LABEL: func @fuse_pad +// CHECK: scf.forall +// CHECK-NOT: scf.if +// CHECK: tensor.pad +// CHECK: linalg.copy +// CHECK: scf.forall.in_parallel +// CHECK: return + +// ----- + +// Check that we can fuse padded destinations. +#map = affine_map<(d0) -> (d0 * 64)> +func.func @fuse_pad_dest(%arg0: tensor<128xf16>, %arg1: index) -> tensor<128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<128xf16> + %padded = tensor.pad %arg0 low[0] high[1] { + ^bb0(%arg10: index): + tensor.yield %cst : f16 + } : tensor<128xf16> to tensor<129xf16> + %extracted_slice_dest = tensor.extract_slice %padded[0] [128] [1] : tensor<129xf16> to tensor<128xf16> + %1 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %extracted_slice_dest) -> (tensor<128xf16>) { + %2 = affine.apply #map(%arg2) + %extracted_slice = tensor.extract_slice %0[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %extracted_slice_0 = tensor.extract_slice %arg3[%2] [64] [1] : tensor<128xf16> to tensor<64xf16> + %3 = linalg.copy ins(%extracted_slice : tensor<64xf16>) outs(%extracted_slice_0 : tensor<64xf16>) -> tensor<64xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [64] [1] : tensor<64xf16> into tensor<128xf16> + } + } {mapping = [#gpu.thread]} + return %1 : tensor<128xf16> +} + +// CHECK-LABEL: func @fuse_pad_dest +// CHECK-NOT: tensor.pad +// CHECK: scf.forall +// CHECK-NOT: tensor.pad +// CHECK: linalg.copy +// CHECK: scf.forall.in_parallel +// CHECK: return From 7dddeb290dc04b41e45df5f2374fb89c51e71aa0 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 15:39:15 -0500 Subject: [PATCH 44/71] Enforce braces in plugins and bindings. NFC. 1/n (#23143) clang-format learned to insert these. Clean up missing braces before changing `.clang-format` --- .../c/iree/compiler/loader/loader.cpp | 9 +- .../python/IREECompilerDialectsModule.cpp | 18 +- .../Conversion/ConvertCollectives.cpp | 9 +- .../Conversion/LegalizeControlFlow.cpp | 9 +- .../Conversion/LegalizeShapeComputations.cpp | 12 +- .../Conversion/MapStableHLOToScalarOp.h | 24 ++- .../Preprocessing/Canonicalization.cpp | 78 +++++--- .../Preprocessing/DotGeneralToDot.cpp | 21 ++- .../Preprocessing/FlattenTuplesInCFG.cpp | 21 ++- .../Preprocessing/FlattenTuplesInSCF.cpp | 12 +- .../Conversion/Preprocessing/LowerComplex.cpp | 6 +- .../Preprocessing/StableHLOToStableHLO.cpp | 166 ++++++++++++------ .../Conversion/StableHLOCustomCalls.cpp | 3 +- .../StableHLOToIREEInputDialects.cpp | 45 +++-- .../Conversion/StableHLOToLinalgExt.cpp | 54 ++++-- .../TOSA/InputConversion/Converti48Toi64.cpp | 12 +- .../TOSA/InputConversion/StripSignedness.cpp | 15 +- .../TOSA/InputConversion/TosaToLinalgExt.cpp | 12 +- .../InputConversion/BindSymbolicShapes.cpp | 42 +++-- .../Torch/InputConversion/BitCastTensor.cpp | 29 +-- .../ConvertTMTensorToLinalgExt.cpp | 9 +- .../Torch/InputConversion/FuncConversion.cpp | 42 +++-- .../input/Torch/InputConversion/Passes.cpp | 3 +- compiler/plugins/target/CUDA/CUDATarget.cpp | 21 ++- .../target/LLVMCPU/Builtins/Device.cpp | 9 +- .../plugins/target/LLVMCPU/Builtins/Musl.cpp | 3 +- .../plugins/target/LLVMCPU/LLVMCPUTarget.cpp | 9 +- .../plugins/target/LLVMCPU/LLVMIRPasses.cpp | 3 +- .../target/LLVMCPU/LLVMTargetOptions.cpp | 46 +++-- .../plugins/target/LLVMCPU/LibraryBuilder.cpp | 12 +- .../plugins/target/LLVMCPU/LinkerTool.cpp | 6 +- .../LLVMCPU/internal/AndroidLinkerTool.cpp | 6 +- .../LLVMCPU/internal/EmbeddedLinkerTool.cpp | 9 +- .../LLVMCPU/internal/UnixLinkerTool.cpp | 9 +- .../LLVMCPU/internal/WasmLinkerTool.cpp | 6 +- .../LLVMCPU/internal/WindowsLinkerTool.cpp | 9 +- .../target/MetalSPIRV/MSLToMetalLib.cpp | 6 +- .../plugins/target/MetalSPIRV/SPIRVToMSL.cpp | 6 +- compiler/plugins/target/ROCM/ROCMTarget.cpp | 9 +- .../plugins/target/ROCM/ROCMTargetUtils.cpp | 14 +- 40 files changed, 553 insertions(+), 281 deletions(-) diff --git a/compiler/bindings/c/iree/compiler/loader/loader.cpp b/compiler/bindings/c/iree/compiler/loader/loader.cpp index f3c04646bb1e..73b6647e0d14 100644 --- a/compiler/bindings/c/iree/compiler/loader/loader.cpp +++ b/compiler/bindings/c/iree/compiler/loader/loader.cpp @@ -19,8 +19,9 @@ namespace { using DlHandle = HMODULE; DlHandle loadLibrary(const char *libraryPath) { HMODULE lib = LoadLibraryExA(libraryPath, nullptr, 0); - if (lib) + if (lib) { return lib; + } DWORD errorMessageID = GetLastError(); LPSTR messageBuffer = nullptr; size_t size = FormatMessageA( @@ -48,8 +49,9 @@ DlHandle loadLibrary(const char *libraryPath) { DlHandle lib = dlopen(libraryPath, RTLD_NOW | RTLD_LOCAL); if (!lib) { const char *reason = dlerror(); - if (!reason) + if (!reason) { reason = ""; + } fprintf(stderr, "IREE COMPILER ERROR: Could not open compiler library %s : %s\n", libraryPath, reason); @@ -73,8 +75,9 @@ DlHandle libraryHandle = nullptr; #undef HANDLE_VERSIONED_SYMBOL void assertLoaded() { - if (libraryHandle) + if (libraryHandle) { return; + } fprintf(stderr, "FATAL ERROR: Attempt to call IREE compiler stub methods before " "library loaded\n"); diff --git a/compiler/bindings/python/IREECompilerDialectsModule.cpp b/compiler/bindings/python/IREECompilerDialectsModule.cpp index 93ef30749db3..d20e58afafd4 100644 --- a/compiler/bindings/python/IREECompilerDialectsModule.cpp +++ b/compiler/bindings/python/IREECompilerDialectsModule.cpp @@ -51,8 +51,9 @@ ireeCodegenGetTunerRootOpsBinding(MlirModule module) { } static std::vector getIntArrayAttrValues(MlirAttribute attr) { - if (mlirAttributeIsNull(attr) || !mlirAttributeIsAArray(attr)) + if (mlirAttributeIsNull(attr) || !mlirAttributeIsAArray(attr)) { return {}; + } std::vector result; size_t n = mlirArrayAttrGetNumElements(attr); @@ -261,8 +262,9 @@ NB_MODULE(_ireeCompilerDialects, m) { "prefetch_num_stages", [](MlirAttribute self) -> std::optional { auto attr = ireeGPUPipelineOptionsAttrGetPrefetchNumStages(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return mlirIntegerAttrGetValueInt(attr); + } return std::nullopt; }) .def_property_readonly( @@ -271,16 +273,18 @@ NB_MODULE(_ireeCompilerDialects, m) { auto attr = ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts( self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return mlirBoolAttrGetValue(attr); + } return std::nullopt; }) .def_property_readonly( "use_igemm_convolution", [](MlirAttribute self) -> std::optional { auto attr = ireeGPUPipelineOptionsAttrGetUseIgemmConvolution(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return mlirBoolAttrGetValue(attr); + } return std::nullopt; }) .def_property_readonly( @@ -288,8 +292,9 @@ NB_MODULE(_ireeCompilerDialects, m) { [](MlirAttribute self) -> std::optional { auto attr = ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return attr; + } return std::nullopt; }); @@ -485,8 +490,9 @@ NB_MODULE(_ireeCompilerDialects, m) { .def_property_readonly( "mma_kind", [](MlirAttribute self) -> std::optional { auto attr = ireeGPULoweringConfigAttrGetMmaKind(self); - if (!mlirAttributeIsNull(attr)) + if (!mlirAttributeIsNull(attr)) { return attr; + } return std::nullopt; }); diff --git a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp index e6175c2af2c5..aff8e78f0c2b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp @@ -97,8 +97,9 @@ static std::pair makeSplitColorAndKey(Location loc, OpBuilder &builder) { IndexSet indexSet(loc, builder); Value noColor = indexSet.get(-1); - if (!groups) + if (!groups) { return std::make_pair(noColor, noColor); + } auto groupsType = cast(groups.getType()); assert(groupsType.getRank() == 2); @@ -311,8 +312,9 @@ static Value createChannelWithGroupInfo( DenseIntElementsAttr replicaGroups, std::optional useGlobalDeviceIds, OpBuilder &builder) { // Set numPartitions to 1 if not set by the user. - if (numPartitions == -1) + if (numPartitions == -1) { numPartitions = 1; + } // Base channel that may be split by the group info. Value baseChannel = IREE::Flow::ChannelDefaultOp::create( @@ -854,8 +856,9 @@ struct CollectivePermuteOpConversion int64_t numParticipants = mode == CollectiveOpGroupMode::CrossReplica ? numReplicas : numPartitions; - if (numParticipants == -1) + if (numParticipants == -1) { numParticipants = 1; + } SmallVector replicaGroups; for (int64_t i = 0; i < numParticipants; ++i) { replicaGroups.push_back(rewriter.getI64IntegerAttr(i)); diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp index dd1d9c9434be..f2506ecafef0 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeControlFlow.cpp @@ -65,12 +65,14 @@ struct ScfForBounds { std::optional extractForBounds(mlir::stablehlo::WhileOp op) { Block &cond = op.getCond().front(); Block &body = op.getBody().front(); - if (cond.getOperations().size() != 2) + if (cond.getOperations().size() != 2) { return std::nullopt; + } auto matchBbArg = [](Value v, Block &block) -> std::optional { - if (!isa(v) || v.getParentBlock() != &block) + if (!isa(v) || v.getParentBlock() != &block) { return std::nullopt; + } return cast(v).getArgNumber(); }; @@ -87,8 +89,9 @@ std::optional extractForBounds(mlir::stablehlo::WhileOp op) { } std::optional iterArg = matchBbArg(compare.getLhs(), cond); - if (!iterArg) + if (!iterArg) { return std::nullopt; + } auto add = dyn_cast_if_present( body.getTerminator()->getOperand(*iterArg).getDefiningOp()); diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp index d80d4ed123bb..070352296c86 100644 --- a/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeShapeComputations.cpp @@ -47,8 +47,9 @@ struct HloElementwiseConverter : OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final { - if (!opIsShapeComputation(op)) + if (!opIsShapeComputation(op)) { return failure(); + } auto resultTy = cast(op.getType()); @@ -86,8 +87,9 @@ struct ConcatenateConverter final LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter &rewriter) const override { - if (!opIsShapeComputation(op)) + if (!opIsShapeComputation(op)) { return failure(); + } Location loc = op.getLoc(); auto resultTy = cast(op.getType()); @@ -144,14 +146,16 @@ struct ReshapeConverter : OpRewritePattern { PatternRewriter &rewriter) const override { Value operand = op.getOperand(); auto shapedTy = cast(operand.getType()); - if (!shapedTy.hasRank() || shapedTy.getRank() > 1) + if (!shapedTy.hasRank() || shapedTy.getRank() > 1) { return failure(); + } auto resultTy = cast(op.getType()); auto fromElements = op.getOperand().getDefiningOp(); - if (!fromElements) + if (!fromElements) { return failure(); + } rewriter.replaceOpWithNewOp( op, resultTy, fromElements.getOperands()); diff --git a/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h b/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h index 06f30deb65a1..0c91d76954a1 100644 --- a/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h +++ b/compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h @@ -456,8 +456,9 @@ inline Value mapStableHloOpToStdScalarOp( return ScalarFOp::create(*b, loc, predicate.value(), lhs, rhs); } - if (auto complexType = dyn_cast(elementType)) + if (auto complexType = dyn_cast(elementType)) { return cmpComplex(loc, lhs, rhs, comparisonDirection, b); + } return nullptr; } @@ -602,11 +603,12 @@ inline Value mapStableHloOpToStdScalarOp( Value lhs = operands.front(); Type complexTy = lhs.getType(); - if (!isa(complexTy)) + if (!isa(complexTy)) { return MapStableHloOpToScalarOpImpl< IsFloatType, arith::MaximumFOp, IsSignedIntegerType, arith::MaxSIOp, IsUnsignedIntegerType, arith::MaxUIOp>{}(loc, resultTypes, argTypes, adaptor.getOperands(), b); + } assert(resultTypes.size() == 1 && "MaxOp should return a single result"); assert(operands.size() == 2 && "MaxOp should take exactly two arguments"); @@ -626,11 +628,12 @@ inline Value mapStableHloOpToStdScalarOp( Value lhs = operands.front(); Type complexTy = lhs.getType(); - if (!isa(complexTy)) + if (!isa(complexTy)) { return MapStableHloOpToScalarOpImpl< IsFloatType, arith::MinimumFOp, IsSignedIntegerType, arith::MinSIOp, IsUnsignedIntegerType, arith::MinUIOp>{}(loc, resultTypes, argTypes, adaptor.getOperands(), b); + } assert(resultTypes.size() == 1 && "MinOp should return a single result"); assert(operands.size() == 2 && "MinOp should take exactly two arguments"); @@ -646,8 +649,9 @@ template <> inline Value mapStableHloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, stablehlo::RealOp::Adaptor adaptor, OpBuilder *b) { - if (!isa(adaptor.getOperand().getType())) + if (!isa(adaptor.getOperand().getType())) { return adaptor.getOperand(); + } return MapStableHloOpToScalarOpImpl{}( loc, resultTypes, argTypes, adaptor.getOperands(), b); } @@ -656,9 +660,10 @@ template <> inline Value mapStableHloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, stablehlo::ImagOp::Adaptor adaptor, OpBuilder *b) { - if (!isa(adaptor.getOperand().getType())) + if (!isa(adaptor.getOperand().getType())) { return arith::ConstantOp::create( *b, loc, b->getZeroAttr(adaptor.getOperand().getType())); + } return MapStableHloOpToScalarOpImpl{}( loc, resultTypes, argTypes, adaptor.getOperands(), b); } @@ -813,15 +818,18 @@ inline Value mapStableHloOpToStdScalarOp( Type resultType = getElementTypeOrSelf(resultTypes.front()); // Skip needless casts. - if (argType == resultType) + if (argType == resultType) { return adaptor.getOperand(); + } if (!isa(resultType) || - !isa(argType)) + !isa(argType)) { return nullptr; + } - if (resultType.getIntOrFloatBitWidth() != argType.getIntOrFloatBitWidth()) + if (resultType.getIntOrFloatBitWidth() != argType.getIntOrFloatBitWidth()) { return nullptr; + } return mlir::arith::BitcastOp::create(*b, loc, resultTypes, adaptor.getOperands()); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp index 8a996a19d6b5..6518c9653572 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/Canonicalization.cpp @@ -54,8 +54,9 @@ static bool isIotaRange(ArrayRef dims) { static bool isIotaRange(ElementsAttr attr) { auto elems = attr.tryGetValues(); - if (!elems) + if (!elems) { return false; + } for (auto [idx, value] : llvm::enumerate(*elems)) { if (idx != value) { @@ -119,8 +120,9 @@ struct AddOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -166,8 +168,9 @@ struct SubtractOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -208,8 +211,9 @@ struct MulOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -334,8 +338,9 @@ struct CompareOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } // Bail out on non-integer comparison. // TODO: Support more comparison types. @@ -410,8 +415,9 @@ struct SelectOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value trueVal = op.getOnTrue(); Value falseVal = op.getOnFalse(); @@ -437,16 +443,19 @@ struct SelectOpCanon final : OpRewritePattern { // Handle elementwise selection when both outcomes are also constants. This // will create a new, likely non-splat constant. - if (cond.getNumElements() > kFoldOpEltLimit) + if (cond.getNumElements() > kFoldOpEltLimit) { return failure(); + } ElementsAttr trueAttr; - if (!matchPattern(trueVal, m_Constant(&trueAttr))) + if (!matchPattern(trueVal, m_Constant(&trueAttr))) { return failure(); + } ElementsAttr falseAttr; - if (!matchPattern(falseVal, m_Constant(&falseAttr))) + if (!matchPattern(falseVal, m_Constant(&falseAttr))) { return failure(); + } SmallVector newValues; newValues.reserve(cond.getNumElements()); @@ -469,13 +478,15 @@ struct BroadcastInDimOpCanon final LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type) + if (!type) { return failure(); + } Value operand = op.getOperand(); auto operandTy = dyn_cast(operand.getType()); - if (!operandTy) + if (!operandTy) { return failure(); + } // Fold when broadcast is a noop. auto dims = op.getBroadcastDimensions(); @@ -534,12 +545,14 @@ struct ConcatenateOpCanon final LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type || !type.hasStaticShape()) + if (!type || !type.hasStaticShape()) { return failure(); + } size_t numElems = type.getNumElements(); - if (numElems > kFoldOpEltLimit) + if (numElems > kFoldOpEltLimit) { return failure(); + } // Fold concatenate when all inputs are constants. OperandRange inputs = op.getInputs(); @@ -578,8 +591,9 @@ struct ConvertOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op, PatternRewriter &rewriter) const override { // Check if this convert is a noop. - if (op.getOperand().getType() != op.getType()) + if (op.getOperand().getType() != op.getType()) { return failure(); + } rewriter.replaceOp(op, op.getOperand()); return success(); @@ -673,8 +687,9 @@ struct ChainedDynamicBroadcastInDimCanonicalization final auto precedingBcast = bcast.getOperand() .getDefiningOp(); - if (!precedingBcast) + if (!precedingBcast) { return failure(); + } // Compose broadcast dimensions. SmallVector composition; @@ -759,8 +774,9 @@ struct EmptyReduceOpCanon final : OpRewritePattern { "unranked input unsupported"); } - if (!llvm::is_contained(elemTy.getShape(), 0)) + if (!llvm::is_contained(elemTy.getShape(), 0)) { return failure(); + } Location loc = op.getLoc(); DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({}); @@ -799,8 +815,9 @@ struct DynamicReshapeOpCanon final PatternRewriter &rewriter) const override { // This is a noop when the output type is already a static shape. auto type = dyn_cast(op.getType()); - if (!type || !type.hasStaticShape()) + if (!type || !type.hasStaticShape()) { return failure(); + } rewriter.replaceOpWithNewOp(op, type, op.getOperand()); @@ -816,8 +833,9 @@ struct GetTupleElementOpCanon final PatternRewriter &rewriter) const override { auto constructor = op.getOperand().getDefiningOp(); - if (!constructor) + if (!constructor) { return failure(); + } Value result = constructor.getOperand(op.getIndex()); rewriter.replaceOp(op, result); @@ -831,8 +849,9 @@ struct RealOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::RealOp op, PatternRewriter &rewriter) const override { auto complex = op.getOperand().getDefiningOp(); - if (!complex) + if (!complex) { return failure(); + } rewriter.replaceOp(op, complex.getLhs()); return success(); @@ -845,8 +864,9 @@ struct ImagOpCanon final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ImagOp op, PatternRewriter &rewriter) const override { auto complex = op.getOperand().getDefiningOp(); - if (!complex) + if (!complex) { return failure(); + } rewriter.replaceOp(op, complex.getRhs()); return success(); @@ -861,12 +881,14 @@ struct GetDimensionSizeOpCanon final PatternRewriter &rewriter) const override { // Fold get_dimension_size when the queried dim is statically known. auto tensorTy = dyn_cast(op.getOperand().getType()); - if (!tensorTy) + if (!tensorTy) { return failure(); + } int64_t dimSize = tensorTy.getDimSize(op.getDimension()); - if (dimSize < 0) + if (dimSize < 0) { return failure(); + } auto elemTy = cast(op.getType().getElementType()); IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize); @@ -903,8 +925,9 @@ struct GatherOpCanon final : OpRewritePattern { auto operandType = dyn_cast(gather->getOperand(0).getType()); - if (!operandType || !operandType.hasStaticShape()) + if (!operandType || !operandType.hasStaticShape()) { return failure(); + } auto sliceEnd = llvm::to_vector(gather.getSliceSizes()); SmallVector sliceStart(sliceEnd.size(), 0); @@ -1044,13 +1067,16 @@ struct TransposeIsReshape final nonZeroPerms.reserve(permValues.size()); for (auto idx : permValues) { auto sz = inputTy.getDimSize(idx); - if (sz != 1) + if (sz != 1) { nonZeroPerms.push_back(idx); + } } - for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) - if (nonZeroPerms[i - 1] > nonZeroPerms[i]) + for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) { + if (nonZeroPerms[i - 1] > nonZeroPerms[i]) { return rewriter.notifyMatchFailure(op, "memory layout change"); + } + } rewriter.replaceOpWithNewOp(op, op.getType(), input); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp index 744a8c523d8b..6f080dfd4e2b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/DotGeneralToDot.cpp @@ -65,8 +65,9 @@ Value transposeReshape(Value arg, Location loc, auto transposeType = RankedTensorType::get(transposedShape, elementType); Value transposeResult = mlir::stablehlo::TransposeOp::create( rewriter, loc, transposeType, arg, transposePermutationAttr); - if (noReshape) + if (noReshape) { return transposeResult; + } // Return the final result. auto reshapedType = RankedTensorType::get({leftSize, rightSize}, elementType); @@ -176,12 +177,14 @@ struct GeneralDotRemoveBatch final // We no longer include the batch dimension of 1. llvm::SmallVector newLhsContractingDims; - for (auto dim : dimNumbers.getLhsContractingDimensions()) + for (auto dim : dimNumbers.getLhsContractingDimensions()) { newLhsContractingDims.push_back(dim - 1); + } llvm::SmallVector newRhsContractingDims; - for (auto dim : dimNumbers.getRhsContractingDimensions()) + for (auto dim : dimNumbers.getRhsContractingDimensions()) { newRhsContractingDims.push_back(dim - 1); + } auto lhs = mlir::stablehlo::ReshapeOp::create( rewriter, op.getLoc(), lhsTy.clone(lhsTy.getShape().drop_front()), @@ -231,8 +234,9 @@ struct GeneralDotConvert final ArrayAttr precisionConfig; auto opPrecisionConfig = op.getPrecisionConfig(); - if (opPrecisionConfig.has_value()) + if (opPrecisionConfig.has_value()) { precisionConfig = *opPrecisionConfig; + } auto resultTy = cast(op.getType()); @@ -246,8 +250,9 @@ struct GeneralDotConvert final RankedTensorType lhsTy = dyn_cast(lhs.getType()); RankedTensorType rhsTy = dyn_cast(rhs.getType()); - if (!lhsTy || !rhsTy) + if (!lhsTy || !rhsTy) { return failure(); + } // The StableHLO dot operator directly supports a vector dot product // (two vectors reduce into a scalar) as well as a matrix vector @@ -295,8 +300,9 @@ struct GeneralDotConvert final // For any sparse situation, don't use any of the following rules, since // transposing and reshaping is not without cost. Instead, rely on the // default linalg lowering that follows later in the pipeline. - if (sparse_tensor::hasAnySparseOperandOrResult(op)) + if (sparse_tensor::hasAnySparseOperandOrResult(op)) { return failure(); + } // Compute the, possibly, transposed-reshaped operands. lhs = cast>(processDotArg( @@ -307,8 +313,9 @@ struct GeneralDotConvert final // Accept only static shaped types. auto lhsShapeType = dyn_cast_if_present(lhs.getType()); auto rhsShapeType = dyn_cast_if_present(rhs.getType()); - if (!lhsShapeType || !rhsShapeType) + if (!lhsShapeType || !rhsShapeType) { return failure(); + } // Generate new dot operator on expanded types. ShapedType newTy = RankedTensorType::get( diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp index 931a43fd8b42..6ce7b58454f3 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInCFG.cpp @@ -69,8 +69,9 @@ void copyOperationAttrs(Operation *oldOp, Operation *newOp) { // Don't copy segment attributes as these correspond to the number operands, // which may be different. if (oldAttr.getName() == "operandSegmentSizes" || - oldAttr.getName() == "resultSegmentSizes") + oldAttr.getName() == "resultSegmentSizes") { continue; + } newOp->setAttr(oldAttr.getName(), oldAttr.getValue()); } @@ -127,8 +128,9 @@ class DetupleReturnOp : public OpRewritePattern { LogicalResult matchAndRewrite(func::ReturnOp op, PatternRewriter &builder) const override { - if (!hasTuples(op.getOperands())) + if (!hasTuples(op.getOperands())) { return builder.notifyMatchFailure(op, "No detupling required"); + } llvm::SmallVector newOperands; if (failed(untupleAndLookupValues(op.getOperands(), newOperands, builder, @@ -147,8 +149,9 @@ class DetupleCallOp : public OpRewritePattern { LogicalResult matchAndRewrite(func::CallOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) + if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector newArgs; if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, @@ -180,8 +183,9 @@ class DetupleIndirectCallOp : public OpRewritePattern { LogicalResult matchAndRewrite(func::CallIndirectOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) + if (!hasTuples(oldOp.getOperands()) && !hasTuples(oldOp.getResults())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector newArgs; if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, @@ -202,8 +206,9 @@ class DetupleBranchOp : public OpRewritePattern { LogicalResult matchAndRewrite(cf::BranchOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands())) + if (!hasTuples(oldOp.getOperands())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector newArgs; if (failed(untupleAndLookupValues(oldOp.getOperands(), newArgs, builder, @@ -225,8 +230,9 @@ class DetupleConditionOp : public OpRewritePattern { LogicalResult matchAndRewrite(cf::CondBranchOp oldOp, PatternRewriter &builder) const override { - if (!hasTuples(oldOp.getOperands())) + if (!hasTuples(oldOp.getOperands())) { return builder.notifyMatchFailure(oldOp, "No detupling required"); + } llvm::SmallVector trueArgs; if (failed(untupleAndLookupValues(oldOp.getTrueOperands(), trueArgs, @@ -279,8 +285,9 @@ LogicalResult convertFunction(func::FuncOp oldFunction, // existing ones along path that produces tuples are used further, so just // remove instead of flattening. if (hasTupleSig && (attr.getName() == oldFunction.getArgAttrsAttrName() || - attr.getName() == oldFunction.getResAttrsAttrName())) + attr.getName() == oldFunction.getResAttrsAttrName())) { continue; + } newFunction->setAttr(attr.getName(), attr.getValue()); } diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp index 5be0ac2a8338..145a2270996e 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/FlattenTuplesInSCF.cpp @@ -114,8 +114,9 @@ class DetupleYieldOp : public OpRewritePattern { recursiveUntuple(operand, b, mapping, operands); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } rewriter.replaceOpWithNewOp(op, operands); return success(); @@ -137,8 +138,9 @@ class DetupleConditionOp : public OpRewritePattern { recursiveUntuple(operand, b, mapping, operands); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } rewriter.replaceOpWithNewOp(op, op.getCondition(), operands); @@ -159,8 +161,9 @@ class DetupleIfOp : public OpRewritePattern { hasTuples |= isa(type); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } llvm::SmallVector types; untupleTypes(op.getResultTypes(), types); @@ -204,8 +207,9 @@ class DetupleWhileOp : public OpRewritePattern { recursiveUntuple(operand, b, mapping, operands); } - if (!hasTuples) + if (!hasTuples) { return rewriter.notifyMatchFailure(op, "no tupled arguments"); + } llvm::SmallVector types; untupleTypes(op.getResultTypes(), types); diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp index ff767d26b95b..11ec46a9887b 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/LowerComplex.cpp @@ -96,12 +96,14 @@ ElementsAttr getSplat(Builder *b, RankedTensorType ty, T constant) { if (auto complexTy = dyn_cast(elementTy)) { auto complexElementTy = complexTy.getElementType(); - if (complexElementTy.isF32()) + if (complexElementTy.isF32()) { return DenseElementsAttr::get(ty, static_cast>(constant)); - if (complexElementTy.isF64()) + } + if (complexElementTy.isF64()) { return DenseElementsAttr::get( ty, static_cast>(constant)); + } } llvm_unreachable("unhandled element type"); } diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp index ba77c12a9bf4..7251188e0b0a 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp @@ -34,8 +34,9 @@ namespace { bool isIota(ArrayRef array) { for (auto [idx, value] : llvm::enumerate(array)) { - if (static_cast(idx) != value) + if (static_cast(idx) != value) { return false; + } } return true; } @@ -122,8 +123,9 @@ struct ReorderConvOpKernelDimensions final PatternRewriter &rewriter) const override { auto kernel = op.getRhs(); auto kernelType = cast(kernel.getType()); - if (!kernelType.hasRank()) + if (!kernelType.hasRank()) { return failure(); + } auto kernelShape = kernelType.getShape(); auto dimensionNumbers = op.getDimensionNumbers(); @@ -142,8 +144,9 @@ struct ReorderConvOpKernelDimensions final permutation.push_back(outputFeatureDimension); // If the permutation is iota, then no transpose is required. - if (isIota(permutation)) + if (isIota(permutation)) { return failure(); + } llvm::SmallVector transposeShape; for (int64_t perm : permutation) { @@ -253,8 +256,9 @@ struct ReorderConvOpOutputDimensions final bool isConsecutive(ArrayRef array) { for (size_t i = 1, e = array.size(); i < e; ++i) { - if (array[i] - array[i - 1] != 1) + if (array[i] - array[i - 1] != 1) { return false; + } } return true; } @@ -274,8 +278,9 @@ struct TransposeReshapeGenericDotGeneral final Value TransposeIfNonConsecutive(OpBuilder &b, Location loc, Value src, ArrayRef targetOrder) const { - if (isConsecutive(targetOrder)) + if (isConsecutive(targetOrder)) { return src; + } auto type = cast(src.getType()); SmallVector transposeShape; @@ -292,8 +297,9 @@ struct TransposeReshapeGenericDotGeneral final auto type = cast(src.getType()); ArrayRef shape = type.getShape(); if (dimsBorder0 <= 1 && dimsBorder1 - dimsBorder0 <= 1 && - shape.size() - dimsBorder1 <= 1) + shape.size() - dimsBorder1 <= 1) { return src; + } int64_t resultShape[] = { llvm::product_of(shape.take_front(dimsBorder0)), @@ -308,15 +314,17 @@ struct TransposeReshapeGenericDotGeneral final auto lhsShapeType = dyn_cast(op.getLhs().getType()); auto rhsShapeType = dyn_cast(op.getRhs().getType()); auto resultType = dyn_cast(op.getResult().getType()); - if (!lhsShapeType || !rhsShapeType || !resultType) + if (!lhsShapeType || !rhsShapeType || !resultType) { return failure(); + } // TODO(jpienaar): This pattern is not safe for dynamic shapes and seems to // be (now) redundant with later pass that does handle them. To decouple // fixing and verifying redundant, this just limits to static shapes and // then will remove this in follow up. - if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape()) + if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape()) { return failure(); + } SmallVector lhsTargetOrder, rhsTargetOrder; mlir::stablehlo::DotDimensionNumbersAttr dimNumbers = @@ -394,8 +402,9 @@ struct TransposeReshapeGenericDotGeneral final rhs = ReshapeIfNonStandard(rewriter, op.getLoc(), rhs, rhsBatchingDims.size(), numRhsContractionDims); - if (lhs == op.getLhs() && rhs == op.getRhs()) + if (lhs == op.getLhs() && rhs == op.getRhs()) { return rewriter.notifyMatchFailure(op, "already in canonical form"); + } auto dimensionNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhsBatchingDimensions=*/0, @@ -409,11 +418,13 @@ struct TransposeReshapeGenericDotGeneral final // batching、lhs parallel、rhs parallel this order is a conversion SmallVector newShape = {lhsNewType.getShape()[0]}; - if (lhsNewType.getRank() > 2) + if (lhsNewType.getRank() > 2) { newShape.push_back(lhsNewType.getDimSize(1)); + } - if (rhsNewType.getRank() > 2) + if (rhsNewType.getRank() > 2) { newShape.push_back(rhsNewType.getDimSize(2)); + } TensorType newResultType = RankedTensorType::get(newShape, resultType.getElementType()); @@ -537,8 +548,9 @@ struct ScatterImplicitBatch final static Value addUnitBatchDim(Location loc, Value value, PatternRewriter &rewriter) { auto valueTy = cast(value.getType()); - if (!valueTy.hasRank()) + if (!valueTy.hasRank()) { return nullptr; + } // Materialize the implicit indices dim. SmallVector reassociationMap(valueTy.getRank()); @@ -565,8 +577,9 @@ struct ScatterImplicitBatch final auto indicesTy = dyn_cast(indices.getType()); // Check whether indices has no batch dimension. - if (!indicesTy) + if (!indicesTy) { return failure(); + } if (indicesTy.getRank() != 1 || indexVectorDim != 0) { return rewriter.notifyMatchFailure(op, "no implicit batch dimension to add."); @@ -620,8 +633,9 @@ struct ScatterCollapseBatch final static Value collapseBatchDims(Location loc, Value value, int64_t batchCount, PatternRewriter &rewriter) { auto valueTy = dyn_cast(value.getType()); - if (!valueTy) + if (!valueTy) { return nullptr; + } SmallVector reassociationMap(1); reassociationMap.reserve(valueTy.getRank() - batchCount + 1); @@ -733,12 +747,14 @@ struct ScatterBatchFirst final : OpRewritePattern { llvm::SmallVector perm; perm.reserve(indicesTy.getRank()); for (int i = 0, s = indicesTy.getRank(); i < s; ++i) { - if (i != indexVectorDim) + if (i != indexVectorDim) { perm.push_back(i); + } } - if (perm.size() < indicesTy.getRank()) + if (perm.size() < indicesTy.getRank()) { perm.push_back(indexVectorDim); + } llvm::SmallVector newShape; for (int i = 0, s = perm.size(); i < s; ++i) { @@ -761,21 +777,25 @@ struct ScatterBatchFirst final : OpRewritePattern { // Determine which dimensions are batch dimensions. llvm::SmallVector isBatch(updates0Ty.getRank(), true); - for (int i = 0, s = updatedWindowDims.size(); i < s; ++i) + for (int i = 0, s = updatedWindowDims.size(); i < s; ++i) { isBatch[updatedWindowDims[i]] = false; + } // Permute batch dimensions to the start of the update tensor. llvm::SmallVector updatePerm; updatePerm.reserve(updates0Ty.getRank()); - for (int i = 0, s = isBatch.size(); i < s; ++i) - if (isBatch[i]) + for (int i = 0, s = isBatch.size(); i < s; ++i) { + if (isBatch[i]) { updatePerm.push_back(i); + } + } updatePerm.append(updatedWindowDims.begin(), updatedWindowDims.end()); llvm::SmallVector newUpdatedWindowDims; int64_t batchCount = updates0Ty.getRank() - updatedWindowDims.size(); - for (int i = batchCount, s = updates0Ty.getRank(); i < s; i++) + for (int i = batchCount, s = updates0Ty.getRank(); i < s; i++) { newUpdatedWindowDims.push_back(i); + } bool indicesChanged = indices != op.getScatterIndices(); bool updatesChanged = @@ -787,17 +807,19 @@ struct ScatterBatchFirst final : OpRewritePattern { auto updateTy = cast(update.getType()); llvm::SmallVector newShape; newShape.reserve(updateTy.getRank()); - for (int i = 0, s = updatePerm.size(); i < s; i++) + for (int i = 0, s = updatePerm.size(); i < s; i++) { newShape.push_back(updateTy.getDimSize(updatePerm[i])); + } update = mlir::stablehlo::TransposeOp::create( builder, updateTy.clone(newShape), update, builder.getDenseI64ArrayAttr(updatePerm)); } } - if (!indicesChanged && !updatesChanged) + if (!indicesChanged && !updatesChanged) { return rewriter.notifyMatchFailure( op, "batch dimensions are already leading"); + } auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get( op.getContext(), newUpdatedWindowDims, @@ -882,8 +904,9 @@ struct ScatterMaterializeInsertedDim final int64_t firstNonIndex = 0; for (int64_t s = scatterDimsToOperandDims.size(); firstNonIndex < s; ++firstNonIndex) { - if (!isIndexDim[firstNonIndex]) + if (!isIndexDim[firstNonIndex]) { break; + } } llvm::SmallVector isInsertDims(operandTy.getRank(), false); @@ -909,8 +932,9 @@ struct ScatterMaterializeInsertedDim final reassociationMap.push_back({rewriter.getAffineDimExpr(0)}); for (auto it : llvm::enumerate(llvm::ArrayRef(toInsertDims))) { - if (!it.value()) + if (!it.value()) { reassociationMap.push_back({}); + } reassociationMap.back().push_back( rewriter.getAffineDimExpr(it.index() + 1)); } @@ -962,8 +986,9 @@ struct ScatterMaterializeInsertedDim final bool isFromBool(Value val) { while (true) { Operation *op = val.getDefiningOp(); - if (!op) + if (!op) { return false; + } if (auto convertOp = dyn_cast(op)) { auto inTy = cast(convertOp.getOperand().getType()); @@ -993,17 +1018,20 @@ struct MulCastOfBool final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); - if (!isa(resultTy.getElementType())) + if (!isa(resultTy.getElementType())) { return failure(); + } Value lhs = op.getLhs(); Value rhs = op.getRhs(); bool lhsIsBool = isFromBool(lhs); bool rhsIsBool = isFromBool(rhs); - if (lhsIsBool == rhsIsBool) + if (lhsIsBool == rhsIsBool) { return failure(); - if (rhsIsBool) + } + if (rhsIsBool) { std::swap(lhs, rhs); + } Type eType = resultTy.getElementType(); auto lhsTy = cast(lhs.getType()); @@ -1023,8 +1051,9 @@ struct MulCastOfBool final : OpRewritePattern { auto valueTy = cast(value.getType()); auto newTy = RankedTensorType::get(resultTy.getShape(), valueTy.getElementType()); - if (valueTy == newTy) + if (valueTy == newTy) { return value; + } auto dimensions = llvm::to_vector( llvm::seq(resultRank - valueTy.getRank(), resultRank)); return mlir::stablehlo::DynamicBroadcastInDimOp::create( @@ -1047,19 +1076,22 @@ struct ExpandRngNormal final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::RngOp op, PatternRewriter &rewriter) const override { - if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::NORMAL) + if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::NORMAL) { return failure(); + } auto resTy = dyn_cast(op.getType()); // We can support static shapes, but it's easier to implement Box-Muller // transform if we know the number of elements. - if (!resTy || !resTy.hasStaticShape()) + if (!resTy || !resTy.hasStaticShape()) { return failure(); + } // The algorithm requires even numbers and will generate pairs. auto numElems = resTy.getNumElements(); - if (numElems & 1) + if (numElems & 1) { numElems++; + } auto halfNumElems = numElems / 2; ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -1193,11 +1225,13 @@ struct ReorderBroadcastInDimOpAndElementwiseOp final // NOTE: bcastOps may contain duplicates. SetVector deadOps; for (auto bcastOp : bcastOps) { - if (bcastOp.getOperation()->use_empty()) + if (bcastOp.getOperation()->use_empty()) { deadOps.insert(bcastOp); + } } - for (auto *deadOp : deadOps) + for (auto *deadOp : deadOps) { rewriter.eraseOp(deadOp); + } return success(); } @@ -1238,8 +1272,9 @@ struct FuseWidenOperands final : OpRewritePattern { if (llvm::all_of( llvm::zip_equal(operands, op->getOperands()), - [](auto pair) { return std::get<0>(pair) == std::get<1>(pair); })) + [](auto pair) { return std::get<0>(pair) == std::get<1>(pair); })) { return failure(); + } rewriter.replaceOpWithNewOp(op, op->getResultTypes(), operands, op->getAttrs()); @@ -1266,8 +1301,9 @@ struct DotToMul final : OpRewritePattern { return rewriter.notifyMatchFailure(op, "lhs and rhs must be rank-2"); } - if (lhsTy.getDimSize(1) != 1) + if (lhsTy.getDimSize(1) != 1) { return failure(); + } // Dynamically compute the shape of the result of the DotOp by querying // the 0-th dimensions, of the left, and the 1st dimension of the right. @@ -1298,10 +1334,13 @@ struct DotToMul final : OpRewritePattern { outSize, rewriter.getDenseI64ArrayAttr({0, 1})); auto computeETy = lhsTy.getElementType(); - if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth()) + if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth()) { computeETy = rhsTy.getElementType(); - if (computeETy.getIntOrFloatBitWidth() < resultTy.getElementTypeBitWidth()) + } + if (computeETy.getIntOrFloatBitWidth() < + resultTy.getElementTypeBitWidth()) { computeETy = resultTy.getElementType(); + } auto computeTy = resultTy.clone(computeETy); @@ -1362,8 +1401,9 @@ struct ZeroConcat final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter &rewriter) const override { auto type = dyn_cast(op.getType()); - if (!type || !type.hasStaticShape()) + if (!type || !type.hasStaticShape()) { return failure(); + } uint64_t axis = op.getDimension(); OperandRange origInputs = op.getInputs(); @@ -1371,15 +1411,18 @@ struct ZeroConcat final : OpRewritePattern { for (auto input : origInputs) { auto type = dyn_cast(input.getType()); ArrayRef shape = type.getShape(); - if (axis > shape.size()) + if (axis > shape.size()) { return failure(); + } - if (shape[axis] != 0) + if (shape[axis] != 0) { nonzeroInputs.push_back(input); + } } - if (nonzeroInputs.size() == origInputs.size()) + if (nonzeroInputs.size() == origInputs.size()) { return failure(); + } rewriter.replaceOpWithNewOp( op, nonzeroInputs, /*dimension=*/axis); @@ -1402,8 +1445,9 @@ struct DotGeneralIsMul final : OpRewritePattern { auto resultTy = dyn_cast(op.getType()); ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - if (!lhsTy || !rhsTy || !resultTy) + if (!lhsTy || !rhsTy || !resultTy) { return failure(); + } auto dNums = op.getDotDimensionNumbers(); auto batchDimsL = dNums.getLhsBatchingDimensions(); @@ -1414,14 +1458,18 @@ struct DotGeneralIsMul final : OpRewritePattern { llvm::SmallVector isLhsParallelDim(lhsTy.getRank(), true); llvm::SmallVector isRhsParallelDim(rhsTy.getRank(), true); - for (auto dim : batchDimsL) + for (auto dim : batchDimsL) { isLhsParallelDim[dim] = false; - for (auto dim : batchDimsR) + } + for (auto dim : batchDimsR) { isRhsParallelDim[dim] = false; - for (auto dim : contractDimsL) + } + for (auto dim : contractDimsL) { isLhsParallelDim[dim] = false; - for (auto dim : contractDimsR) + } + for (auto dim : contractDimsR) { isRhsParallelDim[dim] = false; + } for (auto dim : contractDimsL) { if (lhsTy.getDimSize(dim) != 1) { @@ -1437,13 +1485,15 @@ struct DotGeneralIsMul final : OpRewritePattern { permRhs.append(batchDimsR.begin(), batchDimsR.end()); for (auto [idx, value] : llvm::enumerate(isLhsParallelDim)) { - if (value) + if (value) { permLhs.push_back(idx); + } } for (auto [idx, value] : llvm::enumerate(isRhsParallelDim)) { - if (value) + if (value) { permRhs.push_back(idx); + } } llvm::append_range(permLhs, contractDimsL); @@ -1452,10 +1502,12 @@ struct DotGeneralIsMul final : OpRewritePattern { // Determine the transpose shape based on the generate permutations. llvm::SmallVector lhsTransposeShape; llvm::SmallVector rhsTransposeShape; - for (auto dim : permLhs) + for (auto dim : permLhs) { lhsTransposeShape.push_back(lhsTy.getDimSize(dim)); - for (auto dim : permRhs) + } + for (auto dim : permRhs) { rhsTransposeShape.push_back(rhsTy.getDimSize(dim)); + } // Transpose the left hand side and the right hand side. lhs = mlir::stablehlo::TransposeOp::create( @@ -1733,9 +1785,10 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { int64_t k; // Check that the output of the sort op gets fed into a slice. for (auto [idx, result] : llvm::enumerate(opResults)) { - if (result.getUsers().empty()) + if (result.getUsers().empty()) { return rewriter.notifyMatchFailure( op, "sort isn't calling into a slice op"); + } auto sliceOp = dyn_cast(*result.getUsers().begin()); if (!sliceOp) { @@ -1774,8 +1827,9 @@ struct ApproxTopK final : OpRewritePattern { LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, PatternRewriter &rewriter) const override { - if (op.getCallTargetName() != "ApproxTopK") + if (op.getCallTargetName() != "ApproxTopK") { return rewriter.notifyMatchFailure(op, "not ApproxTopK operation."); + } auto computationName = dyn_cast(op.getCalledComputationsAttr()[0]); @@ -1784,11 +1838,13 @@ struct ApproxTopK final : OpRewritePattern { parent = parent->getParentOp()) { funcOp = SymbolTable::lookupNearestSymbolFrom( parent, computationName); - if (funcOp) + if (funcOp) { break; + } } - if (!funcOp) + if (!funcOp) { return rewriter.notifyMatchFailure(op, "computation function not found."); + } int64_t k = cast(op.getType(0)).getShape().back(); auto input = op.getOperand(0); diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp index c26220cdf5b6..e7314c129b3a 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOCustomCalls.cpp @@ -193,8 +193,9 @@ struct HouseholderReflectorRewriter final Value householder = computeHouseholderSlice(matrix, tau, iv, b); std::vector batch(rank - 2); - for (int i = 0; i < rank - 2; ++i) + for (int i = 0; i < rank - 2; ++i) { batch[i] = i; + } std::vector lhsContract = {rank - 1}; std::vector rhsContract = {rank - 2}; diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp index 75ba4b3817b8..3b264d2603a9 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp @@ -91,8 +91,9 @@ struct ConcatenateOpConversion final auto toOpFoldResult = [](Value v) -> OpFoldResult { auto op = v.getDefiningOp(); - if (!op) + if (!op) { return v; + } return op.getValue(); }; @@ -233,8 +234,9 @@ static bool isValidFuncAttr(DictionaryAttr attrs) { // TODO: switch to using a dialect-based exclusion list or some other way that // is not a big string table. for (auto attr : attrs) { - if (attr.getName() == "tf.aliasing_output") + if (attr.getName() == "tf.aliasing_output") { return false; + } } return true; } @@ -246,13 +248,15 @@ static void setFuncEncodings(func::FuncOp funcOp, FunctionType oldFuncType, auto encodingName = StringAttr::get(funcOp.getContext(), "iree.abi.encoding"); for (auto [i, oldType, newType] : llvm::enumerate(oldFuncType.getInputs(), newFuncType.getInputs())) { - if (oldType != newType) + if (oldType != newType) { funcOp.setArgAttr(i, encodingName, TypeAttr::get(oldType)); + } } for (auto [i, oldType, newType] : llvm::enumerate(oldFuncType.getResults(), newFuncType.getResults())) { - if (oldType != newType) + if (oldType != newType) { funcOp.setResultAttr(i, encodingName, TypeAttr::get(oldType)); + } } } @@ -347,11 +351,13 @@ struct TensorEmptyPattern final : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto oldType = cast(op.getType()); auto newType = getTypeConverter()->convertType(oldType); - if (newType == oldType) + if (newType == oldType) { return failure(); + } - if (!newType) + if (!newType) { return rewriter.notifyMatchFailure(op, "result type conversion failed"); + } rewriter.replaceOpWithNewOp( op, oldType.getShape(), @@ -369,8 +375,9 @@ struct GlobalOpPattern final : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type oldType = globalOp.getType(); Type newType = getTypeConverter()->convertType(oldType); - if (newType == oldType) + if (newType == oldType) { return failure(); + } if (!newType) { return rewriter.notifyMatchFailure(globalOp, "result type conversion failed"); @@ -452,21 +459,24 @@ static void stripFrontendAttrs(mlir::ModuleOp moduleOp) { auto filterOpAttrs = [&](Operation *op) { SmallVector newAttrs; for (auto attr : op->getDialectAttrs()) { - if (!isAttrFiltered(attr)) + if (!isAttrFiltered(attr)) { newAttrs.push_back(attr); + } } op->setDialectAttrs(newAttrs); }; auto filterAttrDicts = [&](ArrayAttr allOldAttrs, SmallVectorImpl &newAttrs) { - if (!allOldAttrs) + if (!allOldAttrs) { return false; + } for (auto oldAttrs : allOldAttrs.getAsRange()) { SmallVector preservedAttrs; preservedAttrs.reserve(oldAttrs.size()); for (auto attr : oldAttrs) { - if (!isAttrFiltered(attr)) + if (!isAttrFiltered(attr)) { preservedAttrs.push_back(attr); + } } newAttrs.push_back( DictionaryAttr::get(allOldAttrs.getContext(), preservedAttrs)); @@ -554,12 +564,14 @@ struct ConvertStableHloToIreeInputDialects final auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); }; auto isLegallyTypedOp = [&](Operation *op) -> bool { for (Type type : op->getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } return true; }; @@ -582,17 +594,20 @@ struct ConvertStableHloToIreeInputDialects final } } for (Type type : funcOp.getFunctionType().getInputs()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : funcOp.getFunctionType().getResults()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Block &block : funcOp.getFunctionBody()) { for (Type type : block.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } return true; diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp index 8c3fab7e1a39..9bcdaf7f4962 100644 --- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp @@ -51,8 +51,9 @@ Type convertIntegerToSignless(IntegerType intType) { } std::optional convertRank0TensorToScalar(RankedTensorType tensorType) { - if (tensorType.getRank() != 0) + if (tensorType.getRank() != 0) { return std::nullopt; + } Type elementType = tensorType.getElementType(); if (auto intType = dyn_cast(elementType)) { elementType = convertIntegerToSignless(intType); @@ -72,8 +73,9 @@ Value materializeCast(OpBuilder &builder, Type toType, ValueRange inputs, assert(inputs.size() == 1 && "too many inputs to type conversion"); Value fromValue = inputs[0]; auto fromType = dyn_cast(fromValue.getType()); - if (!fromType) + if (!fromType) { return Value(); + } if (auto intFromType = dyn_cast(fromType.getElementType())) { Type castType = getElementTypeOrSelf(toType); @@ -88,8 +90,9 @@ Value materializeCast(OpBuilder &builder, Type toType, ValueRange inputs, } } - if (fromType.getRank() != 0) + if (fromType.getRank() != 0) { return fromValue; + } Type extractType = getElementTypeOrSelf(toType); return builder.createOrFold(loc, extractType, fromValue); @@ -131,11 +134,13 @@ struct LinalgExtRegionHLOOpConversion final : OpConversionPattern { LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isInBodyOfLinalgExtOps(op)) + if (!isInBodyOfLinalgExtOps(op)) { return failure(); + } TensorType origRetType = dyn_cast(op.getType()); - if (!origRetType) + if (!origRetType) { return failure(); + } SmallVector scalarArgs; Type newRetType = getElementTypeOrSelf( this->typeConverter->convertType(origRetType.getElementType())); @@ -152,8 +157,9 @@ struct LinalgExtRegionReturnOpConversion final LogicalResult matchAndRewrite(mlir::stablehlo::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isInBodyOfLinalgExtOps(op)) + if (!isInBodyOfLinalgExtOps(op)) { return failure(); + } rewriter.replaceOpWithNewOp( op, adaptor.getOperands()); return success(); @@ -222,23 +228,28 @@ struct ScatterOpConversion final auto indexDepth = indicesType.getShape().back(); auto scatterDimsToOperandDims = dimNumbers.getScatterDimsToOperandDims(); - if (indicesRank != 2) + if (indicesRank != 2) { return false; - if (indexVectorDim != indicesRank - 1) + } + if (indexVectorDim != indicesRank - 1) { return false; - if (scatterDimsToOperandDims.size() != indexDepth) + } + if (scatterDimsToOperandDims.size() != indexDepth) { return false; + } auto insertedWindowDims = dimNumbers.getInsertedWindowDims(); for (auto [idx, dim] : llvm::enumerate(insertedWindowDims)) { - if (idx != dim) + if (idx != dim) { return false; + } } // Check that there is only one batch dimension in the updates. for (auto [idx, dim] : llvm::enumerate(dimNumbers.getUpdateWindowDims())) { - if (idx + 1 != dim) + if (idx + 1 != dim) { return false; + } } return true; @@ -247,12 +258,15 @@ struct ScatterOpConversion final LogicalResult matchAndRewrite(mlir::stablehlo::ScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!hasCanonicalDimensionNumbers(op)) + if (!hasCanonicalDimensionNumbers(op)) { return failure(); - if (llvm::size(op.getInputs()) != 1) + } + if (llvm::size(op.getInputs()) != 1) { return op.emitError("NYI variadic operands scatter"); - if (llvm::size(op.getUpdates()) != 1) + } + if (llvm::size(op.getUpdates()) != 1) { return op.emitError("NYI variadic updates scatter"); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -335,8 +349,9 @@ struct ReverseOpConversion final matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto ty = dyn_cast(adaptor.getOperands()[0].getType()); - if (!ty) + if (!ty) { return failure(); + } Value input = op.getOperand(); auto inputTy = cast(input.getType()); @@ -426,8 +441,9 @@ struct ScanOpConversion final auto window = llvm::to_vector(op.getWindowDimensions()); llvm::SmallVector reduceAxes; for (int i = 0, s = window.size(); i < s; ++i) { - if (window[i] == 1) + if (window[i] == 1) { continue; + } if (window[i] == input0Ty.getDimSize(i)) { reduceAxes.push_back(i); continue; @@ -454,8 +470,9 @@ struct ScanOpConversion final } for (int i = 0, s = padding.size(); i < s; i += 2) { - if (i == reduceAxis * 2) + if (i == reduceAxis * 2) { continue; + } if (padding[i] != 0 || padding[i + 1] != 0) { return rewriter.notifyMatchFailure(op, "padding along non-reduction axis"); @@ -484,8 +501,9 @@ struct ScanOpConversion final llvm::SmallVector initDims; llvm::SmallVector initDynDims; for (int i = 0; i < input0Ty.getRank(); ++i) { - if (i == reduceAxis) + if (i == reduceAxis) { continue; + } initDims.push_back(input0Ty.getDimSize(i)); if (ShapedType::isDynamic(initDims.back())) { initDynDims.push_back( diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp index 6acdea63a8d6..2899ce23ea08 100644 --- a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp @@ -139,21 +139,25 @@ void Converti48Toi64Pass::runOnOperation() { target.markUnknownOpDynamicallyLegal([](Operation *op) { if (auto funcOp = dyn_cast(op)) { for (Type type : funcOp.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : funcOp.getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } for (Type type : op->getResultTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } for (auto attr : op->getAttrs()) { if (auto typedAttr = dyn_cast(attr.getValue())) { diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp index dc544ff4921c..7763ff95ee79 100644 --- a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp @@ -78,8 +78,9 @@ class GenericTypeConvert : public ConversionPattern { }; static bool isIllegalType(Type type) { - if (IntegerType ity = dyn_cast(type)) + if (IntegerType ity = dyn_cast(type)) { return !ity.isSignless(); + } if (auto shapedType = dyn_cast(type)) { return isIllegalType(shapedType.getElementType()); } @@ -94,21 +95,25 @@ void StripSignednessPass::runOnOperation() { target.markUnknownOpDynamicallyLegal([](Operation *op) { if (auto funcOp = dyn_cast(op)) { for (Type type : funcOp.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : funcOp.getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } for (Type type : op->getResultTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (type && isIllegalType(type)) + if (type && isIllegalType(type)) { return false; + } } return true; }); diff --git a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp index 8e7561953a6f..500d309c5b31 100644 --- a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp +++ b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp @@ -42,9 +42,10 @@ class ScatterConversion : public OpRewritePattern { auto updatesTy = dyn_cast(updates.getType()); ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - if (!valuesTy || !indicesTy || !updatesTy) + if (!valuesTy || !indicesTy || !updatesTy) { return rewriter.notifyMatchFailure(op, "tosa.gather has unknown input rank"); + } // TOSA's scatter does not include a index dimension, instead it implicitly // supports an index depth of one. We materialize that implicit index of @@ -68,9 +69,11 @@ class ScatterConversion : public OpRewritePattern { // Materialize the batch indice as LinalgExt scatter is not batched. { llvm::SmallVector dynDims; - for (int i = 0, s = indicesTy.getRank(); i < s; ++i) - if (indicesTy.isDynamicDim(i)) + for (int i = 0, s = indicesTy.getRank(); i < s; ++i) { + if (indicesTy.isDynamicDim(i)) { dynDims.push_back(tensor::DimOp::create(builder, indices, i)); + } + } Value empty = tensor::EmptyOp::create( builder, indicesTy.getShape(), indicesTy.getElementType(), dynDims); @@ -159,8 +162,9 @@ class TosaToLinalgExtPass final mlir::FunctionOpInterface funcOp = getOperation(); mlir::iree_compiler::populateTosaToLinalgExtPatterns(&patterns); - if (failed(applyFullConversion(funcOp, target, std::move(patterns)))) + if (failed(applyFullConversion(funcOp, target, std::move(patterns)))) { signalPassFailure(); + } } }; diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp index 5fbef105cd71..b46fce9904d5 100644 --- a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp @@ -86,13 +86,15 @@ class BindSymbolicShapesPass final auto operand = bindOp.getOperand(); // Torch programs are single block and use structured control flow, so // presume this is an entrypoint. - if (isa(operand)) + if (isa(operand)) { return true; + } // Mutable tensors can exist at the boundary and must be "copied" to a // vtensor prior to use. Therefore, we anchor on the point of copy. - if (operand.getDefiningOp()) + if (operand.getDefiningOp()) { return true; + } return false; } @@ -117,10 +119,12 @@ class BindSymbolicShapesPass final // Gets the canonical dim for this symbol, returning {} if there // is no canonical dim. Value getCanonicalDimValue(OpBuilder &builder) { - if (canonicalDimValue) + if (canonicalDimValue) { return canonicalDimValue; - if (equalityDimInfos.empty()) + } + if (equalityDimInfos.empty()) { return {}; + } canonicalDimValue = getEqualityDimValue(builder, 0); return canonicalDimValue; } @@ -213,8 +217,9 @@ class BindSymbolicShapesPass final std::optional> evaluateExprBounds(AffineExpr expr, llvm::DenseMap &symbolInfos) { - if (!expr.isSymbolicOrConstant()) + if (!expr.isSymbolicOrConstant()) { return {}; + } llvm::SmallVector> lowerBounds; llvm::SmallVector> upperBounds; lowerBounds.reserve(symbols.size()); @@ -233,14 +238,16 @@ class BindSymbolicShapesPass final auto upperBound = getBoundForAffineExpr( expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds, upperBounds, /*isUpper=*/true); - if (!upperBound) + if (!upperBound) { return {}; + } auto lowerBound = getBoundForAffineExpr( expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds, upperBounds, /*isUpper=*/false); - if (!lowerBound) + if (!lowerBound) { return {}; + } return std::make_pair(*lowerBound, *upperBound); } @@ -250,8 +257,9 @@ class BindSymbolicShapesPass final void associateEqualityDims(llvm::DenseMap &symbolInfos) { OpBuilder builder(anchorOp); for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) { - if (expr.getKind() != AffineExprKind::SymbolId) + if (expr.getKind() != AffineExprKind::SymbolId) { continue; + } auto symbolPos = cast(expr).getPosition(); Value symbol = symbols[symbolPos]; auto symbolInfoIt = symbolInfos.find(symbol); @@ -268,12 +276,14 @@ class BindSymbolicShapesPass final if (auto binaryExpr = dyn_cast(genericExpr)) { auto lhs = materializeDimExpr(loc, builder, binaryExpr.getLHS(), symbolInfos); - if (!lhs) + if (!lhs) { return {}; + } auto rhs = materializeDimExpr(loc, builder, binaryExpr.getRHS(), symbolInfos); - if (!rhs) + if (!rhs) { return {}; + } switch (binaryExpr.getKind()) { case AffineExprKind::Add: @@ -303,12 +313,14 @@ class BindSymbolicShapesPass final case AffineExprKind::SymbolId: { auto symExpr = cast(genericExpr); auto pos = symExpr.getPosition(); - if (pos >= symbols.size()) + if (pos >= symbols.size()) { break; + } Value symbolValue = symbols[pos]; auto foundIt = symbolInfos.find(symbolValue); - if (foundIt == symbolInfos.end()) + if (foundIt == symbolInfos.end()) { break; + } SymbolInfo &info = foundIt->second; return info.getCanonicalDimValue(builder); // May legally return {} } @@ -327,8 +339,9 @@ class BindSymbolicShapesPass final void materializeDims(llvm::DenseMap &symbolInfos) { OpBuilder builder(anchorOp); for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) { - if (!builtinTensorType.isDynamicDim(index)) + if (!builtinTensorType.isDynamicDim(index)) { continue; + } Value dimValue = materializeDimExpr(anchorOp->getLoc(), builder, expr, symbolInfos); @@ -412,8 +425,9 @@ class BindSymbolicShapesPass final SymbolInfo(symbolOp)); } else if (auto bindOp = dyn_cast(childOp)) { cleanupOpList.push_back(bindOp); - if (!isEligibleBinding(bindOp)) + if (!isEligibleBinding(bindOp)) { return; + } auto torchType = cast(bindOp.getOperand().getType()); auto builtinType = dyn_cast_if_present( diff --git a/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp b/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp index 1b3cfacb9dcc..a2a761adc592 100644 --- a/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BitCastTensor.cpp @@ -142,30 +142,37 @@ class BitCastMatmul : public OpRewritePattern { return success(); }; int unpackedBitWidth; - if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) + if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) { return failure(); + } auto rhsType = dyn_cast(rhs.getType()); - if (!rhsType) + if (!rhsType) { return failure(); + } - if (!rhsType.hasDtype()) + if (!rhsType.hasDtype()) { return failure(); + } Type dType = rhsType.getDtype(); int dTypeWidth = dType.getIntOrFloatBitWidth(); // If the dtype width already matches the target width, nothing to do. - if (dTypeWidth == unpackedBitWidth) + if (dTypeWidth == unpackedBitWidth) { return failure(); + } - if (!rhsType.hasSizes()) + if (!rhsType.hasSizes()) { return failure(); + } SmallVector tensorShape(rhsType.getSizes()); // Constants should have constant shape. - if (llvm::any_of(tensorShape, - [](int64_t s) { return s == torch::Torch::kUnknownSize; })) + if (llvm::any_of(tensorShape, [](int64_t s) { + return s == torch::Torch::kUnknownSize; + })) { return failure(); + } int packRatio = dTypeWidth / unpackedBitWidth; tensorShape[tensorShape.size() - 1] *= packRatio; @@ -185,10 +192,11 @@ class BitCastMatmul : public OpRewritePattern { // Cast back to the (un)signed torch tensor type to inform later lowerings. Type unpackedElementType; - if (dType.isSignedInteger()) + if (dType.isSignedInteger()) { unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true); - else + } else { unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false); + } torch::Torch::ValueTensorType newRhsType = torch::Torch::ValueTensorType::get(rewriter.getContext(), tensorShape, unpackedElementType); @@ -215,8 +223,9 @@ class BitCastTensorPass final patterns.add(context); patterns.add, BitCastViewComplex>(context); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); + } } }; } // namespace diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index e5ce5246f614..9aeee4d25549 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -51,8 +51,9 @@ struct ScatterOpConversion LogicalResult matchAndRewrite(mlir::torch::TMTensor::ScatterOp op, PatternRewriter &rewriter) const override { auto indicesTy = op.getIndicesType(); - if (!indicesTy.hasRank()) + if (!indicesTy.hasRank()) { return failure(); + } if (indicesTy.isDynamicDim(indicesTy.getRank() - 1)) { return rewriter.notifyMatchFailure(op, "number of indices is unknown"); @@ -60,8 +61,9 @@ struct ScatterOpConversion auto numIndices = indicesTy.getShape().back(); llvm::SmallVector dimMap(numIndices); - for (int i = 0; i < numIndices; i++) + for (int i = 0; i < numIndices; i++) { dimMap[i] = i; + } auto updatesTy = op.getUpdateType(); @@ -182,8 +184,9 @@ struct AttentionOpConversion int64_t numBatches = op.getQueryType().getRank() - 2; for (AffineMap &map : indexingMaps) { map = map.shiftDims(numBatches); - if (map.getNumResults() == 0) + if (map.getNumResults() == 0) { continue; + } for (int batch : llvm::seq(numBatches)) { map = map.insertResult(rewriter.getAffineDimExpr(batch), batch); } diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index 85f436ded555..5b6d944c136b 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -82,8 +82,9 @@ getEnclosingWaitSignalFences(Operation *op) { auto parentFuncOp = dyn_cast(op); if (!parentFuncOp) { parentFuncOp = parentFuncOp->getParentOfType(); - if (!parentFuncOp) + if (!parentFuncOp) { return {}; + } } Block *entryBlock = &parentFuncOp.front(); auto numArguments = entryBlock->getNumArguments(); @@ -99,8 +100,9 @@ getEnclosingWaitSignalFences(Value value) { Value convertToBuiltinTensor(OpBuilder &builder, Value possibleTorchTensor) { Type ty = possibleTorchTensor.getType(); - if (isa(ty)) + if (isa(ty)) { return possibleTorchTensor; + } if (auto defining = dyn_cast_if_present( possibleTorchTensor.getDefiningOp())) { @@ -177,8 +179,9 @@ struct ConvertedAsyncFunctionInfo { }; LogicalResult ConvertedAsyncFunctionInfo::postProcess() { - if (funcOp.isExternal()) + if (funcOp.isExternal()) { return success(); + } if (returnOps.size() != 1) { // Multi-exit/CFG could be supported but requires more complicated dominance @@ -197,14 +200,17 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { llvm::zip_equal(inputDispositions, entryArgs, torchInputTypes)) { switch (disp) { case TypeDisposition::IMMUTABLE_TENSOR: { - if (failed( - convertImmutableTensorArg(argValue, torchType, preambleBuilder))) + if (failed(convertImmutableTensorArg(argValue, torchType, + preambleBuilder))) { return failure(); + } break; } case TypeDisposition::MUTABLE_TENSOR: { - if (failed(convertMutableTensorArg(argValue, torchType, preambleBuilder))) + if (failed( + convertMutableTensorArg(argValue, torchType, preambleBuilder))) { return failure(); + } break; } case TypeDisposition::TORCH_PRIMITIVE: { @@ -374,12 +380,14 @@ LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg( // it. bool hasNonTrivialUse = false; for (auto *userOp : argValue.getUsers()) { - if (isa(userOp)) + if (isa(userOp)) { continue; + } hasNonTrivialUse = true; } - if (!hasNonTrivialUse) + if (!hasNonTrivialUse) { return success(); + } // Remember original uses so we can redirect them. OriginalUses originalUses(argValue); @@ -481,8 +489,9 @@ void retainFunctionAttributes(Operation *srcOp, IREE::Util::FuncOp destOp) { for (auto retainAttrName : retainedAttributes) { StringRef attrName(retainAttrName); Attribute attr = srcOp->getAttr(attrName); - if (attr) + if (attr) { destOp->setAttr(attrName, attr); + } } } @@ -566,8 +575,9 @@ class FuncConversionPass final SmallVector eraseFuncOps; std::vector convertedFuncInfos; for (auto funcOp : moduleOp.getOps()) { - if (!shouldConvertFunc(funcOp)) + if (!shouldConvertFunc(funcOp)) { continue; + } ConvertedAsyncFunctionInfo &convertedFuncInfo = convertedFuncInfos.emplace_back(); if (failed(convertFuncOp(funcOp, convertedFuncInfo))) { @@ -594,12 +604,14 @@ class FuncConversionPass final // calling convention. In the future, we may support "torch externals" // which we convert to mate up with a torch module. We can remove/adapt // this when that is elaborated. - if (torchFunc.isExternal()) + if (torchFunc.isExternal()) { return false; + } // Something has already converted this and told us not to touch it. - if (torchFunc->hasAttr("iree.abi.stub")) + if (torchFunc->hasAttr("iree.abi.stub")) { return false; + } return true; } @@ -640,14 +652,16 @@ class FuncConversionPass final for (size_t i = 0; i < convertedFuncInfo.torchInputTypes.size(); ++i) { if (failed(convertType(loc, convertedFuncInfo.torchInputTypes[i], ireeInputTypes[i], - convertedFuncInfo.inputDispositions[i]))) + convertedFuncInfo.inputDispositions[i]))) { return failure(); + } } for (size_t i = 0; i < convertedFuncInfo.torchResultTypes.size(); ++i) { if (failed(convertType(loc, convertedFuncInfo.torchResultTypes[i], ireeResultTypes[i], - convertedFuncInfo.resultDispositions[i]))) + convertedFuncInfo.resultDispositions[i]))) { return failure(); + } } // Build tied operands index mapping results back to operands. diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index 2d04729c0065..4ed603cb4444 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -50,9 +50,10 @@ void createTorchToIREEPipeline( torch::Torch::createReduceOpVariantsPass(llvm::StringRef())); pm.addNestedPass( mlir::torch::TorchConversion::createConvertCustomQuantOpPass()); - if (options.decompose) + if (options.decompose) { pm.addNestedPass( torch::Torch::createDecomposeComplexOpsPass(BackendLegalOps::get())); + } pm.addNestedPass(torch::Torch::createFuseQuantizedOpsPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(torch::Torch::createScalarizeShapesPass()); diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 87422ee4b6ba..b4fe9c265bd9 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -122,14 +122,17 @@ static constexpr char kPtxasCompilerName[] = "ptxas"; static FailureOr findPtxasCompiler(const CUDAOptions &options, std::string *message) { std::string ptxasCompiler; - if (!options.clUsePtxasFrom.empty()) + if (!options.clUsePtxasFrom.empty()) { ptxasCompiler = options.clUsePtxasFrom; - if (llvm::sys::fs::exists(ptxasCompiler)) + } + if (llvm::sys::fs::exists(ptxasCompiler)) { return ptxasCompiler; + } ptxasCompiler = findTool(kPtxasCompilerName); - if (llvm::sys::fs::exists(ptxasCompiler)) + if (llvm::sys::fs::exists(ptxasCompiler)) { return ptxasCompiler; + } *message = std::string( "Could not find ptxas compiler. Try passing it explicitly with " @@ -181,8 +184,9 @@ static FailureOr compileWithPtxas(StringRef ptxasCompiler, llvm::StringSaver stringSaver(scratchAllocator); SmallVector rawArgs; Tokenize(ptxasParams, stringSaver, rawArgs, /*MarkEOLs=*/false); - for (auto rawArg : rawArgs) + for (auto rawArg : rawArgs) { ArgVector.push_back(StringRef(rawArg)); + } std::optional redirects[] = { stdinFile.str(), @@ -233,8 +237,9 @@ static FailureOr compileWithPtxas(StringRef ptxasCompiler, static std::string produceGpuImage(const CUDAOptions &options, StringRef targetArch, std::string &ptxImage) { - if (!options.clUsePtxas) + if (!options.clUsePtxas) { return ptxImage; + } std::string message; FailureOr ptxasCompiler = findPtxasCompiler(options, &message); @@ -243,8 +248,9 @@ static std::string produceGpuImage(const CUDAOptions &options, FailureOr maybeCubinImage = compileWithPtxas(ptxasCompiler.value(), targetArch, options.clUsePtxasParams, ptxImage, &message); - if (succeeded(maybeCubinImage)) + if (succeeded(maybeCubinImage)) { return maybeCubinImage.value(); + } } llvm::WithColor::warning() @@ -414,8 +420,9 @@ class CUDATargetBackend final : public TargetBackend { getExecutableTarget(MLIRContext *context) const { Builder b(context); SmallVector configItems; - if (failed(options.verify(b))) + if (failed(options.verify(b))) { return nullptr; + } if (auto target = GPU::getCUDATargetDetails( options.clTarget, options.clTargetFeatures, context)) { diff --git a/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp b/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp index 9ae3527d078d..0689fbd07a9a 100644 --- a/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp +++ b/compiler/plugins/target/LLVMCPU/Builtins/Device.cpp @@ -17,8 +17,9 @@ namespace mlir::iree_compiler::IREE::HAL { static const iree_file_toc_t *lookupDeviceFile(StringRef filename) { for (size_t i = 0; i < iree_builtins_libdevice_bitcode_size(); ++i) { const auto &file_toc = iree_builtins_libdevice_bitcode_create()[i]; - if (filename == file_toc.name) + if (filename == file_toc.name) { return &file_toc; + } } return nullptr; } @@ -67,8 +68,9 @@ loadDeviceBitcode(llvm::TargetMachine *targetMachine, llvm::MemoryBufferRef bitcodeBufferRef( llvm::StringRef(file->data, file->size), file->name); auto bitcodeModuleValue = llvm::parseBitcodeFile(bitcodeBufferRef, context); - if (!bitcodeModuleValue) + if (!bitcodeModuleValue) { return bitcodeModuleValue; + } auto bitcodeModule = std::move(bitcodeModuleValue.get()); // Clang adds its own per-function attributes that we need to strip so that @@ -86,8 +88,9 @@ static void overridePlatformGlobal(llvm::Module &module, StringRef globalName, uint32_t newValue) { // NOTE: the global will not be defined if it is not used in the module. auto *globalValue = module.getNamedGlobal(globalName); - if (!globalValue) + if (!globalValue) { return; + } globalValue->setLinkage(llvm::GlobalValue::PrivateLinkage); globalValue->setDSOLocal(true); globalValue->setConstant(true); diff --git a/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp b/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp index 337526155678..0189438d2bf0 100644 --- a/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp +++ b/compiler/plugins/target/LLVMCPU/Builtins/Musl.cpp @@ -16,8 +16,9 @@ namespace mlir::iree_compiler::IREE::HAL { static const iree_file_toc_t *lookupMuslFile(StringRef filename) { for (size_t i = 0; i < iree_builtins_libmusl_size(); ++i) { const auto &file_toc = iree_builtins_libmusl_create()[i]; - if (filename == file_toc.name) + if (filename == file_toc.name) { return &file_toc; + } } return nullptr; } diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp index bdc5b16fc346..69e066ec94a6 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp @@ -272,8 +272,9 @@ class LLVMCPUTargetBackend final : public TargetBackend { // multi-threading issues. llvm::LLVMContext context; auto maybeTarget = getVariantTarget(variantOp); - if (!maybeTarget) + if (!maybeTarget) { return failure(); + } const LLVMTarget &target = *maybeTarget; LLVM_DEBUG(dbgs() << "LLVM-CPU SerializeExecutable:\n" << "-----------------------------\n"; @@ -384,8 +385,9 @@ class LLVMCPUTargetBackend final : public TargetBackend { for (auto exportOp : variantOp.getBlock().getOps()) { // Find the matching function in the LLVM module. auto *llvmFunc = llvmModule->getFunction(exportOp.getName()); - if (!llvmFunc) + if (!llvmFunc) { continue; + } llvmFunc->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage); llvmFunc->setDSOLocal(true); @@ -595,8 +597,9 @@ class LLVMCPUTargetBackend final : public TargetBackend { // Strip any compiler identifiers that may have snuck in. We let the linker // tag the module. auto *llvmIdent = llvmModule->getNamedMetadata("llvm.ident"); - if (llvmIdent) + if (llvmIdent) { llvmIdent->clearOperands(); + } // Dump all linked bitcode prior to optimization. if (!options.dumpIntermediatesPath.empty()) { diff --git a/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp b/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp index c045c061ce2f..2b482ef98b54 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMIRPasses.cpp @@ -89,8 +89,9 @@ LogicalResult runLLVMIRPasses(const LLVMTarget &target, modulePassManager.run(*module, moduleAnalysisManager); } - if (llvm::verifyModule(*module)) + if (llvm::verifyModule(*module)) { return failure(); + } return success(); } diff --git a/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp b/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp index 8ead1b21a39c..d4ba1d652e52 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp @@ -85,8 +85,9 @@ std::optional LLVMTarget::createForHost() { << getMessage(status, triple) << "\n"; return std::nullopt; } - if (target) + if (target) { target->populateDefaultsFromTargetMachine(); + } return target; } @@ -163,16 +164,21 @@ void LLVMTarget::storeToConfigAttrs(MLIRContext *context, if (!staticLibraryOutput.empty()) { addString("static_library_output", staticLibraryOutput); } - if (pipelineTuningOptions.LoopInterleaving != DEFAULT_LOOP_INTERLEAVING) + if (pipelineTuningOptions.LoopInterleaving != DEFAULT_LOOP_INTERLEAVING) { addBool("loop_interleaving", pipelineTuningOptions.LoopInterleaving); - if (pipelineTuningOptions.LoopVectorization != DEFAULT_LOOP_VECTORIZATION) + } + if (pipelineTuningOptions.LoopVectorization != DEFAULT_LOOP_VECTORIZATION) { addBool("loop_vectorization", pipelineTuningOptions.LoopVectorization); - if (pipelineTuningOptions.LoopUnrolling != DEFAULT_LOOP_UNROLLING) + } + if (pipelineTuningOptions.LoopUnrolling != DEFAULT_LOOP_UNROLLING) { addBool("loop_unrolling", pipelineTuningOptions.LoopUnrolling); - if (pipelineTuningOptions.SLPVectorization != DEFAULT_SLP_VECTORIZATION) + } + if (pipelineTuningOptions.SLPVectorization != DEFAULT_SLP_VECTORIZATION) { addBool("slp_vectorization", pipelineTuningOptions.SLPVectorization); - if (!llvmTargetOptions.MCOptions.ABIName.empty()) + } + if (!llvmTargetOptions.MCOptions.ABIName.empty()) { addString("target_abi", llvmTargetOptions.MCOptions.ABIName); + } if (llvmTargetOptions.FloatABIType != DEFAULT_FLOAT_ABI) { switch (llvmTargetOptions.FloatABIType) { case llvm::FloatABI::Default: @@ -186,10 +192,12 @@ void LLVMTarget::storeToConfigAttrs(MLIRContext *context, break; } } - if (ukernels.compare(DEFAULT_ENABLE_UKERNELS) != 0) + if (ukernels.compare(DEFAULT_ENABLE_UKERNELS) != 0) { addString("ukernels", ukernels); - if (linkUkernelBitcode != DEFAULT_LINK_UKERNEL_BITCODE) + } + if (linkUkernelBitcode != DEFAULT_LINK_UKERNEL_BITCODE) { addBool("link_ukernel_bitcode", linkUkernelBitcode); + } } std::optional @@ -274,13 +282,13 @@ LLVMTarget::loadFromConfigAttr(Location loc, DictionaryAttr config, target.linkStatic = getBool("link_static", DEFAULT_LINK_STATIC); auto sanitizer = getOptionalString("sanitizer"); if (sanitizer) { - if (sanitizer == "none") + if (sanitizer == "none") { target.sanitizerKind = SanitizerKind::kNone; - else if (sanitizer == "address") + } else if (sanitizer == "address") { target.sanitizerKind = SanitizerKind::kAddress; - else if (sanitizer == "thread") + } else if (sanitizer == "thread") { target.sanitizerKind = SanitizerKind::kThread; - else { + } else { emitError(loc) << "executable config unexpected value for 'sanitizer': " << *sanitizer; return {}; @@ -297,17 +305,18 @@ LLVMTarget::loadFromConfigAttr(Location loc, DictionaryAttr config, target.pipelineTuningOptions.SLPVectorization = getBool( "slp_vectorization", target.pipelineTuningOptions.SLPVectorization); auto targetAbi = getOptionalString("target_abi"); - if (targetAbi) + if (targetAbi) { target.llvmTargetOptions.MCOptions.ABIName = *targetAbi; + } auto floatAbi = getOptionalString("float_abi"); if (floatAbi) { - if (floatAbi == "default") + if (floatAbi == "default") { target.llvmTargetOptions.FloatABIType = llvm::FloatABI::Default; - else if (floatAbi == "soft") + } else if (floatAbi == "soft") { target.llvmTargetOptions.FloatABIType = llvm::FloatABI::Default; - else if (floatAbi == "hard") + } else if (floatAbi == "hard") { target.llvmTargetOptions.FloatABIType = llvm::FloatABI::Default; - else { + } else { emitError(loc) << "executable config unexpected value for 'float_abi'"; return {}; } @@ -389,8 +398,9 @@ createTargetMachine(const LLVMTarget &target) { std::string errorMessage; auto llvmTarget = llvm::TargetRegistry::lookupTarget( llvm::Triple(target.getTriple()), errorMessage); - if (!llvmTarget) + if (!llvmTarget) { return nullptr; + } llvm::Triple triple(target.getTriple()); std::unique_ptr machine(llvmTarget->createTargetMachine( triple, target.getCpu() /* cpu e.g k8 */, diff --git a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp index d765b2adfafb..83c724ccaf1a 100644 --- a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp +++ b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp @@ -520,8 +520,9 @@ LibraryBuilder::buildLibraryV0ImportTable(std::string libraryName) { SmallVector symbolNameValues; for (auto &import : imports) { auto symbolName = import.symbol_name; - if (import.weak) + if (import.weak) { symbolName = "?" + symbolName; + } symbolNameValues.push_back(createStringConstant(symbolName, module)); } symbolNames = createArrayConstant(libraryName + "_import_names", @@ -552,8 +553,9 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { // iree_hal_executable_export_table_v0_t::ptrs SmallVector exportPtrValues; - for (auto dispatch : exports) + for (auto dispatch : exports) { exportPtrValues.push_back(dispatch.func); + } llvm::Constant *exportPtrs = createArrayConstant( libraryName + "_funcs", ptrType, exportPtrValues, module); @@ -603,8 +605,9 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { llvm::Constant *exportNames = llvm::Constant::getNullValue(ptrType); if (mode == Mode::INCLUDE_REFLECTION_ATTRS) { SmallVector exportNameValues; - for (auto dispatch : exports) + for (auto dispatch : exports) { exportNameValues.push_back(createStringConstant(dispatch.name, module)); + } exportNames = createArrayConstant(libraryName + "_names", ptrType, exportNameValues, module); } @@ -615,9 +618,10 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { exports, [](auto &dispatch) { return !dispatch.tag.empty(); }); if (mode == Mode::INCLUDE_REFLECTION_ATTRS && hasAnyTags) { SmallVector exportTagValues; - for (auto dispatch : exports) + for (auto dispatch : exports) { exportTagValues.push_back( createStringConstantOrNull(dispatch.tag, module)); + } exportTags = createArrayConstant(libraryName + "_tags", ptrType, exportTagValues, module); } diff --git a/compiler/plugins/target/LLVMCPU/LinkerTool.cpp b/compiler/plugins/target/LLVMCPU/LinkerTool.cpp index b763830da3d6..ef13059db486 100644 --- a/compiler/plugins/target/LLVMCPU/LinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/LinkerTool.cpp @@ -56,8 +56,9 @@ Artifact Artifact::createVariant(StringRef basePath, StringRef suffix) { } void Artifact::keep() const { - if (outputFile) + if (outputFile) { outputFile->keep(); + } } std::optional> Artifact::read() const { @@ -129,8 +130,9 @@ LogicalResult LinkerTool::runLinkCommand(std::string commandLine, commandLine = escapeCommandLineComponent(commandLine); } int exitCode = system(commandLine.c_str()); - if (exitCode == 0) + if (exitCode == 0) { return success(); + } llvm::errs() << "Linking failed; escaped command line returned exit code " << exitCode << ":\n\n" << commandLine << "\n\n"; diff --git a/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp index 6d4c208cd48a..1ceb47dd5ac2 100644 --- a/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/AndroidLinkerTool.cpp @@ -104,8 +104,9 @@ class AndroidLinkerTool : public LinkerTool { std::string getSystemToolPath() const override { auto toolPath = LinkerTool::getSystemToolPath(); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } // ANDROID_NDK must be set for us to infer the tool path. char *androidNDKPath = std::getenv("ANDROID_NDK"); @@ -216,8 +217,9 @@ class AndroidLinkerTool : public LinkerTool { flagsToPrefixForLinker.clear(); auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } return artifacts; } }; diff --git a/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp index 61f9a82ed4f5..1dd6eeef5316 100644 --- a/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/EmbeddedLinkerTool.cpp @@ -50,15 +50,17 @@ class EmbeddedLinkerTool : public LinkerTool { // Fall back to check for setting the linker explicitly via environment // variables. char *envVarPath = std::getenv("IREE_LLVM_EMBEDDED_LINKER_PATH"); - if (envVarPath && envVarPath[0] != '\0') + if (envVarPath && envVarPath[0] != '\0') { return std::string(envVarPath); + } // No explicit linker specified, search the install/build dir or env. const SmallVector &toolNames{"iree-lld", "lld", "ld.lld", "lld-link"}; std::string toolPath = findTool(toolNames); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "error: required embedded linker tool (typically `lld`) not found " @@ -119,8 +121,9 @@ class EmbeddedLinkerTool : public LinkerTool { artifacts.libraryFile.close(); std::string embeddedToolPath = getEmbeddedToolPath(); - if (embeddedToolPath.empty()) + if (embeddedToolPath.empty()) { return std::nullopt; + } SmallVector flags = { embeddedToolPath, diff --git a/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp index aa71989094a8..529a036182dc 100644 --- a/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/UnixLinkerTool.cpp @@ -24,8 +24,9 @@ class UnixLinkerTool : public LinkerTool { std::string getSystemToolPath() const override { // First check for setting the linker explicitly. auto toolPath = LinkerTool::getSystemToolPath(); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } // No explicit linker specified, search the environment for common tools. // We want LLD: @@ -53,8 +54,9 @@ class UnixLinkerTool : public LinkerTool { // of these, at least given current behavior. toolPath = findToolInEnvironment({"ld.lld", "ld"}); } - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "No Unix linker tool found in environment.\n"; return ""; @@ -129,8 +131,9 @@ class UnixLinkerTool : public LinkerTool { } auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } return artifacts; } diff --git a/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp index 9e2e4fe6c6fc..5ae6f60fad79 100644 --- a/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/WasmLinkerTool.cpp @@ -54,8 +54,9 @@ class WasmLinkerTool : public LinkerTool { // or install directories) for common tools. std::string toolPath = findToolFromExecutableDir( {"wasm-ld", "iree-lld", "lld", "ld.lld", "lld-link"}); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "No Wasm linker tool specified or discovered\n"; return ""; @@ -131,8 +132,9 @@ class WasmLinkerTool : public LinkerTool { } auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } return artifacts; } }; diff --git a/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp b/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp index 8c3502771144..824af9a90a3c 100644 --- a/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp +++ b/compiler/plugins/target/LLVMCPU/internal/WindowsLinkerTool.cpp @@ -23,14 +23,16 @@ class WindowsLinkerTool : public LinkerTool { std::string getSystemToolPath() const override { // First check for setting the linker explicitly. auto toolPath = LinkerTool::getSystemToolPath(); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } // No explicit linker specified, search the executable directory (i.e. our // own build or install directories) for common tools. toolPath = findToolFromExecutableDir({"lld-link"}); - if (!toolPath.empty()) + if (!toolPath.empty()) { return toolPath; + } llvm::errs() << "No Windows linker tool specified or discovered\n"; return ""; @@ -273,8 +275,9 @@ class WindowsLinkerTool : public LinkerTool { } auto commandLine = llvm::join(flags, " "); - if (failed(runLinkCommand(commandLine))) + if (failed(runLinkCommand(commandLine))) { return std::nullopt; + } // PDB file gets generated wtih the same path + .pdb. artifacts.debugFile = diff --git a/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp b/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp index 2937fefc354e..c50969dab99d 100644 --- a/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp +++ b/compiler/plugins/target/MetalSPIRV/MSLToMetalLib.cpp @@ -54,8 +54,9 @@ static std::string getMetalCompileCommand(MetalTargetPlatform platform, static LogicalResult runSystemCommand(llvm::StringRef command) { LLVM_DEBUG(llvm::dbgs() << "Running system command: '" << command << "'\n"); int exitCode = system(command.data()); - if (exitCode == 0) + if (exitCode == 0) { return success(); + } llvm::errs() << "Failed to run system command '" << command << "' with error code: " << exitCode << "\n"; return failure(); @@ -78,8 +79,9 @@ compileMSLToMetalLib(MetalTargetPlatform targetPlatform, std::string command = getMetalCompileCommand(targetPlatform, mslFile, libFile); - if (failed(runSystemCommand(command))) + if (failed(runSystemCommand(command))) { return nullptr; + } auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(libFile, /*isText=*/false); diff --git a/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp b/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp index a86849ab565a..413056dfe2e2 100644 --- a/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp +++ b/compiler/plugins/target/MetalSPIRV/SPIRVToMSL.cpp @@ -33,8 +33,9 @@ class SPIRVToMSLCompiler : public SPIRV_CROSS_NAMESPACE::CompilerMSL { entryName.str(), spv::ExecutionModel::ExecutionModelGLCompute); const auto &workgroupSize = entryPoint.workgroup_size; // TODO(antiagainst): support specialization constant. - if (workgroupSize.constant != 0) + if (workgroupSize.constant != 0) { return {0, 0, 0}; + } return {workgroupSize.x, workgroupSize.y, workgroupSize.z}; } @@ -127,8 +128,9 @@ crossCompileSPIRVToMSL(IREE::HAL::MetalTargetPlatform targetPlatform, SmallVector descriptors; bool hasPushConstant = false; - if (!spvCrossCompiler.getResources(&descriptors, &hasPushConstant)) + if (!spvCrossCompiler.getResources(&descriptors, &hasPushConstant)) { return std::nullopt; + } // Explicitly set the argument buffer [[id(N)]] location for each SPIR-V // resource variable. diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 516dead102d5..f994f722b356 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -608,8 +608,9 @@ class ROCMTargetBackend final : public TargetBackend { for (auto func : innerModuleOp.getOps()) { llvm::Function *llvmFunc = llvmModule->getFunction(func.getName()); - if (llvmFunc->isDeclaration()) + if (llvmFunc->isDeclaration()) { continue; + } // Override flags as given by target func attrs. if (auto funcAttrs = @@ -702,8 +703,9 @@ class ROCMTargetBackend final : public TargetBackend { llvmModule->addModuleFlag(llvm::Module::Error, "amdhsa_code_object_version", abiVersion); - for (llvm::Function &f : llvmModule->functions()) + for (llvm::Function &f : llvmModule->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); + } // Link user-provided modules. llvm::Linker linker(*llvmModule); @@ -814,8 +816,9 @@ class ROCMTargetBackend final : public TargetBackend { // final FlatBuffer. std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine); targetHSACO = createHsaco(variantOp.getLoc(), targetObj, libraryName); - if (targetHSACO.empty()) + if (targetHSACO.empty()) { return failure(); + } if (options.enableRegSpillWarning) { checkRegisterSpilling(variantOp, targetObj); diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp index ce4120d2d96b..e6bc477fbdc1 100644 --- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp @@ -50,8 +50,9 @@ loadIRModule(Location loc, const std::string &filename, static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module, ArrayRef bitcodePaths) { - if (bitcodePaths.empty()) + if (bitcodePaths.empty()) { return success(); + } llvm::Linker linker(*module); for (auto &bitcodePath : bitcodePaths) { if (!llvm::sys::fs::exists(bitcodePath)) { @@ -62,8 +63,9 @@ static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module, } std::unique_ptr bitcodeModule = loadIRModule(loc, bitcodePath, &module->getContext()); - if (!bitcodeModule) + if (!bitcodeModule) { return failure(); + } // Ignore the data layout of the module we're importing. This avoids a // warning from the linker. bitcodeModule->setDataLayout(module->getDataLayout()); @@ -107,8 +109,9 @@ static void overridePlatformGlobal(llvm::Module *module, StringRef globalName, uint32_t newValue, llvm::Type *globalTy) { // NOTE: the global will not be defined if it is not used in the module. auto *globalValue = module->getNamedGlobal(globalName); - if (!globalValue) + if (!globalValue) { return; + } globalValue->setDSOLocal(true); globalValue->setConstant(true); globalValue->setInitializer(llvm::ConstantInt::get( @@ -160,10 +163,11 @@ LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module, for (const llvm::Function &function : module->functions()) { if (!function.isIntrinsic() && function.isDeclaration()) { auto functionName = function.getName(); - if (functionName.starts_with("__ocml_")) + if (functionName.starts_with("__ocml_")) { usesOCML = true; - else if (functionName.starts_with("__ockl_")) + } else if (functionName.starts_with("__ockl_")) { usesOCKL = true; + } } } From b7a46df44c7f92481c29e7b42735ee14fb650b22 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 15 Jan 2026 13:00:25 -0800 Subject: [PATCH 45/71] [bazel][NFC] Add `# keep sorted` to enforce_glob calls in BUILD.bazel files. (#23149) This change adds buildifier's `# keep sorted` directive to all `enforce_glob` calls in `compiler/` and `tests/` directories, ensuring file lists stay alphabetically sorted. Also converts one remaining `glob()` call to `enforce_glob()` in the ROCM ukernel BUILD.bazel. Signed-off-by: hanhanW --- .../Conversion/Preprocessing/test/BUILD.bazel | 1 + .../StableHLO/Conversion/test/BUILD.bazel | 1 + .../TOSA/InputConversion/test/BUILD.bazel | 1 + compiler/plugins/target/CUDA/test/BUILD.bazel | 1 + .../plugins/target/LLVMCPU/test/BUILD.bazel | 1 + .../target/MetalSPIRV/test/BUILD.bazel | 1 + .../target/ROCM/Dialect/ROCM/IR/BUILD.bazel | 1 + .../ROCM/builtins/mlir_ukernel/BUILD.bazel | 10 +++++- .../ROCM/builtins/mlir_ukernel/CMakeLists.txt | 7 +++-- compiler/plugins/target/VMVX/test/BUILD.bazel | 1 + .../target/VulkanSPIRV/test/BUILD.bazel | 1 + .../Native/Transforms/test/BUILD.bazel | 1 + .../TFLite/Transforms/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Common/TransformExtensions/BUILD.bazel | 1 + .../compiler/Codegen/Common/test/BUILD.bazel | 13 ++++---- .../Codegen/Dialect/CPU/IR/BUILD.bazel | 1 + .../Codegen/Dialect/CPU/IR/test/BUILD.bazel | 1 + .../Codegen/Dialect/Codegen/IR/BUILD.bazel | 1 + .../Dialect/Codegen/IR/test/BUILD.bazel | 1 + .../Codegen/Dialect/GPU/IR/BUILD.bazel | 1 + .../Codegen/Dialect/GPU/IR/test/BUILD.bazel | 1 + .../GPU/TransformExtensions/BUILD.bazel | 1 + .../GPU/TransformExtensions/test/BUILD.bazel | 3 +- .../Dialect/GPU/Transforms/test/BUILD.bazel | 1 + .../PCF/ExternalInterfaces/test/BUILD.bazel | 1 + .../Codegen/Dialect/PCF/IR/BUILD.bazel | 1 + .../Codegen/Dialect/PCF/IR/test/BUILD.bazel | 1 + .../Dialect/PCF/Transforms/test/BUILD.bazel | 1 + .../Codegen/Dialect/VectorExt/IR/BUILD.bazel | 3 +- .../Dialect/VectorExt/IR/test/BUILD.bazel | 1 + .../VectorExt/Transforms/test/BUILD.bazel | 1 + .../ExternalInterfaces/test/BUILD.bazel | 3 +- .../compiler/Codegen/Interfaces/BUILD.bazel | 3 +- .../LLVMCPU/TransformExtensions/BUILD.bazel | 1 + .../LLVMGPU/TransformExtensions/BUILD.bazel | 1 + .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 31 ++++++++++--------- .../Codegen/LLVMGPU/test/ROCDL/BUILD.bazel | 7 +++-- .../compiler/Codegen/SPIRV/test/BUILD.bazel | 9 +++--- .../compiler/Codegen/WGSL/test/BUILD.bazel | 1 + .../iree/compiler/ConstEval/test/BUILD.bazel | 1 + .../compiler/Dialect/Encoding/IR/BUILD.bazel | 1 + .../Dialect/Encoding/IR/test/BUILD.bazel | 1 + .../Conversion/ShardToFlow/test/BUILD.bazel | 1 + .../Conversion/TensorToFlow/test/BUILD.bazel | 3 +- .../iree/compiler/Dialect/Flow/IR/BUILD.bazel | 1 + .../compiler/Dialect/Flow/IR/test/BUILD.bazel | 1 + .../Flow/TransformExtensions/BUILD.bazel | 1 + .../Dialect/Flow/Transforms/test/BUILD.bazel | 1 + .../HAL/Conversion/HALToHAL/test/BUILD.bazel | 1 + .../HAL/Conversion/HALToVM/test/BUILD.bazel | 1 + .../Conversion/StandardToHAL/test/BUILD.bazel | 1 + .../Conversion/StreamToHAL/test/BUILD.bazel | 1 + .../HAL/Conversion/UtilToHAL/test/BUILD.bazel | 1 + .../iree/compiler/Dialect/HAL/IR/BUILD.bazel | 1 + .../compiler/Dialect/HAL/IR/test/BUILD.bazel | 3 +- .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 + .../compiler/Dialect/LinalgExt/IR/BUILD.bazel | 1 + .../Dialect/LinalgExt/IR/test/BUILD.bazel | 1 + .../LinalgExt/TransformExtensions/BUILD.bazel | 1 + .../TransformExtensions/test/BUILD.bazel | 1 + .../LinalgExt/Transforms/test/BUILD.bazel | 1 + .../Conversion/FlowToStream/test/BUILD.bazel | 1 + .../Conversion/HALToStream/test/BUILD.bazel | 1 + .../StandardToStream/test/BUILD.bazel | 1 + .../Conversion/UtilToStream/test/BUILD.bazel | 1 + .../compiler/Dialect/Stream/IR/BUILD.bazel | 1 + .../Dialect/Stream/IR/test/BUILD.bazel | 1 + .../Stream/Transforms/test/BUILD.bazel | 1 + .../Stream/Transforms/test/e2e/BUILD.bazel | 1 + .../compiler/Dialect/TensorExt/IR/BUILD.bazel | 1 + .../Dialect/TensorExt/IR/test/BUILD.bazel | 1 + .../TensorExt/Transforms/test/BUILD.bazel | 1 + .../Conversion/FuncToUtil/test/BUILD.bazel | 1 + .../Conversion/MemRefToUtil/test/BUILD.bazel | 1 + .../Dialect/Util/Conversion/test/BUILD.bazel | 1 + .../iree/compiler/Dialect/Util/IR/BUILD.bazel | 1 + .../compiler/Dialect/Util/IR/test/BUILD.bazel | 1 + .../Dialect/Util/TransformOps/BUILD.bazel | 1 + .../Util/TransformOps/test/BUILD.bazel | 1 + .../Dialect/Util/Transforms/test/BUILD.bazel | 1 + .../Dialect/VM/Analysis/test/BUILD.bazel | 1 + .../VM/Conversion/ArithToVM/test/BUILD.bazel | 1 + .../VM/Conversion/MathToVM/test/BUILD.bazel | 1 + .../Conversion/StandardToVM/test/BUILD.bazel | 1 + .../VM/Conversion/UtilToVM/test/BUILD.bazel | 1 + .../VM/Conversion/VMToEmitC/test/BUILD.bazel | 17 +++++----- .../iree/compiler/Dialect/VM/IR/BUILD.bazel | 1 + .../compiler/Dialect/VM/IR/test/BUILD.bazel | 1 + .../VM/Target/Bytecode/test/BUILD.bazel | 1 + .../Dialect/VM/Transforms/test/BUILD.bazel | 1 + .../Conversion/HALToVMVX/test/BUILD.bazel | 1 + .../StandardToVMVX/test/BUILD.bazel | 1 + .../VMVX/Conversion/VMVXToVM/test/BUILD.bazel | 1 + .../iree/compiler/Dialect/VMVX/IR/BUILD.bazel | 1 + .../compiler/Dialect/VMVX/IR/test/BUILD.bazel | 1 + .../Dialect/VMVX/Transforms/test/BUILD.bazel | 1 + .../DispatchCreation/test/BUILD.bazel | 27 ++++++++-------- .../GlobalOptimization/test/BUILD.bazel | 1 + .../InputConversion/Common/test/BUILD.bazel | 1 + .../compiler/Modules/Check/IR/BUILD.bazel | 1 + .../compiler/Modules/Check/test/BUILD.bazel | 1 + .../Conversion/HALInlineToVM/test/BUILD.bazel | 1 + .../HALToHALInline/test/BUILD.bazel | 1 + .../StreamToHALInline/test/BUILD.bazel | 1 + .../Modules/HAL/Inline/IR/BUILD.bazel | 1 + .../Modules/HAL/Inline/IR/test/BUILD.bazel | 1 + .../HAL/Inline/Transforms/test/BUILD.bazel | 1 + .../Conversion/HALLoaderToVM/test/BUILD.bazel | 1 + .../StreamToHALLoader/test/BUILD.bazel | 1 + .../Modules/HAL/Loader/IR/BUILD.bazel | 1 + .../Modules/HAL/Loader/IR/test/BUILD.bazel | 1 + .../HAL/Loader/Transforms/test/BUILD.bazel | 1 + .../Conversion/ParamsToVM/test/BUILD.bazel | 1 + .../StreamToParams/test/BUILD.bazel | 1 + .../Modules/IO/Parameters/IR/BUILD.bazel | 1 + .../Modules/IO/Parameters/IR/test/BUILD.bazel | 1 + .../IO/Parameters/Transforms/test/BUILD.bazel | 1 + .../Preprocessing/Common/test/BUILD.bazel | 1 + .../TransformExtensions/BUILD.bazel | 1 + compiler/src/iree/compiler/Utils/BUILD.bazel | 1 + tests/compiler_driver/BUILD.bazel | 1 + tests/e2e/parameters/BUILD.bazel | 1 + tests/e2e/regression/stablehlo/BUILD.bazel | 6 ++++ tests/e2e/stablehlo_models/BUILD.bazel | 1 + tests/e2e/stablehlo_ops/BUILD.bazel | 1 + tests/e2e/subbyte_types/BUILD.bazel | 1 + tests/e2e/tosa_ops/BUILD.bazel | 2 ++ 128 files changed, 200 insertions(+), 59 deletions(-) diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel index 988c92ddc26d..1ddae17e0da1 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalization.mlir", "canonicalize_dot_general.mlir", diff --git a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel index c8f2420d6ae4..ca0da7f359e9 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel +++ b/compiler/plugins/input/StableHLO/Conversion/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "auto_input_conversion.mlir", "convert_collectives.mlir", diff --git a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel index 8d592371c72d..9c9e6a66d044 100644 --- a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel +++ b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "apply_pdl_patterns_tosa.mlir", "auto_input_conversion.mlir", diff --git a/compiler/plugins/target/CUDA/test/BUILD.bazel b/compiler/plugins/target/CUDA/test/BUILD.bazel index 4bb0e5e82b7c..51edda9651ef 100644 --- a/compiler/plugins/target/CUDA/test/BUILD.bazel +++ b/compiler/plugins/target/CUDA/test/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "smoketest.mlir", ], diff --git a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel index a8ce13583dd3..496cb7698608 100644 --- a/compiler/plugins/target/LLVMCPU/test/BUILD.bazel +++ b/compiler/plugins/target/LLVMCPU/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "hal_target_device_attributes.mlir", "materialize_homogeneous_encodings.mlir", diff --git a/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel b/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel index 0bf6c5a2c1da..9eeac27b6974 100644 --- a/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel +++ b/compiler/plugins/target/MetalSPIRV/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted ["smoketest.mlir"], include = ["*.mlir"], ), diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel index 144d37168e2c..3c37fbacaba3 100644 --- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel +++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel @@ -21,6 +21,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "ROCMAttrs.td", "ROCMDialect.td", diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel index 8b3a36ede083..ee2a7e155773 100644 --- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel +++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/BUILD.bazel @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content") +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") load("//build_tools/embed_data:build_defs.bzl", "iree_c_embed_data") @@ -23,7 +24,14 @@ endif() inline = True, ) -ukernel_patterns_mlir_files = glob(["ukernel_patterns_*.mlir"]) +ukernel_patterns_mlir_files = enforce_glob( + # keep sorted + [ + "ukernel_patterns_gfx942.mlir", + "ukernel_patterns_gfx950.mlir", + ], + include = ["ukernel_patterns_*.mlir"], +) iree_c_embed_data( name = "iree_mlir_ukernel_patterns_amdgpu", diff --git a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt index f392a56c96c3..9af5d0c64320 100644 --- a/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/builtins/mlir_ukernel/CMakeLists.txt @@ -14,12 +14,12 @@ if(NOT IREE_TARGET_BACKEND_ROCM) return() endif() -file(GLOB _GLOB_UKERNEL_PATTERNS_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS ukernel_patterns_*.mlir) iree_c_embed_data( NAME iree_mlir_ukernel_patterns_amdgpu SRCS - "${_GLOB_UKERNEL_PATTERNS_X_MLIR}" + "ukernel_patterns_gfx942.mlir" + "ukernel_patterns_gfx950.mlir" C_FILE_OUTPUT "iree_mlir_ukernel_patterns_amdgpu.c" H_FILE_OUTPUT @@ -32,7 +32,8 @@ iree_lit_test_suite( NAME verify_mlir_ukernel_patterns_amdgpu SRCS - "${_GLOB_UKERNEL_PATTERNS_X_MLIR}" + "ukernel_patterns_gfx942.mlir" + "ukernel_patterns_gfx950.mlir" TOOLS iree-opt ) diff --git a/compiler/plugins/target/VMVX/test/BUILD.bazel b/compiler/plugins/target/VMVX/test/BUILD.bazel index f53b5b481690..1fdf96d9a36f 100644 --- a/compiler/plugins/target/VMVX/test/BUILD.bazel +++ b/compiler/plugins/target/VMVX/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "smoketest.mlir", ], diff --git a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel index b8394437156e..ca0a92f251c8 100644 --- a/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel +++ b/compiler/plugins/target/VulkanSPIRV/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "materialize_homogeneous_encodings.mlir", "smoketest.mlir", diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel index ec389ac0bb29..e44c2a853dc8 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "convert_streamable_ops.mlir", "wrap_entry_points.mlir", diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel index 1579818a792e..c6692436852a 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "wrap_entry_points.mlir", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 2673c42bbdf5..5e9bbe2b9fed 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "amdgpu_lower_coalesced_dma_to_gather_lds.mlir", "decompose_horizontally_fused_gemms.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel index 60f8016aa9e3..1b6a4d0a58ce 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "CommonExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 4c1ceff17ddd..702a432bde7b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "add_fmfs.mlir", "affinemin_canonicalization.mlir", @@ -27,19 +28,17 @@ iree_lit_test_suite( "bufferize_dispatch_tensor_load_store.mlir", "canonicalize_early_bufferization_ops.mlir", "canonicalize_interface_load_store.mlir", - "forall_to_for.mlir", "check_for_config.mlir", "combine_layout_transformation.mlir", "convert_accgemm_to_gemm.mlir", - "convert_bf16_to_uint16_buffers.mlir", "convert_bf16_arith_to_f32.mlir", + "convert_bf16_to_uint16_buffers.mlir", "convert_hal_descriptor_type_to_gpu_address_space.mlir", "convert_to_destination_passing_style.mlir", "convert_unsupported_float_arith.mlir", "convert_workgroup_forall_to_pcf.mlir", "convolution_to_igemm.mlir", "convolutions.mlir", - "erase_dead_alloc_and_stores.mlir", "decompose_affine_ops.mlir", "decompose_boundary_pack_unpack_ops.mlir", "decompose_conv2d.mlir", @@ -49,6 +48,7 @@ iree_lit_test_suite( "decompose_softmax.mlir", "eliminate_empty_tensors.mlir", "emulate_narrow_type.mlir", + "erase_dead_alloc_and_stores.mlir", "erase_hal_descriptor_type.mlir", "extract_address_computation.mlir", "fission_transfer_ops_control_flow.mlir", @@ -60,14 +60,15 @@ iree_lit_test_suite( "fold_reshape_into_interface_tensor.mlir", "fold_split_reduction_workgroup_mapping_loops.mlir", "fold_tensor_extract_op.mlir", + "forall_to_for.mlir", "forop_canonicalization.mlir", "generic_vectorization.mlir", "hoist_statically_bound_allocations.mlir", "hoist_unrolled_vector_extract_insert_slice.mlir", "iree_codegen_canonicalize.mlir", "iree_comprehensive_bufferize.mlir", - "iree_expand_strided_metadata_with_subview_expansion.mlir", "iree_expand_strided_metadata.mlir", + "iree_expand_strided_metadata_with_subview_expansion.mlir", "iree_inject_assume_alignment.mlir", "iree_loop_invariant_code_motion.mlir", "link_tuning_specs.mlir", @@ -109,10 +110,10 @@ iree_lit_test_suite( "rematerialize_parallel_ops.mlir", "remove_dead_allocs.mlir", "remove_single_iteration_loop.mlir", - "resolve_swizzle_hints.mlir", - "resolve_workgroup_count_hints.mlir", "repeated_matcher_use.mlir", "replace_slow_min_max_ops.mlir", + "resolve_swizzle_hints.mlir", + "resolve_workgroup_count_hints.mlir", "specialize_exports.mlir", "strip_compilation_info.mlir", "test_partitionable_loops_interface.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel index 676a4fd8c4b9..52c452d0804c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/BUILD.bazel @@ -21,6 +21,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREECPUAttrs.td", "IREECPUDialect.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel index c7ebf7648fd1..9f62cb2c2724 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/CPU/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "invalid.mlir", "roundtrip.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel index 1612effd8a3d..84708c271ef0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel @@ -25,6 +25,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREECodegenAttrs.td", "IREECodegenDialect.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel index 588d9dea767c..53856548c216 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "invalid.mlir", "lowering_config_attr.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel index e82cc074eb70..aa1a4f229316 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel @@ -23,6 +23,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREEGPUAttrs.td", "IREEGPUDialect.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel index 88b751a73805..6698292164a1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "bufferize_coalesced_gather_dma.mlir", "canonicalize.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel index 7fcef1e752f8..898344518ba9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IREEGPUExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 405750a481cf..5cc241ac631f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "convert_to_multi_mma.mlir", "distribute_inner_tiled.mlir", @@ -28,8 +29,8 @@ iree_lit_test_suite( "transform_fuse_extract_slice_with_forall.mlir", "transform_fuse_forall.mlir", "transform_lower_barrier_region.mlir", - "vectorize_iree_gpu_ops.mlir", "unroll_multi_mma.mlir", + "vectorize_iree_gpu_ops.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel index 3868bd7e564b..46574c8f7180 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "combine_barrier_regions.mlir", "distribute_inner_tiled_to_lanes.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel index 63fd7bf4c497..cdb68196b384 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/ExternalInterfaces/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "bufferize.mlir", ], diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel index 1c16ae00608c..ca840f957570 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "PCFBase.td", "PCFInterfaces.td", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel index 7fa829131cde..4a76b1f09230 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "control_flow_ops.mlir", "folders.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel index 8e3484164d2f..49034482436a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "convert_forall_to_loops.mlir", "convert_sref_to_memref.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel index c165a4adc525..212b45745f7a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel @@ -23,11 +23,12 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "VectorExtAttrs.td", "VectorExtBase.td", - "VectorExtOps.td", "VectorExtInterfaces.td", + "VectorExtOps.td", ], include = ["*.td"], ), diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel index f4b3ce71c919..18b6c1bbbda4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalize.mlir", "invalid.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel index d6f29e9f09ca..55e1e50d8431 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "vector_ext_fold_unit_extent_dims.mlir", "vectorize_vector_ext_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel index b1524933aa58..9f62cb2c2724 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/BUILD.bazel @@ -15,9 +15,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ - "roundtrip.mlir", "invalid.mlir", + "roundtrip.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel index 1501c7838e1d..e9812c35456f 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel @@ -22,11 +22,12 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "PartitionableLoopsInterface.td", "ProcessorOpInterfaces.td", - "UKernelOpInterface.td", "TensorMaskingOpInterface.td", + "UKernelOpInterface.td", ], include = ["*.td"], ), diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel index 8e73e87b943c..522ae9359b0e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LLVMCPUExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel index 3227ac21ea14..a152783be7e4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LLVMGPUExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index d5e77632a5a3..631efee31d63 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -17,17 +17,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "amdgpu_emulate_narrow_type.mlir", "assign_constant_ordinals.mlir", - "conv_pipeline_test_cuda.mlir", - "convert_to_nvvm.mlir", - "convert_to_rocdl.mlir", - "convert_to_rocdl_gfx950.mlir", - "create_async_groups.mlir", - "create_tile_sizes.mlir", - "distribute_to_thread.mlir", - "elementwise_pipeline.mlir", "cast_address_space_function.mlir", "cast_type_to_fit_mma.mlir", "config_custom_op.mlir", @@ -38,23 +31,32 @@ iree_lit_test_suite( "config_root_op_attribute.mlir", "config_sort.mlir", "config_winograd.mlir", + "configure_tensor_layout.mlir", + "conv_pipeline_test_cuda.mlir", + "convert_to_nvvm.mlir", + "convert_to_rocdl.mlir", + "convert_to_rocdl_gfx950.mlir", + "create_async_groups.mlir", + "create_tile_sizes.mlir", + "distribute_to_thread.mlir", + "elementwise_pipeline.mlir", "extract_address_computation_gpu.mlir", "gpu_pipeline_data_tiling.mlir", "gpu_pipeline_relayout_ops.mlir", "horizontal_fusion_pipeline.mlir", - "link_executables.mlir", - "reduction_pipeline_cuda.mlir", - "reduction_pipeline_rocm.mlir", - "reduction_pipeline_softmax_rocm.mlir", - "reuse_shared_memory_allocs.mlir", - "rocdl_pipeline_test.mlir", "legalize.mlir", "linalg_transform.mlir", + "link_executables.mlir", "llvmgpu_bufferize.mlir", "nvvm_pipeline_test.mlir", "pack_shared_memory_alloc.mlir", "pipeline_coalesced_dma.mlir", "prefetch_shared_memory.mlir", + "reduction_pipeline_cuda.mlir", + "reduction_pipeline_rocm.mlir", + "reduction_pipeline_softmax_rocm.mlir", + "reuse_shared_memory_allocs.mlir", + "rocdl_pipeline_test.mlir", "sort_pipeline_test.mlir", "tensorcore_vectorization.mlir", "transform_dialect_bufferize.mlir", @@ -68,7 +70,6 @@ iree_lit_test_suite( "transform_gpu_pipelining.mlir", "transform_vector_to_mma.mlir", "transpose_pipeline_test.mlir", - "configure_tensor_layout.mlir", "vector_lowering.mlir", "vector_to_gpu.mlir", "winograd_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index 05ccee9b3112..342984a9a8ec 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_kernel_for_translation.mlir", "buffer_instructions_optimization.mlir", @@ -24,12 +25,12 @@ iree_lit_test_suite( "config_igemm_tile_and_fuse.mlir", "config_tile_and_fuse.mlir", "config_tile_and_fuse_gfx950.mlir", + "config_user_vector_distribute.mlir", "config_vector_distribute_gfx1100.mlir", "config_vector_distribute_gfx942.mlir", "config_vector_distribute_gfx950.mlir", "config_vector_distribute_reduction_gfx942.mlir", "config_vector_distribute_reduction_gfx950.mlir", - "config_user_vector_distribute.mlir", "configure_buffer_instructions.mlir", "pipeline_direct_conv_tile_and_fuse.mlir", "pipeline_elementwise_f8fnuz.mlir", @@ -40,10 +41,10 @@ iree_lit_test_suite( "pipeline_tile_and_fuse.mlir", "pipeline_tile_and_fuse_gfx950.mlir", "pipeline_vector_distribute_dynamic_shapes_gfx942.mlir", + "pipeline_vector_distribute_gfx1100.mlir", "pipeline_vector_distribute_gfx942.mlir", - "pipeline_vector_distribute_reduction_gfx942.mlir", "pipeline_vector_distribute_gfx950.mlir", - "pipeline_vector_distribute_gfx1100.mlir", + "pipeline_vector_distribute_reduction_gfx942.mlir", ], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel index afeb7e4a61ea..47d4524d055d 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_winograd_loops.mlir", "break_down_large_vector.mlir", @@ -38,8 +39,8 @@ iree_lit_test_suite( "config_nvidia_matmul.mlir", "config_nvidia_matmul_cooperative_ops.mlir", "config_user.mlir", - "convert_to_spirv.mlir", "convert_gpu_target.mlir", + "convert_to_spirv.mlir", "emulate_i64.mlir", "erase_storage_buffer_static_shape.mlir", "illegal_configuration.mlir", @@ -49,17 +50,17 @@ iree_lit_test_suite( "lowering_matmul_fusion.mlir", "lowering_matmul_promotion.mlir", "lowering_matvec.mlir", - "lowering_scalar_dispatch.mlir", "lowering_reduction.mlir", + "lowering_scalar_dispatch.mlir", "map_memref_storage_class.mlir", "materialize_executable_conditions.mlir", + "physical_storage_buffer_addresses.mlir", "pipeline_matmul_cooperative_ops.mlir", "pipeline_matmul_promotion.mlir", "pipeline_matmul_vectorization.mlir", "pipeline_matvec.mlir", "pipeline_reduction_subgroup.mlir", "pipeline_sub_byte_dequant.mlir", - "physical_storage_buffer_addresses.mlir", "tile_and_distribute.mlir", "tile_and_distribute_scatter.mlir", "tile_and_distribute_sort.mlir", @@ -74,8 +75,8 @@ iree_lit_test_suite( "vectorize_conv.mlir", "vectorize_elementwise_ops.mlir", "vectorize_gather.mlir", - "vectorize_matmul.mlir", "vectorize_load_store.mlir", + "vectorize_matmul.mlir", "vectorize_reduction.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel index 26b0d8f876e2..f9654b309494 100644 --- a/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/WGSL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel index 428ca152d8ac..5c4f59509016 100644 --- a/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel +++ b/compiler/src/iree/compiler/ConstEval/test/BUILD.bazel @@ -16,6 +16,7 @@ iree_lit_test_suite( name = "lit", timeout = "moderate", srcs = enforce_glob( + # keep sorted [ "compile_regressions.mlir", "failing.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel index 49e728b41187..df2e48c9a953 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "EncodingAttrs.td", "EncodingBase.td", diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel index 55216bea1866..18bbcfefa300 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "invalid.mlir", "roundtrip.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel index 5f15951f4cc8..1e48034998f7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/ShardToFlow/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "channel_creation.mlir", "collectives.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel index 80695f1e9590..c3164bbea997 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "bitcast.mlir", "cast.mlir", @@ -23,8 +24,8 @@ iree_lit_test_suite( "extract_slice.mlir", "fill.mlir", "from_elements.mlir", - "insert_slice.mlir", "insert.mlir", + "insert_slice.mlir", "reshape.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel index 4ecd72fbfa41..65e4304312a5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "FlowBase.td", "FlowInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel index 0e8edf6c4512..111c516e29aa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "call_ops.mlir", "dispatch_folding.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel index dba8a9ed4574..aec29f76740e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "FlowExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel index 44335a5d6b87..cfa0a91985f9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_dispatches.mlir", "canonicalize.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel index b38bbbea7b32..a4085ef63ddd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "device_ops.mlir", "pseudo_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel index c949d6efdf21..0a16a83b5434 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "allocator_ops.mlir", "buffer_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel index 779a62ec7217..853292ccd18f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "shape_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel index 2d6f7779493a..c91e92dab276 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "channel_ops.mlir", "cmd_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel index b9874626c5aa..f35d67c7b58b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted ["global_ops.mlir"], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index 9d935df7f9b5..cac7679b8232 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -21,6 +21,7 @@ exports_files([ iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "HALAttrs.td", "HALBase.td", diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel index 39c200fd5d06..18609d466992 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "allocator_folding.mlir", "allocator_ops.mlir", @@ -34,8 +35,8 @@ iree_lit_test_suite( "experimental_ops.mlir", "fence_folding.mlir", "fence_ops.mlir", - "interface_ops.mlir", "interface_folding.mlir", + "interface_ops.mlir", "invalid.mlir", "tensor_folding.mlir", "tensor_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index d100a35b1ba0..0f87d6865510 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_target_devices.mlir", "assign_legacy_target_devices.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel index 7209bda88aa7..244aac9a100a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LinalgExtBase.td", "LinalgExtInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel index aff5921fd57b..1d224bb0bcc3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalize.mlir", "decompose_aggregate_op.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel index d70b7ab22f1f..65619199475c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "LinalgExtExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel index 603f52011f7a..43f44a7a07e0 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/TransformExtensions/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel index 6e36532763a8..2176e3431a51 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "conv2d_to_winograd.mlir", "conv_to_im2col.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel index eb5c9b643c5d..7ad8f03376e4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "call_ops.mlir", "collective_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel index 34c6442f57a3..e1497efcb477 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "abi_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel index 54fbdc060085..cd8d750cc3fa 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "constant_ops.mlir", "structural_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel index ceb82144e9f2..eecd75ff1755 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "compiler_hints.mlir", "global_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel index 91244ba3b01e..33c213a8a6e4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "StreamAttrs.td", "StreamBase.td", diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel index 7d7ee55b763a..cda6a510571e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "async_folding.mlir", "async_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index f186cdcc207b..98a16f2cda2b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_affinities.mlir", "annotate_constant_transient_size.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel index e17c889f3868..a7c22e2d0f88 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/e2e/BUILD.bazel @@ -20,6 +20,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "async_parameters.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel index a91cd2f27e82..a278378e9ce0 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "TensorExtAttrs.td", "TensorExtBase.td", diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel index 9b34141476ab..41a26b2ff6ad 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "dispatch_tensor_folding.mlir", "dispatch_workload_ordinal_folding.mlir", diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel index 2b205eaaa6ef..83ce0a2d4733 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "sparse_interface_methods.mlir", "sparse_interface_methods_estimated_loop_range_fail.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel index 5db13f411ee6..3e92f3d8edc3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "func_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel index 794f7ca76376..67706eb1ea6a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "memref_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel index 54a003dd2c21..a4b61738a425 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "compiler_hints.mlir", "structural_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel index 2ca83b21fdd2..0a8bf199db8c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["UtilBase.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "UtilAttrs.td", "UtilBase.td", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel index a1c60408471f..abb8771cf69c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "alignment_folding.mlir", "alignment_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel index 5abb54d70073..143b5be5ee62 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "UtilTransformOps.td", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel index 76afbf00d106..138cbadc3b25 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "create_serialized_module.mlir", "symbol_transforms.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index d3fe86862e8b..f4b05a1ef1fe 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_op_ordinals.mlir", "attribute_call_graph.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel index 90d56411b87c..5f09cbe79644 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "live_intervals.mlir", "register_allocation.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel index 9a0cbbace1fc..3076f512dc6f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "arithmetic_ops.mlir", "assignment_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel index 206cdadd7685..07c99862721b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "arithmetic_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel index bc6b7ac5b474..35075c7270dc 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "control_flow_ops.mlir", "func_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel index 3a4dd9ee2458..1dc2826c8774 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "alignment_ops.mlir", "assignment_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel index 2d9dc4c22200..1862bd578552 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/BUILD.bazel @@ -24,35 +24,36 @@ endif() iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ + "arithmetic_ops.mlir", "arithmetic_ops_f32.mlir", "arithmetic_ops_i64.mlir", - "arithmetic_ops.mlir", + "assignment_ops.mlir", "assignment_ops_f32.mlir", "assignment_ops_i64.mlir", - "assignment_ops.mlir", "buffer_ops.mlir", "buffer_ops_f32.mlir", "buffer_ops_f64.mlir", "buffer_ops_i64.mlir", + "comparison_ops.mlir", "comparison_ops_f32.mlir", "comparison_ops_i64.mlir", - "comparison_ops.mlir", + "const_ops.mlir", "const_ops_f32.mlir", "const_ops_i64.mlir", - "const_ops.mlir", "control_flow_ops.mlir", + "conversion_ops.mlir", "conversion_ops_f32.mlir", "conversion_ops_i64.mlir", - "conversion_ops.mlir", "func_op.mlir", + "global_ops.mlir", "global_ops_f32.mlir", "global_ops_i64.mlir", - "global_ops.mlir", - "list_ops_i64.mlir", "list_ops.mlir", - "shift_ops_i64.mlir", + "list_ops_i64.mlir", "shift_ops.mlir", + "shift_ops_i64.mlir", "type_conversion.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel index d5dff039a403..803e15bf969e 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["VMOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "VMBase.td", "VMOpcodesCore.td", diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel index a886e9ec5890..0f638408a177 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "arithmetic_folding.mlir", "arithmetic_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel index 8ef4a8724d18..99467bbbb317 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "constant_encoding.mlir", "dependencies.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel index ae4b2025c32c..3290f81e6711 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_functions.mlir", "convert_to_yieldable_calls.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel index 99b86ec28216..3ab8248cceff 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted ["interface_ops.mlir"], include = ["*.mlir"], ), diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel index 5684262af8b2..5a72a1402263 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/StandardToVMVX/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel index 4995c8bdca7b..5e2a2c5c577e 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "binary.mlir", "copy.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel index b05ab342f6ab..4e49ae3ad0b6 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["VMLXOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "VMVXBase.td", "VMVXInterfaces.td", diff --git a/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel index 5684262af8b2..5a72a1402263 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel index c02abbca8ad2..2320156440e3 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "materialize_constants.mlir", "resolve_buffer_descriptors.mlir", diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel index 98f8f232ddfb..cdc161e94f1c 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel @@ -15,30 +15,27 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "annotate_data_tiling_hints.mlir", "bitcast_unsupported_element_types.mlir", + "bubble_up_expand_shapes.mlir", + "bubble_up_extract_slice.mlir", "clone_producers_into_dispatch_regions.mlir", "collapse_dimensions.mlir", "collapse_linalg_generic_on_tensors.mlir", - "elementwise_op_fusion.mlir", - "dispatch_region_formation_preprocessing.mlir", - "fold_reshapes_into_tensor_barriers.mlir", - "fold_unit_dims.mlir", - "form_dispatch_regions.mlir", - "dispatch_linalg_on_tensors.mlir", "convert_encoding_to_flow.mlir", "convert_region_to_workgroups.mlir", - "bubble_up_expand_shapes.mlir", - "bubble_up_extract_slice.mlir", - "form_dispatch_workgroups.mlir", "dispatch_linalg_ext_fusion.mlir", - "hoist_encoding_ops.mlir", - "hoist_uniform_scalar_compute.mlir", - "insert_tensor_barriers.mlir", - "remove_tensor_barriers.mlir", + "dispatch_linalg_on_tensors.mlir", "dispatch_linalg_on_tensors_default.mlir", "dispatch_linalg_on_tensors_fusion_with_transpose.mlir", + "dispatch_region_formation_preprocessing.mlir", + "elementwise_op_fusion.mlir", + "fold_reshapes_into_tensor_barriers.mlir", + "fold_unit_dims.mlir", + "form_dispatch_regions.mlir", + "form_dispatch_workgroups.mlir", "form_scalar_dispatches.mlir", "form_split_reduction_dispatches.mlir", "fuse_encoding_ops_into_dispatch_regions.mlir", @@ -46,6 +43,9 @@ iree_lit_test_suite( "fuse_multiuse_elementwise_producer.mlir", "fuse_multiuse_intra_dispatch.mlir", "fusion_preprocessing.mlir", + "hoist_encoding_ops.mlir", + "hoist_uniform_scalar_compute.mlir", + "insert_tensor_barriers.mlir", "materialize_default_workgroup_count_region.mlir", "pad_fusion_with_consumer.mlir", "pad_fusion_with_producer.mlir", @@ -53,6 +53,7 @@ iree_lit_test_suite( "pipeline_tests_aggressive.mlir", "pipeline_tests_split_reduction.mlir", "propagate_encodings.mlir", + "remove_tensor_barriers.mlir", "set_encoding.mlir", "set_encoding_padding.mlir", "set_encoding_pipeline.mlir", diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel index 76233b25c578..6a2a09440819 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "cleanup_numeric_narrowing.mlir", "conv1x1_to_matmul.mlir", diff --git a/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel index 3cb0af85328c..918a3d5a6671 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/InputConversion/Common/test/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "demote_f32_to_f16.mlir", "demote_f64_to_f32.mlir", diff --git a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel index 3f2bea0aef55..0a9d1dc304bd 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "CheckBase.td", "CheckOps.td", diff --git a/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel index 8f8f49a2a7ec..6f031429a476 100644 --- a/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "canonicalize.mlir", "ops.mlir", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel index 315246764f8e..accd28591098 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALInlineToVM/test/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel index a872428de0ba..0b0b3144856e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "buffer_ops.mlir", "buffer_view_ops.mlir", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel index 4c424a92daff..744d2d40971e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "cmd_ops.mlir", "debug_ops.mlir", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel index 6543c578a5dc..52bdf367df85 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["HALInlineOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "HALInlineBase.td", "HALInlineOps.td", diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel index 5cf787ee14bd..892b95fb1904 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "buffer_folding.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel index 647c5145a792..637ea14df48d 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "inline_executables.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel index e2feb4e13e13..5a047972de3b 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "executable_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel index 04aeda54bf77..3a5ddbbb438f 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/StreamToHALLoader/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "cmd_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel index 74739f603623..0f575dec174e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["HALLoaderOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "HALLoaderBase.td", "HALLoaderOps.td", diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel index ad58a2475fbd..8e3ed66d6a33 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "dispatch_folding.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel index a445578a0001..9e19da435caa 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "materialize_executables.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel index a58f8e507078..6530aa2037ee 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/ParamsToVM/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "parameter_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel index a58f8e507078..6530aa2037ee 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Conversion/StreamToParams/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "parameter_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel index affb48d05d05..6c6b395f2d55 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/BUILD.bazel @@ -18,6 +18,7 @@ exports_files(["IOParametersOps.td"]) iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "IOParametersBase.td", "IOParametersOps.td", diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel index a58f8e507078..6530aa2037ee 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/IR/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "parameter_ops.mlir", ], diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel index eb9ea3fe8998..241c1800eb40 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "export_parameters.mlir", "generate_splat_parameter_archive.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index 68e8d19d31c1..60f73989ce8a 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "attr_based_pipeline.mlir", "conv2d_to_img2col.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel index 9a8958a13f22..a49cab6a6e63 100644 --- a/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel @@ -16,6 +16,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "PreprocessingExtensionsOps.td", ], diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel index 8471e3665f49..050bed8656ff 100644 --- a/compiler/src/iree/compiler/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Utils/BUILD.bazel @@ -18,6 +18,7 @@ package( iree_td_library( name = "td_files", srcs = enforce_glob( + # keep sorted [ "CommonTypeConstraints.td", "DocMetadata.td", diff --git a/tests/compiler_driver/BUILD.bazel b/tests/compiler_driver/BUILD.bazel index a9b99b8558d2..fafafa864bf0 100644 --- a/tests/compiler_driver/BUILD.bazel +++ b/tests/compiler_driver/BUILD.bazel @@ -17,6 +17,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "executable_benchmarks.mlir", "hal_executable.mlir", diff --git a/tests/e2e/parameters/BUILD.bazel b/tests/e2e/parameters/BUILD.bazel index 3bdb9d6daf98..fe20806c512f 100644 --- a/tests/e2e/parameters/BUILD.bazel +++ b/tests/e2e/parameters/BUILD.bazel @@ -15,6 +15,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "encode_parameters.mlir", "export_parameters.mlir", diff --git a/tests/e2e/regression/stablehlo/BUILD.bazel b/tests/e2e/regression/stablehlo/BUILD.bazel index 7e2fff3b9132..1a99b7fac129 100644 --- a/tests/e2e/regression/stablehlo/BUILD.bazel +++ b/tests/e2e/regression/stablehlo/BUILD.bazel @@ -35,6 +35,7 @@ NON_CHECK_TESTS = [ iree_check_single_backend_test_suite( name = "check_stablehlo_regression_llvm-cpu", srcs = enforce_glob( + # keep sorted CHECK_TESTS + CPU_SPECIFIC_TESTS, include = ["*.mlir"], exclude = NON_CHECK_TESTS, @@ -47,6 +48,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_vmvx", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -58,6 +60,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_vulkan-spirv", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -69,6 +72,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_cuda", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -88,6 +92,7 @@ iree_check_single_backend_test_suite( iree_check_single_backend_test_suite( name = "check_stablehlo_regression_hip", srcs = enforce_glob( + # keep sorted CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + NON_CHECK_TESTS, @@ -109,6 +114,7 @@ iree_check_single_backend_test_suite( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted NON_CHECK_TESTS, include = ["*.mlir"], exclude = CPU_SPECIFIC_TESTS + CHECK_TESTS, diff --git a/tests/e2e/stablehlo_models/BUILD.bazel b/tests/e2e/stablehlo_models/BUILD.bazel index 3e68c6253689..78b4a6457c89 100644 --- a/tests/e2e/stablehlo_models/BUILD.bazel +++ b/tests/e2e/stablehlo_models/BUILD.bazel @@ -20,6 +20,7 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( + # keep sorted [ "collatz.mlir", "edge_detection.mlir", diff --git a/tests/e2e/stablehlo_ops/BUILD.bazel b/tests/e2e/stablehlo_ops/BUILD.bazel index 5b8fc814f49b..9118e78eb292 100644 --- a/tests/e2e/stablehlo_ops/BUILD.bazel +++ b/tests/e2e/stablehlo_ops/BUILD.bazel @@ -13,6 +13,7 @@ package( ) ALL_SRCS = enforce_glob( + # keep sorted [ "abs.mlir", "add.mlir", diff --git a/tests/e2e/subbyte_types/BUILD.bazel b/tests/e2e/subbyte_types/BUILD.bazel index ff1b1f3ea643..5782fd122dec 100644 --- a/tests/e2e/subbyte_types/BUILD.bazel +++ b/tests/e2e/subbyte_types/BUILD.bazel @@ -21,6 +21,7 @@ package( iree_check_single_backend_test_suite( name = "check_llvm-cpu_subbyte_emulation", srcs = enforce_glob( + # keep sorted [ "subbyte_types.mlir", ], diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel index c43028acb86d..fd7c52baac42 100644 --- a/tests/e2e/tosa_ops/BUILD.bazel +++ b/tests/e2e/tosa_ops/BUILD.bazel @@ -13,6 +13,7 @@ package( ) ALL_SRCS = enforce_glob( + # keep sorted [ "abs.mlir", "add.mlir", @@ -99,6 +100,7 @@ iree_check_single_backend_test_suite( ) ROCM_AND_CUDA_SRCS = enforce_glob( + # keep sorted [ "abs.mlir", "add.mlir", From dc3cf9338d2954f4fd4765cbc9f830b4928592f1 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 16:00:40 -0500 Subject: [PATCH 46/71] Add braces in Codegen/Common (core). NFC. 2/n (#23144) --- .../Codegen/Common/BlockDynamicDimensions.cpp | 21 ++- .../Codegen/Common/BufferizationAnalysis.cpp | 33 ++-- .../Codegen/Common/BufferizationAnalysis.h | 3 +- .../BufferizeCopyOnlyDispatchesPass.cpp | 6 +- .../Common/ConcretizePadResultShape.cpp | 12 +- .../Common/ConfigTrackingCanonicalizer.cpp | 6 +- .../Codegen/Common/ConvertBf16ArithToF32.cpp | 15 +- .../Common/ConvertBf16ToUInt16Buffers.cpp | 24 ++- .../ConvertToDestinationPassingStylePass.cpp | 50 ++++-- .../DecomposeConvolutionToLowerDimOps.cpp | 12 +- .../Codegen/Common/DecomposePackUnPackOps.cpp | 6 +- .../Codegen/Common/DecomposeSoftmax.cpp | 3 +- .../Codegen/Common/EmulateNarrowType.cpp | 9 +- .../EraseHALDescriptorTypeFromMemRef.cpp | 6 +- .../Common/ExtractAddressComputation.h | 3 +- .../Common/FlattenMemRefSubspanPass.cpp | 39 +++-- .../Codegen/Common/FlattenMemRefs.cpp | 21 ++- .../FoldAffineMinInDistributedLoops.cpp | 27 ++-- .../Common/FoldTensorExtractOpPass.cpp | 3 +- .../FoldTensorSubsetIntoVectorTransferOps.cpp | 67 +++++--- .../Common/ForOpCanonicalizationPass.cpp | 9 +- .../Codegen/Common/GenericVectorization.cpp | 15 +- .../HoistStaticallyBoundAllocations.cpp | 3 +- .../HoistUnrolledVectorExtractInsertSlice.cpp | 30 ++-- .../Common/IREECodegenCanonicalizer.cpp | 9 +- .../Common/IREEComprehensiveBufferizePass.cpp | 27 ++-- .../Common/IREEExpandStridedMetadata.cpp | 9 +- .../Codegen/Common/LinkTuningSpecsPass.cpp | 3 +- .../Common/LowerUKernelDescriptors.cpp | 3 +- .../Codegen/Common/MaterializeEncoding.cpp | 3 +- .../Common/MaterializeEncodingPatterns.cpp | 12 +- .../Common/MaterializeTuningSpecsPass.cpp | 3 +- .../Codegen/Common/MathTransformPass.cpp | 3 +- .../Codegen/Common/MemrefCopyToLinalg.cpp | 3 +- .../Codegen/Common/NormalizeLoopBounds.cpp | 3 +- .../OptimizeTensorInsertExtractSlices.cpp | 10 +- .../Common/OptimizeVectorTransferPass.cpp | 3 +- .../Codegen/Common/PadDynamicAlloc.cpp | 18 ++- .../Common/PropagateConstantOffsets.cpp | 9 +- .../Common/PropagateReshapesByExpansion.cpp | 43 +++-- .../Codegen/Common/ReshapePatterns.cpp | 12 +- .../Common/StripCompilationInfoPass.cpp | 6 +- .../Common/TensorDynamicDimAnalysis.cpp | 6 +- .../Common/TensorToVectorVectorizePad.cpp | 18 ++- .../Common/TestExecutablePreprocessing.cpp | 3 +- .../TileAndDistributeToWorkgroupsPass.cpp | 9 +- .../Codegen/Common/TileAndFuseUtils.cpp | 29 ++-- .../Common/TileDispatchUsingForall.cpp | 6 +- .../Common/TileDispatchUsingInterface.cpp | 24 ++- .../TransformDialectInterpreterPass.cpp | 3 +- .../TransformExtensions/CommonExtensions.cpp | 81 ++++++---- .../compiler/Codegen/Common/Transforms.cpp | 9 +- .../Codegen/Common/TypePropagationPass.cpp | 9 +- .../compiler/Codegen/Common/UserConfig.cpp | 3 +- .../Codegen/Transforms/ReshapeFusion.cpp | 6 +- .../Codegen/Dialect/Codegen/Utils/Utils.cpp | 9 +- .../Interfaces/BufferizationInterfaces.cpp | 51 ++++-- ...ffineMinDistributedSCFCanonicalization.cpp | 33 ++-- .../Codegen/Transforms/Transforms.cpp | 54 ++++--- .../compiler/Codegen/Transforms/Transforms.h | 6 +- .../iree/compiler/Codegen/Utils/CPUUtils.cpp | 3 +- .../compiler/Codegen/Utils/EncodingUtils.cpp | 3 +- .../iree/compiler/Codegen/Utils/GPUUtils.cpp | 148 ++++++++++++------ .../compiler/Codegen/Utils/LinalgOpInfo.cpp | 3 +- .../compiler/Codegen/Utils/LinkingUtils.cpp | 12 +- .../compiler/Codegen/Utils/MarkerUtils.cpp | 3 +- .../iree/compiler/Codegen/Utils/MarkerUtils.h | 3 +- .../src/iree/compiler/Codegen/Utils/Utils.cpp | 87 ++++++---- 68 files changed, 813 insertions(+), 410 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp index a4e8a0f8606b..fc1a56a7cc6f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -75,12 +75,14 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, } for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { - if (!tensorType.isDynamicDim(index)) + if (!tensorType.isDynamicDim(index)) { continue; + } std::optional dimDivisibility = dynamicDimAnalysis.getDivisibilityInfo(v, index); - if (!dimDivisibility) + if (!dimDivisibility) { continue; + } divisibilityInfo[index] = std::move(dimDivisibility.value()); } @@ -191,14 +193,17 @@ static LogicalResult blockDynamicDimensions( Operation *operation, llvm::SmallDenseSet limitToOperandNumbers, llvm::SmallDenseSet limitToResultNumbers) { for (OpOperand &operand : operation->getOpOperands()) { - if (!limitToOperandNumbers.contains(operand.getOperandNumber())) + if (!limitToOperandNumbers.contains(operand.getOperandNumber())) { continue; - if (operand.get().getDefiningOp()) + } + if (operand.get().getDefiningOp()) { continue; + } TensorDivisibilityInfo operandDivisibilityInfo = getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get()); - if (operandDivisibilityInfo.empty()) + if (operandDivisibilityInfo.empty()) { continue; + } std::optional reshapes = blockDynamicDimensionsOfValue( rewriter, operandDivisibilityInfo, operand.get()); if (reshapes) { @@ -210,12 +215,14 @@ static LogicalResult blockDynamicDimensions( OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(operation); for (OpResult result : operation->getResults()) { - if (!limitToResultNumbers.contains(result.getResultNumber())) + if (!limitToResultNumbers.contains(result.getResultNumber())) { continue; + } TensorDivisibilityInfo resultDivisibilityInfo = getTensorDivisibilityInfo(dynamicDimAnalysis, result); - if (resultDivisibilityInfo.empty()) + if (resultDivisibilityInfo.empty()) { continue; + } std::optional reshapes = blockDynamicDimensionsOfValue(rewriter, resultDivisibilityInfo, result); if (reshapes) { diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp index db3f6fd95c83..c10f29fbbc42 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp @@ -88,8 +88,9 @@ static bool isFromReadOnlyTensor(Value v, const BufferizationPlan &plan) { /// here). static LogicalResult analyseConstantOp(arith::ConstantOp constantOp, BufferizationPlan &plan) { - if (!isa(constantOp.getResult().getType())) + if (!isa(constantOp.getResult().getType())) { return success(); + } plan.insert(constantOp.getResult()); return success(); } @@ -112,12 +113,14 @@ static OpType getEquivalentOpOfType(Value value, BufferizationPlan &plan) { SmallVector mappedTensors = plan.getTensorsMappedToSameSet(value); for (auto v : mappedTensors) { auto definingOp = v.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } assert((!equivalentOp || equivalentOp == definingOp) && "found two interface binding ops marked as equivalent"); - if (!equivalentOp) + if (!equivalentOp) { equivalentOp = definingOp; + } } return equivalentOp; } @@ -252,12 +255,14 @@ getTiedOperandsForDPSOps(DestinationStyleOpInterface dpsOp, /// same equivalence class. static LogicalResult analyseDPSOps(DestinationStyleOpInterface dpsOp, BufferizationPlan &plan) { - if (!dpsOp.hasPureTensorSemantics()) + if (!dpsOp.hasPureTensorSemantics()) { return success(); + } auto results = dpsOp->getResults(); auto tiedOperands = getTiedOperandsForDPSOps(dpsOp, plan); - if (tiedOperands.empty()) + if (tiedOperands.empty()) { return failure(); + } for (auto [index, resultTensor, tiedOperand] : llvm::zip_equal( llvm::seq(0, results.size()), results, tiedOperands)) { if (tiedOperand) { @@ -328,13 +333,15 @@ static LogicalResult analyseDestructiveUpdateOp(Operation *op, Value source, } static LogicalResult analyseScfIfOp(scf::IfOp ifOp, BufferizationPlan &plan) { - if (!ifOp.getNumResults()) + if (!ifOp.getNumResults()) { return success(); + } for (auto [result, thenOperand, elseOperand] : llvm::zip_equal(ifOp.getResults(), ifOp.thenYield().getOperands(), ifOp.elseYield().getOperands())) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } // All results and yields of the if-then-else are tied together. plan.unionSets(result, thenOperand); plan.unionSets(result, elseOperand); @@ -344,8 +351,9 @@ static LogicalResult analyseScfIfOp(scf::IfOp ifOp, BufferizationPlan &plan) { static LogicalResult analyseScfForOp(scf::ForOp forOp, BufferizationPlan &plan) { - if (forOp.getResults().empty()) + if (forOp.getResults().empty()) { return success(); + } if (!llvm::all_of(forOp->getResultTypes(), [](Type resultType) { return isa(resultType); })) { @@ -406,8 +414,9 @@ static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) { for (OpOperand &use : source.getUses()) { auto user = use.getOwner(); // Process only update ops uses here. - if (!isUpdateOp(user)) + if (!isUpdateOp(user)) { continue; + } // If this is not the first use in a tensor::InsertSliceOp abort. if (updateOp) { return; @@ -432,8 +441,9 @@ static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) { Block *updateOpBlock = updateOp->getBlock(); for (OpOperand &use : source.getUses()) { Operation *user = use.getOwner(); - if (user == updateOp) + if (user == updateOp) { continue; + } if (isReadOp(user)) { Value source = getSource(user); assert(source && "unable to find source from read op"); @@ -494,8 +504,9 @@ void BufferizationPlan::dump() { unsigned numSets = 0; for (auto it = mappedTensors.begin(), ie = mappedTensors.end(); it != ie; ++it) { - if (!(*it)->isLeader()) + if (!(*it)->isLeader()) { continue; + } llvm::dbgs() << "\tSet " << numSets; if (storeLeaders.count( getLeaderValue(getValue(*mappedTensors.member_begin(**it))))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h index 6523d7771b0a..7dc3015cefd2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h +++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.h @@ -62,8 +62,9 @@ class BufferizationPlan { /// the dispatch region. bool isInStoreSet(Value v) { Value leader = getLeaderValue(v); - if (!leader) + if (!leader) { return false; + } return storeLeaders.count(leader); } diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp index bc8d6a07830e..d77447291b7b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp @@ -63,11 +63,13 @@ void BufferizeCopyOnlyDispatchesPass::runOnOperation() { hasDispatchStore = true; return success(isReadOnly(storeOp.getValue())); }); - if (walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) { return; + } // The function is just a copy and is not yet bufferized. - if (!hasDispatchStore) + if (!hasDispatchStore) { return; + } // Apply the bufferization passes. std::optional maybeBufferizationPipeline = diff --git a/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp b/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp index d4b782b67eb9..6458cb6f2702 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConcretizePadResultShape.cpp @@ -32,8 +32,9 @@ static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, Location loc) { IntegerAttr attr; if (Value val = dyn_cast(attrOrValue)) { - if (val.getType().isIndex()) + if (val.getType().isIndex()) { return val; + } matchPattern(val, m_Constant(&attr)); } else { attr = cast(cast(attrOrValue)); @@ -52,8 +53,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { // If the result shape is already static, then nothing to do. - if (padOp.getResultType().hasStaticShape()) + if (padOp.getResultType().hasStaticShape()) { return failure(); + } int rank = padOp.getResultType().getRank(); SmallVector staticShape; @@ -61,8 +63,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern { auto sourceIfxOp = dyn_cast_if_present( padOp.getSource().getDefiningOp()); - if (!sourceIfxOp) + if (!sourceIfxOp) { return failure(); + } SmallVector lowPad = padOp.getMixedLowPad(); SmallVector source = sourceIfxOp.getMixedSizes(); @@ -111,8 +114,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern { affine::canonicalizeMapAndOperands(&map, &valueSizes); cstExpr = dyn_cast(map.getResult(0)); } - if (!cstExpr) + if (!cstExpr) { return failure(); + } staticShape.push_back(cstExpr.getValue()); } diff --git a/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp b/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp index 3ff9ece926f7..71e3b933b9f3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConfigTrackingCanonicalizer.cpp @@ -77,10 +77,12 @@ struct ConfigTrackingCanonicalizerPass final GreedySimplifyRegionLevel::Normal); RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) + for (auto *dialect : context->getLoadedDialects()) { dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) + } + for (RegisteredOperationName op : context->getRegisteredOperations()) { op.getCanonicalizationPatterns(owningPatterns, context); + } patterns = std::make_shared(std::move(owningPatterns)); diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp index fd33f84de711..b7941bcb2b7e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp @@ -46,8 +46,9 @@ Value convertRankedFloat(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { Type eTy = getElementTypeOrSelf(type); Type inputETy = getElementTypeOrSelf(inputs[0].getType()); - if (!isa(getElementTypeOrSelf(type))) + if (!isa(getElementTypeOrSelf(type))) { return nullptr; + } if (inputETy.getIntOrFloatBitWidth() > eTy.getIntOrFloatBitWidth()) { return arith::TruncFOp::create(builder, loc, type, inputs[0]); @@ -66,8 +67,9 @@ struct PrimitiveTypeConverter : public TypeConverter { explicit PrimitiveTypeConverter() { addConversion([](Type type) { return type; }); addConversion([&](SourceType type) -> Type { - if (!isSourceType(type)) + if (!isSourceType(type)) { return type; + } return getTargetType(type); }); addConversion([&](ComplexType type) { @@ -262,16 +264,19 @@ struct ConvertBf16ArithToF32Pass final auto checkOp = [&](Operation *op) { for (Type type : op->getResultTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (auto ®ion : op->getRegions()) { - if (!typeConverter.isLegal(®ion)) + if (!typeConverter.isLegal(®ion)) { return false; + } } return true; }; diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp index 94b32b9db14d..458ea0d894de 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp @@ -48,8 +48,9 @@ class Bf16EmulationConverter : public TypeConverter { // Scalar case. addConversion([](FloatType ty) -> std::optional { - if (ty.isBF16()) + if (ty.isBF16()) { return IntegerType::get(ty.getContext(), 16); + } return ty; }); @@ -59,12 +60,14 @@ class Bf16EmulationConverter : public TypeConverter { addConversion([this](FunctionType ty) -> std::optional { SmallVector inputs; - if (failed(convertTypes(ty.getInputs(), inputs))) + if (failed(convertTypes(ty.getInputs(), inputs))) { return std::nullopt; + } SmallVector results; - if (failed(convertTypes(ty.getResults(), results))) + if (failed(convertTypes(ty.getResults(), results))) { return std::nullopt; + } return FunctionType::get(ty.getContext(), inputs, results); }); @@ -82,10 +85,11 @@ struct ConvertHalInterfaceBindingSubspan final matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResultTy = getTypeConverter()->convertType(op.getType()); - if (!newResultTy) + if (!newResultTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to legalize memref type: {}", op.getType())); + } auto newOp = rewriter.replaceOpWithNewOp( @@ -105,10 +109,11 @@ struct ConvertMemRefAlloc final : OpConversionPattern { matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) + if (!newTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {}", op.getType())); + } rewriter.replaceOpWithNewOp( op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), @@ -191,10 +196,11 @@ struct ConvertMemRefLoad final : OpConversionPattern { matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResTy = getTypeConverter()->convertType(op.getType()); - if (!newResTy) + if (!newResTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {}", op.getMemRefType())); + } rewriter.replaceOpWithNewOp( op, newResTy, adaptor.getMemref(), adaptor.getIndices(), @@ -210,10 +216,11 @@ struct ConvertMemRefStore final : OpConversionPattern { matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newTy = getTypeConverter()->convertType(op.getMemRefType()); - if (!newTy) + if (!newTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {}", op.getMemRefType())); + } rewriter.replaceOpWithNewOp( op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), @@ -327,8 +334,9 @@ struct ConvertBf16ToUInt16BuffersPass final RewritePatternSet patterns(ctx); populateIreeBf16EmulationPatterns(patterns, typeConverter); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp index 46c140f1acd6..d718f83ef523 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp @@ -147,8 +147,9 @@ walkUseToGetDispatchStoreOp(Value value, const BufferizationPlan &plan, return user; } value = getTiedResultForOperand(use, plan); - if (!value) + if (!value) { return nullptr; + } traversedUses.push_back(&use); } // If the value has a use which is a store, then use that directly. @@ -271,8 +272,9 @@ convertToDestinationPassingStyle(OpBuilder &b, auto walkResult = funcOp.walk( [&](tensor::EmptyOp emptyOp) -> WalkResult { for (auto result : emptyOp->getResults()) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } if (plan.isInStoreSet(result) && !processed.count(result)) { return modifyResultToUseStoreBuffer(b, result, plan, processed); } @@ -291,20 +293,23 @@ canUseInOperandAsInitOperand(OpOperand *inOperand, OpOperand *initOperand, return false; } - if (inOperand->getOwner() != initOperand->getOwner()) + if (inOperand->getOwner() != initOperand->getOwner()) { return false; + } auto linalgOp = dyn_cast(inOperand->getOwner()); - if (!linalgOp) + if (!linalgOp) { return false; + } if (linalgOp.getMatchingIndexingMap(inOperand) != linalgOp.getMatchingIndexingMap(initOperand)) { return false; } - if (inOperand->get().getType() != initOperand->get().getType()) + if (inOperand->get().getType() != initOperand->get().getType()) { return false; + } if (useWARForCooperativeMatrixCodegen) { return true; @@ -330,8 +335,9 @@ canModifyUseToGetValueIntoStoreSet(BufferizationPlan &plan, OpOperand *use, // Currently only look at use in linalg.generic ops. auto genericOpConsumer = dyn_cast(use->getOwner()); - if (!genericOpConsumer) + if (!genericOpConsumer) { return std::nullopt; + } // All loops need to be parallel. if (genericOpConsumer.getNumLoops() != @@ -339,17 +345,20 @@ canModifyUseToGetValueIntoStoreSet(BufferizationPlan &plan, OpOperand *use, return std::nullopt; } - if (genericOpConsumer.isDpsInit(use)) + if (genericOpConsumer.isDpsInit(use)) { return std::nullopt; + } for (auto [index, initOperand] : llvm::enumerate(genericOpConsumer.getDpsInitsMutable())) { // Output tensor is unused in the body computation. - if (genericOpConsumer.payloadUsesValueFromOperand(&initOperand)) + if (genericOpConsumer.payloadUsesValueFromOperand(&initOperand)) { continue; + } // The result of this operation needs to be in a store set. - if (!plan.isInStoreSet(genericOpConsumer->getResult(index))) + if (!plan.isInStoreSet(genericOpConsumer->getResult(index))) { continue; + } if (!canUseInOperandAsInitOperand(use, &initOperand, useWARForCooperativeMatrixCodegen)) { continue; @@ -441,8 +450,9 @@ static LogicalResult adaptComputeConsumerToAvoidStackAllocation( [&](TilingInterface computeOp) -> WalkResult { for (auto result : computeOp->getResults()) { // If result is already in a store set. Nothing to do. - if (plan.isInStoreSet(result)) + if (plan.isInStoreSet(result)) { continue; + } // Check if there are any uses that can be modified to reuse the output // buffer. @@ -450,11 +460,13 @@ static LogicalResult adaptComputeConsumerToAvoidStackAllocation( std::optional reusableOperand = canModifyUseToGetValueIntoStoreSet( plan, &use, useWARForCooperativeMatrixCodegen); - if (!reusableOperand) + if (!reusableOperand) { continue; - if (failed(modifyUseToGetValueIntoStoreSet(rewriter, &use, - reusableOperand.value()))) + } + if (failed(modifyUseToGetValueIntoStoreSet( + rewriter, &use, reusableOperand.value()))) { continue; + } return WalkResult::interrupt(); } } @@ -486,8 +498,9 @@ replaceUnpackEmptyWithAllocTensor(OpBuilder &b, return; } auto emptyOp = unpackOp.getDest().getDefiningOp(); - if (!emptyOp) + if (!emptyOp) { return; + } OpBuilder::InsertionGuard g(b); b.setInsertionPointAfter(emptyOp); @@ -511,13 +524,16 @@ struct RemoveCstOutsDependency Location loc = op.getLoc(); for (OpOperand &opOperand : op.getDpsInitsMutable()) { ElementsAttr attr; - if (!matchPattern(opOperand.get(), m_Constant(&attr))) + if (!matchPattern(opOperand.get(), m_Constant(&attr))) { continue; - if (!attr.isSplat()) + } + if (!attr.isSplat()) { continue; + } auto type = dyn_cast(attr.getType()); - if (!type) + if (!type) { continue; + } TypedAttr scalarAttr = attr.getValues()[0]; modifiedOutput = true; diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp index f7dcced3766a..88107abf00bc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeConvolutionToLowerDimOps.cpp @@ -72,19 +72,22 @@ computeDecomposedLoweringConfig(ArrayRef computeOps, // ATM only folding of the H dim is supported. // TODO: Add support for cases where the W dim is folded. - if (!foldHDim(convOp)) + if (!foldHDim(convOp)) { return failure(); + } // 2. Get the current lowering config attached to the Conv Op. FailureOr loweringConfigAttr = getFirstLoweringConfig(computeOps); - if (failed(loweringConfigAttr)) + if (failed(loweringConfigAttr)) { return failure(); + } // TODO: Either remove "interchange" from lowering_config or add support in // this pass. - if (!loweringConfigAttr->isInterchangeEmpty()) + if (!loweringConfigAttr->isInterchangeEmpty()) { return failure(); + } // 3. Calculate new tiling levels. // Note that this will basically erase the _H_ dims from the orignal lowering @@ -159,8 +162,9 @@ class DecomposeConvolutionToLowerDimOpsPass final if (numConvOps == 1 && succeeded(newLoweringConfig)) { auto computeOps = getComputeOps(funcOp); for (auto computeOp : computeOps) { - if (isa(computeOp)) + if (isa(computeOp)) { setLoweringConfig(computeOp, newLoweringConfig.value()); + } } } } diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index 154bbd313dd5..736e69212536 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -169,8 +169,9 @@ static LogicalResult commonRunOnOperation( scf::tileConsumerAndFuseProducersUsingSCF( rewriter, cast(op.getOperation()), packOptions); - if (failed(tileAndFuseResult)) + if (failed(tileAndFuseResult)) { return WalkResult::interrupt(); + } rewriter.replaceOp(op, tileAndFuseResult->replacements[op.getResult()]); return WalkResult::advance(); }); @@ -203,8 +204,9 @@ static LogicalResult commonRunOnOperation( FailureOr tilingResult = scf::tileUsingSCF( rewriter, cast(op.getOperation()), unpackTilingOptions); - if (failed(tilingResult)) + if (failed(tilingResult)) { return WalkResult::interrupt(); + } rewriter.replaceOp(op, tilingResult->replacements); return WalkResult::advance(); }); diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp index bcdb06e0f3b3..3e3ca38bed75 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposeSoftmax.cpp @@ -43,8 +43,9 @@ struct FuseElementWiseGenericOps : public OpRewritePattern { // Find the first operand that is defined by another generic op on tensors. for (OpOperand &opOperand : genericOp->getOpOperands()) { - if (!linalg::areElementwiseOpsFusable(&opOperand)) + if (!linalg::areElementwiseOpsFusable(&opOperand)) { continue; + } // Don't fuse if it has external capture. For e.g., the gather like // payload operation like 'tensor.extract' would be cloned in // every consumer op, which is not what we want. diff --git a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp index 82ec1b7b8d0a..8bede523de65 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp @@ -130,8 +130,9 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, // When extracting all available elements, just use the source vector as the // result. - if (vectorType.getNumElements() == numElemsToExtract) + if (vectorType.getNumElements() == numElemsToExtract) { return src; + } auto offsets = rewriter.getI64ArrayAttr({offset}); auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract}); @@ -160,8 +161,9 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, "expected source and dest to be rank-1 vector types"); // If overwritting the destination vector, just return the source. - if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0) + if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0) { return src; + } auto offsets = rewriter.getI64ArrayAttr({offset}); auto strides = rewriter.getI64ArrayAttr({1}); @@ -344,9 +346,10 @@ struct IREEConvertVectorStore final : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // See #115653 - if (op.getValueToStore().getType().getRank() != 1) + if (op.getValueToStore().getType().getRank() != 1) { return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported ATM"); + } auto loc = op.getLoc(); diff --git a/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp b/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp index ae957c479fa9..8d978d85c473 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/EraseHALDescriptorTypeFromMemRef.cpp @@ -46,8 +46,9 @@ struct EraseHALDescriptorTypeFromMemRefPass final AttrTypeReplacer replacer; replacer.addReplacement( [](BaseMemRefType memRefType) -> std::optional { - if (isLegalType(memRefType)) + if (isLegalType(memRefType)) { return std::nullopt; + } // Erase the #hal.descriptor_type memory space. if (auto rankedType = dyn_cast(memRefType)) { @@ -74,8 +75,9 @@ struct ConvertHALDescriptorTypeToGPUAddressSpacePass final AttrTypeReplacer replacer; replacer.addReplacement( [](BaseMemRefType memRefType) -> std::optional { - if (isLegalType(memRefType)) + if (isLegalType(memRefType)) { return std::nullopt; + } Attribute globalSpace = gpu::AddressSpaceAttr::get( memRefType.getContext(), gpu::AddressSpace::Global); diff --git a/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h b/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h index e18de97bae05..69ecb1b09d29 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h +++ b/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.h @@ -40,8 +40,9 @@ struct StoreLoadLikeOpRewriter : public OpRewritePattern { auto ldTy = cast(srcMemRef.getType()); unsigned storeLoadRank = ldTy.getRank(); // Don't waste compile time if there is nothing to rewrite. - if (storeLoadRank == 0) + if (storeLoadRank == 0) { return failure(); + } // If our load already has only zeros as indices there is nothing // to do. diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index 814b73300be3..8e160b8fc4ec 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -157,8 +157,9 @@ struct FlattenAlloc final : public OpConversionPattern { matchAndRewrite(AllocOpTy allocOp, typename AllocOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldType = dyn_cast(allocOp.getType()); - if (!oldType || !oldType.getLayout().isIdentity()) + if (!oldType || !oldType.getLayout().isIdentity()) { return failure(); + } Value dynamicDim = createTotalElementCountValue( oldType, allocOp.getDynamicSizes(), allocOp.getLoc(), rewriter); @@ -176,8 +177,9 @@ struct FlattenGlobal final : public OpConversionPattern { using Base::Base; static Attribute flattenAttribute(Attribute value, ShapedType newType) { - if (!value) + if (!value) { return value; + } if (auto splatAttr = dyn_cast(value)) { return splatAttr.reshape(newType); } else if (auto denseAttr = dyn_cast(value)) { @@ -194,8 +196,9 @@ struct FlattenGlobal final : public OpConversionPattern { matchAndRewrite(memref::GlobalOp globalOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldType = dyn_cast(globalOp.getType()); - if (!oldType || !oldType.getLayout().isIdentity()) + if (!oldType || !oldType.getLayout().isIdentity()) { return failure(); + } auto tensorType = RankedTensorType::get({oldType.getNumElements()}, oldType.getElementType()); @@ -221,13 +224,15 @@ struct FlattenGetGlobal final matchAndRewrite(memref::GetGlobalOp getOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldType = dyn_cast(getOp.getType()); - if (!oldType || !oldType.getLayout().isIdentity()) + if (!oldType || !oldType.getLayout().isIdentity()) { return failure(); + } auto globalOp = dyn_cast_if_present( SymbolTable::lookupNearestSymbolFrom(getOp, getOp.getNameAttr())); - if (!globalOp) + if (!globalOp) { return failure(); + } auto loadedValue = rewriter.createOrFold( getOp.getLoc(), globalOp.getType(), getOp.getNameAttr()); @@ -250,8 +255,9 @@ struct FlattenBindingSubspan final auto oldType = dyn_cast(subspanOp.getType()); // IREE subspan ops only use memref types with the default identity // layout maps. - if (!oldType) + if (!oldType) { return failure(); + } OpFoldResult linearShape; if (oldType.hasStaticShape()) { @@ -441,8 +447,9 @@ struct FlattenSubView final : public OpConversionPattern { } Type neededResultType = getTypeConverter()->convertType(op.getResult().getType()); - if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType)) + if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType)) { return failure(); + } Value size = createTotalElementCountValue(op.getType(), op.getSizes(), op.getLoc(), rewriter); SmallVector offsets = mlir::getValueOrCreateConstantIndexOp( @@ -651,13 +658,15 @@ struct AdjustConversionCast final LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (castOp->getNumOperands() != 1) + if (castOp->getNumOperands() != 1) { return failure(); + } Value input = adaptor.getOperands().front(); // We only want to handle cases where the cast op handles memref types. - if (!isa(input.getType())) + if (!isa(input.getType())) { return failure(); + } if (!isRankZeroOrOneMemRef(input.getType())) { return rewriter.notifyMatchFailure( @@ -695,8 +704,9 @@ struct FoldMemRefReshape final : public OpConversionPattern { Type newSourceType = adaptor.getSrc().getType(); Type neededResultType = typeConverter->convertType(op.getResult().getType()); - if (!neededResultType) + if (!neededResultType) { return failure(); + } if (newSourceType == neededResultType) { rewriter.replaceOp(op, adaptor.getSrc()); return success(); @@ -769,8 +779,9 @@ struct FlattenMemRefSubspanPass final [](MemRefType type) -> std::optional { // 0-D MemRef types can be used to represent raw pointers for // micro-kernel ABI purposes. Specially allow it. - if (isRankZeroMemRef(type)) + if (isRankZeroMemRef(type)) { return type; + } // Fall back to the default conversion flow. return std::nullopt; @@ -786,8 +797,9 @@ struct FlattenMemRefSubspanPass final internalTypeConverter.addConversion( [](MemRefType type) -> std::optional { // 0-D or 1-D MemRef types are okay. - if (isRankZeroOrOneMemRef(type)) + if (isRankZeroOrOneMemRef(type)) { return type; + } // Fall back to the default conversion flow. return std::nullopt; @@ -857,8 +869,9 @@ struct FlattenMemRefSubspanPass final }); target.addDynamicallyLegalOp( [](UnrealizedConversionCastOp castOp) { - if (castOp->getNumOperands() != 1) + if (castOp->getNumOperands() != 1) { return false; + } Type inputType = castOp->getOperandTypes().front(); return !isa(inputType) || diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp index d880ff25396e..9d29253fb12b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp @@ -45,8 +45,9 @@ static OpFoldResult computeProduct(Location loc, OpBuilder &builder, SmallVector dynamicPart; AffineExpr result = builder.getAffineConstantExpr(1); for (OpFoldResult term : terms) { - if (!term) + if (!term) { return term; + } std::optional maybeConst = getConstantIntValue(term); if (maybeConst) { result = result * builder.getAffineConstantExpr(*maybeConst); @@ -55,8 +56,9 @@ static OpFoldResult computeProduct(Location loc, OpBuilder &builder, result = result * builder.getAffineSymbolExpr(nDynamic++); } } - if (auto constant = dyn_cast(result)) + if (auto constant = dyn_cast(result)) { return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); + } return affine::AffineApplyOp::create(builder, loc, result, dynamicPart) .getResult(); } @@ -245,9 +247,10 @@ struct MemRefRewritePatternBase : public OpRewritePattern { LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { Value memref = getTargetMemref(op); - if (!needFlattenning(memref) || !checkLayout(memref)) + if (!needFlattenning(memref) || !checkLayout(memref)) { return rewriter.notifyMatchFailure(op, "nothing to do or unsupported layout"); + } auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( rewriter, op->getLoc(), memref, op.getIndices()); replaceOp(op, rewriter, flatMemref, offset); @@ -301,11 +304,13 @@ struct FlattenSubview : public OpRewritePattern { LogicalResult matchAndRewrite(memref::SubViewOp op, PatternRewriter &rewriter) const override { Value memref = op.getSource(); - if (!needFlattenning(memref)) + if (!needFlattenning(memref)) { return rewriter.notifyMatchFailure(op, "nothing to do"); + } - if (!checkLayout(memref)) + if (!checkLayout(memref)) { return rewriter.notifyMatchFailure(op, "unsupported layout"); + } Location loc = op.getLoc(); SmallVector subOffsets = op.getMixedOffsets(); @@ -327,8 +332,9 @@ struct FlattenSubview : public OpRewritePattern { finalStrides.reserve(subRank); for (auto i : llvm::seq(0u, static_cast(srcType.getRank()))) { - if (droppedDims.test(i)) + if (droppedDims.test(i)) { continue; + } finalSizes.push_back(subSizes[i]); finalStrides.push_back(strides[i]); @@ -354,8 +360,9 @@ struct DecomposeMemrefsPass mlir::iree_compiler::populateDecomposeMemrefsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); + } } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp index 34fb202ead83..3f3ea86d8aad 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldAffineMinInDistributedLoops.cpp @@ -57,8 +57,9 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, rewriter.setInsertionPoint(op); FailureOr simplified = mlir::affine::simplifyConstrainedMinMaxOp(op, std::move(constraints)); - if (failed(simplified)) + if (failed(simplified)) { return failure(); + } return rewriter.replaceOpWithNewOp( op, simplified->getAffineMap(), simplified->getOperands()); } @@ -89,22 +90,26 @@ struct FoldAffineMinOverDistributedLoopInductionVariable final auto loopMatcher = [&](Value iv, OpFoldResult &lb, OpFoldResult &ub, OpFoldResult &step) { scf::ForOp forOp = scf::getForInductionVarOwner(iv); - if (!forOp) + if (!forOp) { return failure(); + } auto loopInfo = isTiledAndDistributedLoop(forOp); - if (!loopInfo) + if (!loopInfo) { return failure(); + } LLVM_DEBUG(llvm::dbgs() << *loopInfo); std::optional untiledStep = getConstantIntValue(loopInfo->untiledStep); // For IREE right now the original untiled loop should have step 1.. - if (!untiledStep || *untiledStep != 1) + if (!untiledStep || *untiledStep != 1) { return failure(); + } // ..and we tile according to some static tile sizes for processors. - if (!loopInfo->tileSize) + if (!loopInfo->tileSize) { return failure(); + } lb = loopInfo->untiledLowerBound; ub = loopInfo->untiledUpperBound; @@ -132,17 +137,21 @@ struct FoldAffineMinOverWorkgroupIDs final // Find all iteration variables among `minOp`'s operands add constrain them. for (Value operand : minOp->getOperands()) { // Skip duplicate ids. - if (!allIds.insert(operand).second) + if (!allIds.insert(operand).second) { continue; + } auto idOp = operand.getDefiningOp(); - if (!idOp) + if (!idOp) { continue; + } // Can't infer the range when workroupCount is unknown. unsigned index = idOp.getDimension().getZExtValue(); - if (index >= numWorkgroup.size()) + if (index >= numWorkgroup.size()) { return failure(); - if (numWorkgroup[index] == ShapedType::kDynamic) + } + if (numWorkgroup[index] == ShapedType::kDynamic) { continue; + } constraints.appendDimVar({idOp}); constraints.addBound(presburger::BoundType::LB, idOp, 0); constraints.addBound(presburger::BoundType::UB, idOp, diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp index c4cc7e2c77a5..7ddd611d8a2e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorExtractOpPass.cpp @@ -59,7 +59,8 @@ class FoldTensorExtractOpPass final void FoldTensorExtractOpPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp index 4e32b7bf1c6b..0041574d3a04 100644 --- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp @@ -22,8 +22,9 @@ using namespace mlir; static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, unsigned trailingRank) { // If no ranks are reduced at all, it's a degenerated case; always true. - if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) + if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) { return true; + } RankedTensorType inferredType = extractOp.inferResultType( extractOp.getSourceType(), extractOp.getMixedSizes()); @@ -57,19 +58,25 @@ class FoldExtractSliceIntoTransferRead final LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) + if (xferOp.getTransferRank() == 0) { return failure(); - if (xferOp.hasOutOfBoundsDim()) + } + if (xferOp.hasOutOfBoundsDim()) { return failure(); - if (!xferOp.getPermutationMap().isMinorIdentity()) + } + if (!xferOp.getPermutationMap().isMinorIdentity()) { return failure(); - if (xferOp.getMask()) + } + if (xferOp.getMask()) { return failure(); + } auto extractOp = xferOp.getBase().getDefiningOp(); - if (!extractOp) + if (!extractOp) { return failure(); - if (!extractOp.hasUnitStride()) + } + if (!extractOp.hasUnitStride()) { return failure(); + } // Bail on illegal rank-reduction: we need to check that the rank-reduced // dims are exactly the leading dims. I.e. the following is illegal: @@ -87,8 +94,10 @@ class FoldExtractSliceIntoTransferRead final // ``` // For this, check the trailing `vectorRank` dims of the extract_slice // result tensor match the trailing dims of the inferred result tensor. - if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) + if (!areAllRankReducedLeadingDim(extractOp, + extractOp.getType().getRank())) { return failure(); + } int64_t rankReduced = extractOp.getSourceType().getRank() - extractOp.getType().getRank(); @@ -132,12 +141,15 @@ class FoldExtractSliceIntoTransferRead final /// dynamic tensors, where it resolves the tensor sizes via value-bounds /// analysis, and then checks if the vector type fully overwrites the tensor. static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { - if (writeOp.hasOutOfBoundsDim()) + if (writeOp.hasOutOfBoundsDim()) { return false; - if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank()) + } + if (writeOp.getVectorType().getRank() != writeOp.getShapedType().getRank()) { return false; - if (writeOp.getMask()) + } + if (writeOp.getMask()) { return false; + } std::optional vscaleRange; auto vecType = writeOp.getVectorType(); @@ -155,8 +167,9 @@ static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { [&](unsigned dimIndex) -> FailureOr { auto size = destShape[dimIndex]; // Fixed-size dimensions are simply included in the shape. - if (size != ShapedType::kDynamic) + if (size != ShapedType::kDynamic) { return iree_compiler::DimBoundSize{size}; + } // (Attempt to) resolve dynamic dimensions via value-bounds analysis. return iree_compiler::computeDimUpperBound(dest, dimIndex, vscaleRange); }; @@ -165,12 +178,15 @@ static bool isDestinationFullyOverwritten(vector::TransferWriteOp writeOp) { ArrayRef vecScalableFlags = vecType.getScalableDims(); for (unsigned d = 0, e = destShape.size(); d < e; ++d) { auto dimSize = resolveDestinationDimSize(d); - if (failed(dimSize)) + if (failed(dimSize)) { return false; - if (dimSize->scalable && !vecScalableFlags[d]) + } + if (dimSize->scalable && !vecScalableFlags[d]) { return false; - if (vecShape[d] != dimSize->baseSize) + } + if (vecShape[d] != dimSize->baseSize) { return false; + } } return true; } @@ -198,23 +214,28 @@ class FoldInsertSliceIntoTransferWrite final LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { - if (!insertOp.hasUnitStride()) + if (!insertOp.hasUnitStride()) { return failure(); + } auto xferOp = insertOp.getSource().getDefiningOp(); - if (!xferOp) + if (!xferOp) { return failure(); + } // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) + if (xferOp.getTransferRank() == 0) { return failure(); - if (!xferOp.getPermutationMap().isIdentity()) + } + if (!xferOp.getPermutationMap().isIdentity()) { return failure(); + } // Fold only if the TransferWriteOp completely overwrites the `source` with // a vector. I.e., the result of the TransferWriteOp is a new tensor whose // content is the data of the vector. - if (!isDestinationFullyOverwritten(xferOp)) + if (!isDestinationFullyOverwritten(xferOp)) { return failure(); + } // Bail on illegal rank-reduction: we need to check that the rank-reduced // dims are exactly the leading dims. I.e. the following is illegal: @@ -241,8 +262,9 @@ class FoldInsertSliceIntoTransferWrite final auto actualSourceTensorShape = insertOp.getSourceType().getShape(); if (rankReduced > 0 && actualSourceTensorShape.take_back(vectorRank) != - inferredSourceTensorType.getShape().take_back(vectorRank)) + inferredSourceTensorType.getShape().take_back(vectorRank)) { return failure(); + } SmallVector indices = getValueOrCreateConstantIndexOp( rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); @@ -328,8 +350,9 @@ class FoldExtractSliceIntoTransferWrite final if (!maybeDestSize || !maybeIndex) { continue; } - if (vecSize + *maybeIndex <= *maybeDestSize) + if (vecSize + *maybeIndex <= *maybeDestSize) { inBounds[idx] = true; + } } rewriter.replaceOpWithNewOp( diff --git a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp index 9c11c2a30050..8f230dfaafd0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp @@ -158,8 +158,9 @@ struct CanonicalizeForOpInductionVarShape final mapping.map(loopIndVar, start); initArgs[index] = rewriter.clone(*finalIvUser, mapping)->getResult(0); } - if (iteratorFolded.empty()) + if (iteratorFolded.empty()) { return failure(); + } auto newLoop = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), @@ -230,8 +231,9 @@ struct PackForOpInductionVarVector final : public OpRewritePattern { targetTypes.push_back(targetType); } } - if (ivIndices.empty()) + if (ivIndices.empty()) { return failure(); + } // Bit cast all init values to the smaller vector (fewer elements). auto ivInitValues = llvm::to_vector<8>(forOp.getInitArgs()); @@ -287,8 +289,9 @@ struct PackForOpInductionVarVector final : public OpRewritePattern { yieldOp->setOperands(ivRetValues); SmallVector forRetValues; - for (Value result : newLoop.getResults()) + for (Value result : newLoop.getResults()) { forRetValues.push_back(result); + } // Bit cast return values to the old type to fix for op uses. rewriter.setInsertionPointAfter(newLoop); diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index a6114b60cc0a..fe13e195b271 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -96,8 +96,9 @@ getVectorSizes(Operation *op, bool useConfiguredVectorSizes) { auto ty = padOp.getResultType(); // TODO(hanchung): Infer the vector sizes for pad op after // maskedVectorize method allows dynamic result shapes. - if (!ty.hasStaticShape()) + if (!ty.hasStaticShape()) { return; + } vectorSizes = SmallVector(ty.getShape()); }) .Case([&](IREE::LinalgExt::GatherOp gatherOp) { @@ -121,10 +122,12 @@ static LogicalResult isWithinVectorSizeLimit(linalg::LinalgOp linalgOp, int64_t maxFlatVecSize = 1; for (OpOperand &operand : linalgOp->getOpOperands()) { auto type = dyn_cast(operand.get().getType()); - if (!type) + if (!type) { continue; - if (!type.hasStaticShape()) + } + if (!type.hasStaticShape()) { return failure(); + } maxFlatVecSize = std::max(maxFlatVecSize, type.getNumElements()); } return success(maxFlatVecSize < maxVectorSize); @@ -183,11 +186,13 @@ void GenericVectorizationPass::runOnOperation() { // Do not vectorize the op if the vector size is greater than or equal // to limit. if (enableVectorMasking) { - if (llvm::product_of(vectorSizes) >= maxVectorSize) + if (llvm::product_of(vectorSizes) >= maxVectorSize) { continue; + } } else { - if (failed(isWithinVectorSizeLimit(linalgOp, maxVectorSize))) + if (failed(isWithinVectorSizeLimit(linalgOp, maxVectorSize))) { continue; + } } } // Pad scalable dims with `false` to match the vector sizes. diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp index 2fd1b1b6bf47..5879aedbc214 100644 --- a/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/HoistStaticallyBoundAllocations.cpp @@ -36,8 +36,9 @@ void HoistStaticallyBoundAllocationsPass::runOnOperation() { IRRewriter rewriter(funcOp->getContext()); std::optional vscaleRange; - if (this->vscaleMax != 0 && this->vscaleMin <= this->vscaleMax) + if (this->vscaleMax != 0 && this->vscaleMin <= this->vscaleMax) { vscaleRange = {this->vscaleMin, this->vscaleMax}; + } hoistStaticallyBoundAllocationsInFunc(rewriter, funcOp, vscaleRange); diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp index 9c2d0974d21e..ac8026e5d85a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/HoistUnrolledVectorExtractInsertSlice.cpp @@ -37,19 +37,22 @@ getUnrolledExtractSlices( SmallVector res; for (auto user : srcTensor.getUsers()) { auto extractStridedSliceOp = dyn_cast(user); - if (!extractStridedSliceOp) + if (!extractStridedSliceOp) { return failure(); + } res.push_back(extractStridedSliceOp); } - if (res.size() != insertOps.size()) + if (res.size() != insertOps.size()) { return failure(); + } std::reverse(res.begin(), res.end()); for (auto [extractOp, insertOp] : llvm::zip_equal(res, insertOps)) { auto offset0 = insertOp.getOffsets(); auto offset1 = extractOp.getOffsets(); - if (offset0 != offset1) + if (offset0 != offset1) { return failure(); + } } return res; @@ -72,8 +75,9 @@ getUnrolledInsertSlices(scf::ForOp forOp, BlockArgument bbArg, SmallVector res; Value v = yieldOperand.get(); auto insertStridedSliceOp = v.getDefiningOp(); - if (!insertStridedSliceOp) + if (!insertStridedSliceOp) { return failure(); + } ArrayRef vecShape = insertStridedSliceOp.getSourceVectorType().getShape(); @@ -81,8 +85,9 @@ getUnrolledInsertSlices(scf::ForOp forOp, BlockArgument bbArg, insertStridedSliceOp.getDestVectorType().getShape(); int numOps = 1; for (auto [vecSize, destSize] : llvm::zip_equal(vecShape, destShape)) { - if (destSize % vecSize) + if (destSize % vecSize) { return failure(); + } numOps *= destSize / vecSize; } @@ -91,19 +96,22 @@ getUnrolledInsertSlices(scf::ForOp forOp, BlockArgument bbArg, insertStridedSliceOp = insertStridedSliceOp.getDest() .getDefiningOp(); } - if (res.size() != numOps) + if (res.size() != numOps) { return failure(); + } std::reverse(res.begin(), res.end()); SmallVector expectedOffsets(vecShape.size(), 0); for (vector::InsertStridedSliceOp op : res) { SmallVector offsets = getI64SubArray(op.getOffsets()); - if (expectedOffsets != offsets) + if (expectedOffsets != offsets) { return failure(); + } expectedOffsets.back() += vecShape.back(); for (int pos = expectedOffsets.size() - 1; pos > 0; pos--) { - if (expectedOffsets[pos] != destShape[pos]) + if (expectedOffsets[pos] != destShape[pos]) { break; + } expectedOffsets[pos] = 0; expectedOffsets[pos - 1] += vecShape[pos - 1]; } @@ -189,11 +197,13 @@ static scf::ForOp hoistUnrolledVectorExtractInsert(RewriterBase &rewriter, LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n"); OpOperand &ret = yield->getOpOperand(it.index()); auto insertOps = getUnrolledInsertSlices(forOp, it.value(), ret); - if (failed(insertOps)) + if (failed(insertOps)) { continue; + } auto extractOps = getUnrolledExtractSlices(it.value(), insertOps.value()); - if (failed(extractOps)) + if (failed(extractOps)) { continue; + } newForOp = hoistVectorExtractInsertSlice(rewriter, extractOps.value(), insertOps.value(), it.value()); break; diff --git a/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp b/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp index 20bc8b4b9d6e..45ac22dcd333 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREECodegenCanonicalizer.cpp @@ -24,8 +24,9 @@ namespace { /// shape is same as the size of the subview. In such cases, the subview can /// be folded into its source. static bool isTrivialSubViewOp(memref::SubViewOp subviewOp) { - if (subviewOp.getSourceType().getRank() != subviewOp.getType().getRank()) + if (subviewOp.getSourceType().getRank() != subviewOp.getType().getRank()) { return false; + } if (!areAllConstantIntValue(subviewOp.getMixedOffsets(), 0) || !areAllConstantIntValue(subviewOp.getMixedStrides(), 1)) { @@ -81,8 +82,9 @@ class DynamicTrivialSubViewOpFolder final LogicalResult matchAndRewrite(memref::SubViewOp subViewOp, PatternRewriter &rewriter) const override { - if (!isTrivialSubViewOp(subViewOp)) + if (!isTrivialSubViewOp(subViewOp)) { return failure(); + } if (subViewOp.getSourceType() == subViewOp.getType()) { rewriter.replaceOp(subViewOp, subViewOp.getSource()); return success(); @@ -105,8 +107,9 @@ struct IREECodegenCanonicalizerPass final GreedySimplifyRegionLevel::Normal); RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) + for (auto *dialect : context->getLoadedDialects()) { dialect->getCanonicalizationPatterns(owningPatterns); + } for (RegisteredOperationName op : context->getRegisteredOperations()) { if (op.getStringRef() == memref::CopyOp::getOperationName()) { owningPatterns.add(context); diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp index c6acdd001089..3aabe98df1e9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp @@ -63,9 +63,10 @@ static FailureOr defaultAllocationFn(OpBuilder &builder, Location loc, // type memory space; that's runtime allocations. So erase and fallback to // the default 0 memory space. It is fine given this is just the default // allocator; backends are expected to control by themselves. - if (isa(storage)) + if (isa(storage)) { type = MemRefType::get(type.getShape(), type.getElementType(), type.getLayout()); + } } return memref::AllocOp::create(builder, loc, type, dynamicSizes).getResult(); } @@ -172,12 +173,14 @@ eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, const OneShotBufferizationOptions &options) { // Analyze IR. OneShotAnalysisState state(op, options); - if (failed(analyzeOp(op, state))) + if (failed(analyzeOp(op, state))) { return failure(); + } // Rewrite tensor.empty ops that are anchored on specific ops. - if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) { return failure(); + } return success(); } @@ -215,11 +218,13 @@ void EliminateEmptyTensorsPass::runOnOperation() { auto bufferizationOptions = getBufferizationOptions(); OneShotAnalysisState state(funcOp, bufferizationOptions); // Analyze IR. - if (failed(analyzeOp(funcOp, state))) + if (failed(analyzeOp(funcOp, state))) { return signalPassFailure(); + } // Eliminate empty tensors. - if (failed(bufferization::eliminateEmptyTensors(rewriter, funcOp, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, funcOp, state))) { return signalPassFailure(); + } } // The following is copied from bufferization::runOneShotBufferize with @@ -229,10 +234,12 @@ runIREEOneShotBufferize(Operation *op, const IREEOneShotBufferizationOptions &options, bufferization::BufferizationState &state) { OneShotAnalysisState analyzeState(op, options); - if (failed(analyzeOp(op, analyzeState))) + if (failed(analyzeOp(op, analyzeState))) { return failure(); - if (options.testAnalysisOnly) + } + if (options.testAnalysisOnly) { return success(); + } return bufferization::runOneShotBufferize(op, options, state); } @@ -302,10 +309,12 @@ std::unique_ptr> createIREEComprehensiveBufferizePass( std::optional allocationFn, std::optional memCpyFn) { - if (!allocationFn) + if (!allocationFn) { allocationFn = defaultAllocationFn; - if (!memCpyFn) + } + if (!memCpyFn) { memCpyFn = defaultMemCpyFn; + } return std::make_unique(allocationFn.value(), memCpyFn.value()); } diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp index d2a9ed3ca3ed..a02432b02be6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp @@ -158,8 +158,9 @@ struct ConvertMemRefExtractMetadataToIREECodegen using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { - if (!getSourceInterfaceBinding(op.getSource())) + if (!getSourceInterfaceBinding(op.getSource())) { return failure(); + } // Replace with iree_codegen version which doesn't fold. rewriter.replaceOpWithNewOp( op, op.getSource()); @@ -173,8 +174,9 @@ struct ResolveExtractMetadataFromHalInterfaceBindingSubspan LogicalResult matchAndRewrite(IREE::Codegen::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto binding = getSourceInterfaceBinding(op.getSource()); - if (!binding) + if (!binding) { return failure(); + } auto memRefType = cast(binding->getResult().getType()); auto loc = op.getLoc(); @@ -287,8 +289,9 @@ struct ConvertIREECodegenExtractMetadataToMemRef // Pattern ResolveExtractMetadataFromHalInterfaceBindingSubspan must // resolve these first to preserve SSA links through buffer binding // optimizations. - if (getSourceInterfaceBinding(op.getSource())) + if (getSourceInterfaceBinding(op.getSource())) { return failure(); + } // Only convert ops that don't have HAL bindings (or are already resolved). rewriter.replaceOpWithNewOp( diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp index 1980cd3d220a..f66fb308e4e4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp @@ -161,8 +161,9 @@ static void updateNamedSequenceOp( seenNames.insert(newSeqName); // Skip updating ForeachMatchOp if the NamedSequenceOp is not used in it. - if (!namedSequenceToUser.contains(op)) + if (!namedSequenceToUser.contains(op)) { return; + } ForeachMatchOp foreachMatchOp = namedSequenceToUser[op]; diff --git a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp index f35fd48c60e0..74b460ebeafb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp @@ -120,8 +120,9 @@ convertToUKernelGeneric(RewriterBase &rewriter, Operation *op, StringRef name, provider.createAndReplaceWithUkernelOp( rewriter, name, targetConfiguration, op, tensorInputs, tensorOutputs, otherOperands); - if (retVal) + if (retVal) { return retVal.value(); + } } // Default ukernel generic op is created when a provider doesn't exist or when // the provider doesn't implement the replacement method. diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp index c606152324b1..d8686b812d6a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncoding.cpp @@ -204,8 +204,9 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp, // the pipeline. if (isa(consumer) && isa_and_nonnull(producer) && - !producer->hasOneUse()) + !producer->hasOneUse()) { return false; + } return true; }); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp index 191c78306dac..b4db56291403 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp @@ -461,8 +461,9 @@ struct MaterializeOperation : public OpConversionPattern { this->template getTypeConverter(); FailureOr convertedOp = lowerOpWithEncoding(rewriter, op, adaptor.getOperands(), *converter); - if (failed(convertedOp)) + if (failed(convertedOp)) { return failure(); + } rewriter.replaceOp(op, convertedOp.value()); return success(); @@ -705,8 +706,9 @@ void populateMaterializeEncodingPatterns( auto resultType = dyn_cast( subspanOp.getResult().getType()); // For types that are not `TensorExt::DispatchTensorType` mark as legal. - if (!resultType) + if (!resultType) { return true; + } return resultType == typeConverter.convertType(resultType); }); target.addIllegalOp( storeOp.getTargetType()); // For types that are not `TensorExt::DispatchTensorType` mark as legal. - if (!resultType) + if (!resultType) { return true; + } return resultType == typeConverter.convertType(resultType); }); target.addDynamicallyLegalOp( @@ -725,8 +728,9 @@ void populateMaterializeEncodingPatterns( auto resultType = dyn_cast( loadOp.getSourceType()); // For types that are not `TensorExt::DispatchTensorType` mark as legal. - if (!resultType) + if (!resultType) { return true; + } return resultType == typeConverter.convertType(resultType); }); target.addDynamicallyLegalOp([](func::ReturnOp returnOp) { diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp index a11fa423ddf4..b008bfeaeb50 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeTuningSpecsPass.cpp @@ -149,9 +149,10 @@ getDefaultTuningSpec(ModuleOp module, #ifndef NDEBUG if (succeeded(defaultTransformLibrary) && - failed(mlir::verify(*defaultTransformLibrary))) + failed(mlir::verify(*defaultTransformLibrary))) { return (*defaultTransformLibrary).emitError() << "Default tuning spec from " << storageAttr << " failed to verify"; + } #endif return defaultTransformLibrary; diff --git a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp index cd8d08df66c1..532730af1b85 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp @@ -142,8 +142,9 @@ static bool predicateDeviceLibImpl(StringRef name, bool hasFastExp = isROCMBackend(target); // If fast exp is not available, don't use device-lib implementations. - if (!hasFastExp) + if (!hasFastExp) { return false; + } // Only apply to erf for now. StringRef erf = math::ErfOp::getOperationName(); diff --git a/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp b/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp index 0839fb76a747..f1c85625e0a5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MemrefCopyToLinalg.cpp @@ -25,8 +25,9 @@ struct MemrefCopyOpToLinalg : public OpRewritePattern { Operation *linalgCopy = createLinalgCopyOp(rewriter, copyOp.getLoc(), copyOp.getSource(), copyOp.getTarget(), copyOp->getAttrs()); - if (!linalgCopy) + if (!linalgCopy) { return failure(); + } rewriter.replaceOp(copyOp, linalgCopy->getResults()); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp index 6a237b1d8d3a..d9802f233b33 100644 --- a/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp @@ -137,8 +137,9 @@ LogicalResult normalizeLoopBounds(RewriterBase &rewriter, scf::ForOp forOp) { LogicalResult normalizeLoopBounds(RewriterBase &rewriter, scf::ForallOp forallOp) { OpBuilder::InsertionGuard g(rewriter); - if (forallOp.isNormalized()) + if (forallOp.isNormalized()) { return success(); + } // `scf.forall` requires that all lbs/ubs/steps/ivs are index type so no need // to check here. diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp index e65554190878..ef4fbe5ad5f9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp @@ -49,8 +49,9 @@ class OptimizeTensorInsertExtractSlicesPass final static bool canBeHoisted(LoopLikeOpInterface loopLike, SubsetInsertionOpInterface insertion) { // Do not move terminators. - if (insertion->hasTrait()) + if (insertion->hasTrait()) { return false; + } // Walk the nested operations and check that all used values are either // defined outside of the loop or in a nested region, but not at the level of @@ -58,8 +59,10 @@ static bool canBeHoisted(LoopLikeOpInterface loopLike, auto walkFn = [&](Operation *child) { for (OpOperand &operand : child->getOpOperands()) { // Ignore values defined in a nested region. - if (insertion->isAncestor(operand.get().getParentRegion()->getParentOp())) + if (insertion->isAncestor( + operand.get().getParentRegion()->getParentOp())) { continue; + } if (!loopLike.isDefinedOutsideOfLoop(operand.get()) && &operand != &insertion.getSourceOperand()) { return WalkResult::interrupt(); @@ -310,8 +313,9 @@ struct FoldMaskedTransferRAW : OpRewritePattern { [](Value v) { return !isZeroInteger(v); }) || llvm::any_of(writeOp.getIndices(), [](Value v) { return !isZeroInteger(v); })) && - (op.getIndices() != writeOp.getIndices())) + (op.getIndices() != writeOp.getIndices())) { return failure(); + } // Work only with minor identity mappings. if (!op.getPermutationMap().isMinorIdentity() || diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index 1aee9088d051..d639f8d79edf 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -44,8 +44,9 @@ class TransposeUnitDimToShapeCast unsigned numNonUnitSrcDim = llvm::count_if(op.getSourceVectorType().getShape(), [](int64_t dim) { return dim != 1; }); - if (numNonUnitSrcDim > 1) + if (numNonUnitSrcDim > 1) { return failure(); + } rewriter.replaceOpWithNewOp( op, op.getResultVectorType(), op.getVector()); return success(); diff --git a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp index b7794ffaea06..46c6b559f175 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp @@ -23,15 +23,18 @@ namespace mlir::iree_compiler { /// compute alloc sizes. static Value skipAffineMaxZero(Value dim) { auto affineMax = dim.getDefiningOp(); - if (!affineMax) + if (!affineMax) { return dim; + } for (AffineExpr expr : affineMax.getMap().getResults()) { if (auto cst = dyn_cast(expr)) { - if (cst.getValue() == 0) + if (cst.getValue() == 0) { continue; + } } else if (auto symExpr = dyn_cast(expr)) { - if (symExpr.getPosition() == 0) + if (symExpr.getPosition() == 0) { continue; + } } return dim; } @@ -62,8 +65,9 @@ static LogicalResult padAlloc(MLIRContext *context, AllocLikeOp allocOp, dimSize = *ub; sizes.push_back(dim); } - if (dynamicDimIdx == 0) + if (dynamicDimIdx == 0) { return success(); + } Type elType = allocOp.getType().getElementType(); MemRefType allocType = MemRefType::get(shape, elType, AffineMap(), allocOp.getType().getMemorySpace()); @@ -98,8 +102,9 @@ struct PadDynamicAllocPass final SmallVector allocs; funcOp.walk([&](memref::AllocOp allocOp) { allocs.push_back(allocOp); }); for (memref::AllocOp alloc : allocs) { - if (failed(padAlloc(context, alloc, solver))) + if (failed(padAlloc(context, alloc, solver))) { return signalPassFailure(); + } } // Collect all the alloca operations. @@ -107,8 +112,9 @@ struct PadDynamicAllocPass final funcOp.walk( [&](memref::AllocaOp allocaOp) { allocas.push_back(allocaOp); }); for (memref::AllocaOp alloca : allocas) { - if (failed(padAlloc(context, alloca, solver))) + if (failed(padAlloc(context, alloca, solver))) { return signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp index a16b661dc5e1..dd4ad8651fae 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp @@ -105,8 +105,9 @@ struct FoldApplySymbolOrDimSum final : OpRewritePattern { replacements.reserve(map.getNumInputs()); int64_t numDims = map.getNumDims(); auto getCurrExpr = [&](int64_t i) -> AffineExpr { - if (i >= numDims) + if (i >= numDims) { return rewriter.getAffineSymbolExpr(i - numDims); + } return rewriter.getAffineDimExpr(i); }; bool didReplace = false; @@ -157,8 +158,9 @@ struct PropagateConstantAddsThroughLinearize final int64_t runningOffset = 0; Value zero = nullptr; auto getZero = [&]() { - if (zero) + if (zero) { return zero; + } zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); return zero; }; @@ -252,8 +254,9 @@ struct FoldDivisibleConstantMulsIntoLinearize final SmallVector newStaticBasis; Value zero = nullptr; auto getZero = [&]() { - if (zero) + if (zero) { return zero; + } zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); return zero; }; diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index 20a7a3c8700b..e06fb170c6af 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -33,11 +33,13 @@ getExpandedShape(SmallVector reIndices, SmallVectorImpl &expandedShape, SmallVectorImpl &totalInnerSizes) { auto destType = dyn_cast(dest.getType()); - if (!destType) + if (!destType) { return failure(); + } // TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice. - if (reIndices.size() != destType.getShape().size()) + if (reIndices.size() != destType.getShape().size()) { return failure(); + } // Iterator to insert outer sizes. auto outerShapeIdx = 0; for (auto [reassociations, destSize] : @@ -58,13 +60,15 @@ getExpandedShape(SmallVector reIndices, for (int64_t reasociation : llvm::drop_begin(reassociations)) { int64_t expandedInnerSize = sliceStaticSizes[reasociation]; // It is not safe to do this pattern if inner dimensions are dynamic. - if (ShapedType::isDynamic(expandedInnerSize)) + if (ShapedType::isDynamic(expandedInnerSize)) { return failure(); + } expandedShape.push_back(expandedInnerSize); totalInnerSize *= expandedInnerSize; } - if (destSize % totalInnerSize != 0) + if (destSize % totalInnerSize != 0) { return failure(); + } totalInnerSizes.push_back(totalInnerSize); // insert the outer size in front of any inner sizes. expandedShape.insert(expandedShape.begin() + outerShapeIdx, @@ -88,20 +92,26 @@ static LogicalResult verifyAndCollectExpandableUsers( continue; } auto extractSliceOp = dyn_cast(user); - if (!extractSliceOp) + if (!extractSliceOp) { return failure(); - if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) + } + if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) { return failure(); - if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets()) + } + if (extractSliceOp.getMixedOffsets() != + parallelInsertOp.getMixedOffsets()) { return failure(); + } for (Operation *user : extractSliceOp->getUsers()) { auto expandShapeOp = dyn_cast(user); - if (!expandShapeOp) + if (!expandShapeOp) { return failure(); + } SmallVector expandReIndices = expandShapeOp.getReassociationIndices(); - if (reIndices != expandReIndices) + if (reIndices != expandReIndices) { return failure(); + } } expandableUsers.push_back(extractSliceOp); } @@ -187,8 +197,9 @@ struct ExpandDestinationForallOp final auto collapseOp = parallelInsertOp.getSource().getDefiningOp(); // No collapse op to hoist out. - if (!collapseOp) + if (!collapseOp) { return failure(); + } // Ignore trivially foldable collapse ops. if (collapseOp.getSrcType().getRank() == @@ -204,8 +215,9 @@ struct ExpandDestinationForallOp final int64_t tiedResultIdx = tiedResult.getResultNumber(); auto forallOp = dyn_cast(tiedResult.getOwner()); - if (!forallOp) + if (!forallOp) { return failure(); + } SmallVector expandedDestShape; SmallVector totalInnerSizes; @@ -227,16 +239,19 @@ struct ExpandDestinationForallOp final auto storeOp = dyn_cast(foralluser); if (storeOp && isFullSlice(storeOp, storeOp.getTargetType(), - storeOp.getTargetDims())) + storeOp.getTargetDims())) { continue; + } auto storeToBufferOp = dyn_cast(foralluser); - if (!storeToBufferOp) + if (!storeToBufferOp) { return failure(); + } MemRefType bufferType = storeToBufferOp.getBuffer().getType(); if (failed(memref::ExpandShapeOp::computeExpandedType( - bufferType, expandedDestShape, reIndices))) + bufferType, expandedDestShape, reIndices))) { return failure(); + } } // This allows us to assume that the extract/inserts in the loop are diff --git a/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp index d53084d66cea..a1b82815e72f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp @@ -99,8 +99,9 @@ struct FoldCollapseShapeIntoInterfaceTensorLoad auto reshapeSrcType = cast(reshapeSrc.getType()); auto loadOp = reshapeSrc.getDefiningOp(); - if (!loadOp) + if (!loadOp) { return failure(); + } // Make sure we are loading the full incoming subspan. Otherwise we cannot // simply adjust the subspan's resultant type later. @@ -110,8 +111,9 @@ struct FoldCollapseShapeIntoInterfaceTensorLoad auto subspanOp = loadOp.getSource() .getDefiningOp(); - if (!subspanOp) + if (!subspanOp) { return failure(); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(subspanOp); @@ -200,8 +202,9 @@ struct FoldExpandShapeIntoInterfaceTensorLoad auto subspanOp = loadOp.getSource() .getDefiningOp(); - if (!subspanOp) + if (!subspanOp) { return failure(); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(subspanOp); @@ -305,8 +308,9 @@ struct FoldExpandShapeIntoInterfaceTensorStore auto subspanOp = storeOp.getTarget() .getDefiningOp(); - if (!subspanOp) + if (!subspanOp) { return failure(); + } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subspanOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp b/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp index b742e844888f..66b088db3684 100644 --- a/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp @@ -21,8 +21,9 @@ struct StripFuncOpTranslationInfo final using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(mlir::FunctionOpInterface funcOp, PatternRewriter &rewriter) const final { - if (!getTranslationInfo(funcOp)) + if (!getTranslationInfo(funcOp)) { return failure(); + } rewriter.modifyOpInPlace(funcOp, [&]() { // If the function has translation info, erase it. @@ -38,8 +39,9 @@ struct StripLinalgOpCompilationInfo final using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const final { - if (!getCompilationInfo(linalgOp) && !getLoweringConfig(linalgOp)) + if (!getCompilationInfo(linalgOp) && !getLoweringConfig(linalgOp)) { return failure(); + } rewriter.modifyOpInPlace(linalgOp, [&]() { if (getCompilationInfo(linalgOp)) { // Erase the compilation info configuration if it exists. diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp index b3c49aebb466..b29265471b81 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp @@ -138,8 +138,9 @@ static void updateTensorDimInfo( auto resultType = cast(result.getType()); int dimOperandIndex = 0; for (auto [index, shape] : llvm::enumerate(resultType.getShape())) { - if (ShapedType::isStatic(shape)) + if (ShapedType::isStatic(shape)) { continue; + } updateTensorDimInfo(result, index, dimOperands[dimOperandIndex++], solver, divisibilityInfo, rangeInfo); } @@ -185,8 +186,9 @@ static void updateTensorDimInfo( LLVM_DEBUG({ for (auto [resultIndex, result] : llvm::enumerate(op->getResults())) { auto tensorType = dyn_cast(result.getType()); - if (!tensorType) + if (!tensorType) { continue; + } for (auto index : llvm::seq(0, tensorType.getRank())) { std::optional range; std::optional divisibility; diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp index 03672231dfd6..1f41ff706aa5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp @@ -34,8 +34,9 @@ static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder, Location loc) { IntegerAttr attr; if (Value val = dyn_cast(attrOrValue)) { - if (val.getType().isIndex()) + if (val.getType().isIndex()) { return val; + } matchPattern(val, m_Constant(&attr)); } else { attr = cast(cast(attrOrValue)); @@ -84,13 +85,15 @@ struct VectorizePadWithConditions final PatternRewriter &rewriter) const override { // Static result shape is needed to reading padded dimensions in an // unrolled manner. - if (!padOp.getType().hasStaticShape()) + if (!padOp.getType().hasStaticShape()) { return failure(); + } // Only support constant padding value cases. Value paddingValue = padOp.getConstantPaddingValue(); - if (!paddingValue) + if (!paddingValue) { return failure(); + } Attribute paddingAttr; if (!matchPattern(paddingValue, m_Constant(&paddingAttr))) { return failure(); @@ -127,8 +130,9 @@ struct VectorizePadWithConditions final SmallVector paddedDimLBs(tensorRank); SmallVector paddedDimUBs(tensorRank); for (int i = 0; i < tensorRank; ++i) { - if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) + if (isConstantZero(lowPads[i]) && isConstantZero(highPads[i])) { continue; + } paddedDimIndices.push_back(i); auto srcDimSize = @@ -147,8 +151,9 @@ struct VectorizePadWithConditions final loc, SplatElementsAttr::get(fullVectorType, {paddingAttr})); auto sliceVectorShape = llvm::to_vector(paddedTensorShape); - for (int dim : paddedDimIndices) + for (int dim : paddedDimIndices) { sliceVectorShape[dim] = 1; + } auto sliceVectorType = VectorType::get(dropLeadingOne(sliceVectorShape), elementType); Value cstSliceVector = rewriter.createOrFold( @@ -157,8 +162,9 @@ struct VectorizePadWithConditions final // Calculate the total count of all padded dimensions. We need to generate // vector read ops with scf.if guards for each of them. int totalCount = 1; - for (int dim : paddedDimIndices) + for (int dim : paddedDimIndices) { totalCount *= paddedTensorShape[dim]; + } auto zeroIndex = rewriter.createOrFold(loc, 0); auto trueAttr = rewriter.getBoolAttr(true); diff --git a/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp b/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp index 217d20ba1040..2779c7ff4486 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TestExecutablePreprocessing.cpp @@ -29,8 +29,9 @@ struct TestExecutablePreprocessingPass final // whatever it needed to the executable instead. getOperation()->walk([&](IREE::HAL::ExecutableVariantOp variantOp) { auto configAttr = variantOp.getTarget().getConfiguration(); - if (!configAttr) + if (!configAttr) { return; + } auto replacementAttr = configAttr.getAs("replace_i64"); if (!replacementAttr) { // Skip variants that don't request modification. diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp index fe52ed733687..59c230de7494 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp @@ -106,8 +106,9 @@ getTileAndDistributeConfig(ArrayRef computeOps, partitionableLoopsSet.insert(partitionableLoops.begin(), partitionableLoops.end()); for (auto loopId : llvm::seq(0, tileSizes.size())) { - if (partitionableLoopsSet.count(loopId)) + if (partitionableLoopsSet.count(loopId)) { continue; + } tileSizes[loopId] = 0; } @@ -181,10 +182,12 @@ static LogicalResult lowerDispatchWorkgroupCountForDagRootOp( // slowest varying. SmallVector numWorkgroups; for (auto partitionedLoop : llvm::reverse(partitionedLoops)) { - if (partitionedLoop >= tileSizes.size()) + if (partitionedLoop >= tileSizes.size()) { continue; - if (isZeroInteger(tileSizes[partitionedLoop])) + } + if (isZeroInteger(tileSizes[partitionedLoop])) { continue; + } Value numTileAlongDim = getValueOrCreateConstantIndexOp( rewriter, loc, numTiles[partitionedLoop]); if (numWorkgroups.size() == maxWorkgroupParallelDims) { diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp index 126edd72b9c3..1a4a5a40c943 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp @@ -35,23 +35,26 @@ void fuseProducersOfSlices(RewriterBase &rewriter, auto fusableProducer = candidateSlice.getSource().getDefiningOp(); - if (!fusableProducer) + if (!fusableProducer) { continue; + } std::optional controlFnResult = options.fusionControlFn(candidateSlice, cast(candidateSlice.getSource()), /*destinationInitArg=*/false); - if (!controlFnResult) + if (!controlFnResult) { continue; + } // The operands of the fused producer might themselves be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. std::optional fusedResult = scf::tileAndFuseProducerOfSlice(rewriter, candidateSlice, loops); - if (!fusedResult) + if (!fusedResult) { continue; + } for (auto newSlice : fusedResult->generatedSlices) { worklist.push(newSlice); @@ -70,8 +73,9 @@ void collectTiledAndFusedOps(Operation *rootOp, for (OpOperand &operand : current->getOpOperands()) { Operation *producer = operand.get().getDefiningOp(); if (!producer || !isa(producer) || - result.count(producer)) + result.count(producer)) { continue; + } worklist.push_back(producer); result.insert(producer); } @@ -181,10 +185,11 @@ fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef tiledOps, // list of slices to handle. Otherwise, insert it into the right // position based on dominance. auto *it = llvm::lower_bound(candidates, entry, comp); - if (it != candidates.end() && it->fusableUser == fusableUser) + if (it != candidates.end() && it->fusableUser == fusableUser) { *it = std::move(entry); - else + } else { candidates.insert(it, std::move(entry)); + } } } } @@ -250,15 +255,17 @@ collectTiledAndFusedOps(Operation *op, Operation *current = worklist.pop_back_val(); for (OpOperand &operand : current->getOpOperands()) { auto producer = operand.get().getDefiningOp(); - if (!producer || ops.contains(producer) || exclude.contains(producer)) + if (!producer || ops.contains(producer) || exclude.contains(producer)) { continue; + } worklist.push_back(producer); ops.insert(producer); } for (auto user : current->getUsers()) { auto consumer = dyn_cast(user); - if (!consumer || ops.contains(consumer) || exclude.contains(consumer)) + if (!consumer || ops.contains(consumer) || exclude.contains(consumer)) { continue; + } worklist.push_back(consumer); ops.insert(consumer); } @@ -374,8 +381,9 @@ LogicalResult applyTileAndFuseToEachRoot( // We dont want this for reduction tiling as it can lead to large tensors // being yielded. if (tilingLevel != IREE::GPU::TilingLevel::Reduction && - tilingLevel != IREE::GPU::TilingLevel::PartialReduction) + tilingLevel != IREE::GPU::TilingLevel::PartialReduction) { yieldProducerReplacement = yieldReplacementsFor.contains(owner); + } bool shouldFuse = false; if (auto tilingOwner = dyn_cast(owner)) { shouldFuse = !payloadOps.contains(tilingOwner); @@ -440,7 +448,7 @@ LogicalResult applyTileAndFuseToEachRoot( SmallVector opsToReplace{tilingInterfaceOp}; llvm::append_range(opsToReplace, tiledResults->fusedProducers); for (Operation *toReplace : opsToReplace) { - for (OpResult res : toReplace->getResults()) + for (OpResult res : toReplace->getResults()) { if (auto replacement = tiledResults->replacements.lookup(res)) { Operation *replacementOp = replacement.getDefiningOp(); rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { @@ -448,6 +456,7 @@ LogicalResult applyTileAndFuseToEachRoot( return dominanceInfo.properlyDominates(replacementOp, user); }); } + } if (toReplace->use_empty()) { rewriter.eraseOp(toReplace); diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp index f3d156cc37c7..6e7005a67a12 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp @@ -216,8 +216,9 @@ static bool verifyComputeOpsAfterDistribution(FunctionOpInterface funcOp) { /// for the DPS `user`. Returns false if the user is not in DPS. static bool isUsedAsInit(Operation *producer, Operation *user) { auto dpsIface = dyn_cast(user); - if (!dpsIface) + if (!dpsIface) { return false; + } ValueRange results = producer->getResults(); return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) { return llvm::is_contained(results, operand); @@ -251,8 +252,9 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { // an init of a DPS op, the user currently cannot be fused. Having a // replacement for it would attempt fusion and fail, so avoid such cases. if (llvm::any_of(op->getUsers(), [&](Operation *user) { - if (isUsedAsInit(op, user)) + if (isUsedAsInit(op, user)) { return false; + } return dominanceInfo.properlyDominates(tilableOp, user) || !tiledAndFusedOps.contains(user); })) { diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp index cb9718857383..d3aec86e681f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp @@ -36,14 +36,16 @@ static SmallVector fillInterchangeVector(ArrayRef interchangeVector, size_t iterationDomainSize) { SmallVector filledVector; - for (auto v : interchangeVector) + for (auto v : interchangeVector) { filledVector.push_back(v); + } if (filledVector.size() < iterationDomainSize) { auto range = llvm::seq(filledVector.size(), iterationDomainSize); filledVector.append(range.begin(), range.end()); } - if (filledVector.size() > iterationDomainSize) + if (filledVector.size() > iterationDomainSize) { filledVector.resize(iterationDomainSize); + } return filledVector; } @@ -208,8 +210,9 @@ static LogicalResult replaceStoresWithTiledVersion( storeOps.push_back(storeOp); } } - if (storeOps.empty()) + if (storeOps.empty()) { return success(); + } if (storeOps.size() != 1) { return rewriter.notifyMatchFailure(untiledValue.getOwner(), "expected a single store for the op"); @@ -398,9 +401,10 @@ tileDispatchUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, }); // 4. Generate the tiled implementation within the inner most loop. - if (!tilingResult.loops.empty()) + if (!tilingResult.loops.empty()) { rewriter.setInsertionPoint( tilingResult.loops.back().getBody()->getTerminator()); + } FailureOr tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); if (failed(tiledImplementation)) { @@ -480,8 +484,9 @@ getAllFusableProducerUses(Operation *untiledOp, for (auto tiledOp : llvm::reverse(tiledOps)) { for (OpOperand &operand : llvm::reverse(tiledOp->getOpOperands())) { auto sliceOp = operand.get().getDefiningOp(); - if (!sliceOp || sliceOp.getSource().getDefiningOp() != untiledOp) + if (!sliceOp || sliceOp.getSource().getDefiningOp() != untiledOp) { continue; + } sliceOps.push_back(sliceOp); } } @@ -572,8 +577,9 @@ struct SwapExtractSliceWithDispatchTensorLoad PatternRewriter &rewriter) const override { auto loadOp = sliceOp.getSource() .getDefiningOp(); - if (!loadOp) + if (!loadOp) { return failure(); + } SmallVector combinedOffsets, combinedSizes, combinedStrides; if (failed(affine::mergeOffsetsSizesAndStrides( @@ -602,8 +608,9 @@ struct SwapExtractSliceWithTensorEmpty LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto emptyTensorOp = sliceOp.getSource().getDefiningOp(); - if (!emptyTensorOp) + if (!emptyTensorOp) { return failure(); + } SmallVector mixedSizes = sliceOp.getMixedSizes(); if (mixedSizes.size() != sliceOp.getType().getRank()) { @@ -611,8 +618,9 @@ struct SwapExtractSliceWithTensorEmpty rankReducedMixedSizes.reserve(sliceOp.getType().getRank()); auto droppedDims = sliceOp.getDroppedDims(); for (auto [index, size] : llvm::enumerate(mixedSizes)) { - if (droppedDims.test(index)) + if (droppedDims.test(index)) { continue; + } rankReducedMixedSizes.push_back(size); } std::swap(mixedSizes, rankReducedMixedSizes); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp index fa95a8151c3e..b11f9f0e962e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp @@ -69,8 +69,9 @@ class TransformDialectInterpreterPass final } if (failed(transform::applyTransformNamedSequence( payloadRoot, transformEntryPoint, transformModule, - options.enableExpensiveChecks(true)))) + options.enableExpensiveChecks(true)))) { return signalPassFailure(); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 1cb8615d66e6..5af5331db67f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -88,8 +88,9 @@ void mlir::iree_compiler::registerTransformDialectCommonExtension( //===---------------------------------------------------------------------===// static void addOperands(Operation *op, SetVector &operandSet) { - if (!op) + if (!op) { return; + } TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { SmallVector inputOperands = linalgOp.getDpsInputs(); @@ -103,12 +104,14 @@ static void addOperands(Operation *op, SetVector &operandSet) { template static bool setFusedOpOperandLimit(OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); - if (!producer) + if (!producer) { return false; + } Operation *consumer = fusedOperand->getOwner(); SetVector fusedOpOperands; - if (producer->getNumResults() != 1) + if (producer->getNumResults() != 1) { return false; + } addOperands(consumer, fusedOpOperands); fusedOpOperands.remove(producer->getResult(0)); addOperands(producer, fusedOpOperands); @@ -148,8 +151,9 @@ void transform_dialect::ApplyUnrollVectorsGpuMmaSyncPatternsOp:: populatePatterns(RewritePatternSet &patterns) { auto unrollOrder = [](Operation *op) -> std::optional> { auto contract = dyn_cast(op); - if (!contract) + if (!contract) { return std::nullopt; + } return mlir::iree_compiler::gpuMmaUnrollOrder(contract); }; vector::populateVectorUnrollPatterns( @@ -171,8 +175,9 @@ void transform_dialect::ApplyUnrollVectorsGpuWmmaSyncPatternsOp:: populatePatterns(RewritePatternSet &patterns) { auto unrollOrder = [](Operation *op) -> std::optional> { auto contract = dyn_cast(op); - if (!contract) + if (!contract) { return std::nullopt; + } return mlir::iree_compiler::gpuMmaUnrollOrder(contract); }; vector::populateVectorUnrollPatterns( @@ -280,8 +285,9 @@ static bool isAscendingRelativeMapping(ArrayRef mapping) { static FailureOr flattenForallOp(RewriterBase &rewriter, scf::ForallOp forallOp) { - if (!forallOp.getMapping().has_value()) + if (!forallOp.getMapping().has_value()) { return forallOp->emitError("mapping must be present"); + } SmallVector mapping = llvm::to_vector(forallOp.getMapping()->getValue()); if (!(llvm::all_of(mapping, llvm::IsaPred) || @@ -403,15 +409,18 @@ static LogicalResult rewriteForallToWorkgroup(RewriterBase &rewriter, Attribute bX = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimX); Attribute bY = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimY); Attribute bZ = gpu::GPUBlockMappingAttr::get(ctx, gpu::MappingId::DimZ); - if (forallOp.getNumResults() > 0) + if (forallOp.getNumResults() > 0) { return forallOp->emitError( "only bufferized scf.forall lowers to workgroup"); - if (forallOp.getRank() > 3) + } + if (forallOp.getRank() > 3) { return forallOp->emitError( "scf.forall with rank > 3 does not lower to workgroup"); + } - if (!forallOp.getMapping().has_value()) + if (!forallOp.getMapping().has_value()) { return forallOp->emitError("mapping must be present"); + } SmallVector blockMapping = llvm::to_vector(forallOp.getMapping()->getValue()); if (llvm::any_of(blockMapping, [](Attribute map) { @@ -492,10 +501,12 @@ DiagnosedSilenceableFailure transform_dialect::ForallToWorkgroupOp::applyToOne( scf::ForallOp topLevelForallOp; auto walkResult = target->walk([&](scf::ForallOp forallOp) { - if (forallOp->getParentOfType()) + if (forallOp->getParentOfType()) { return WalkResult::advance(); - if (topLevelForallOp) + } + if (topLevelForallOp) { return WalkResult::interrupt(); + } topLevelForallOp = forallOp; return WalkResult::advance(); }); @@ -506,8 +517,9 @@ DiagnosedSilenceableFailure transform_dialect::ForallToWorkgroupOp::applyToOne( } rewriter.setInsertionPoint(topLevelForallOp); - if (failed(rewriteForallToWorkgroup(rewriter, topLevelForallOp))) + if (failed(rewriteForallToWorkgroup(rewriter, topLevelForallOp))) { return mlir::emitDefiniteFailure(target, "rewriteForallToWorkgroup failed"); + } return DiagnosedSilenceableFailure::success(); } @@ -531,29 +543,34 @@ transform_dialect::GpuDistributeSharedMemoryCopyOp::applyToOne( // Look for ops that move to workgroup memory and mark as copies for // distribution. target.walk([&](linalg::GenericOp copyOp) { - if (copyOp.getNumDpsInputs() != 1 || copyOp.getNumDpsInits() != 1) + if (copyOp.getNumDpsInputs() != 1 || copyOp.getNumDpsInits() != 1) { return; + } auto dest = dyn_cast>(copyOp.getDpsInitOperand(0)->get()); - if (!dest) + if (!dest) { return; + } MemRefType destType = dest.getType(); // Check if the only operation in the possible copy op region is a // terminator. Block &body = copyOp.getRegion().front(); - if (!std::begin(body)->hasTrait()) + if (!std::begin(body)->hasTrait()) { return; + } auto destSpace = dyn_cast_if_present(destType.getMemorySpace()); - if (!destSpace) + if (!destSpace) { return; + } // The destination space must be shared memory. - if (destSpace.getValue() != gpu::GPUDialect::getWorkgroupAddressSpace()) + if (destSpace.getValue() != gpu::GPUDialect::getWorkgroupAddressSpace()) { return; + } // Mark this copy operation as a copy to workgroup memory. setMarker(copyOp, getCopyToWorkgroupMemoryMarker()); @@ -682,8 +699,9 @@ transform_dialect::IREEApplyLoopIndependentCodeMotionOp::applyToOne( // Do not hoist from scf.forall ops. These capture isolated computations // that will be mapped to a certain level in the GPU hierarchy (e.g., // GPU blocks), so hoisting is not desired. - if (!isa(loopLike.getOperation())) + if (!isa(loopLike.getOperation())) { moveLoopInvariantCode(loopLike); + } }); // For now, put single loop promotion as part of licm. Underlying // implementations perform splice operations which shouldn't need @@ -803,16 +821,18 @@ static LogicalResult gpuComprehensiveBufferizeCopyFn(OpBuilder &builder, hasSharedMemoryAddressSpace(cast(to.getType()))) { needsBarrier = true; } - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } // TODO: ideally we should use linalg.copy which was recently reintroduced // as an OpDSL named op. However, IREE-specific patterns to cleanup spurious // post-bufferization copies do not trigger properly. // So we keep using `createLinalgCopyOp` which builds a GenericOp. // linalg::CopyOp::create(builder, loc, from, to); mlir::iree_compiler::createLinalgCopyOp(builder, loc, from, to); - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } return success(); } @@ -889,8 +909,9 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply( return mlir::emitDefiniteFailure(target, "greedy pattern application failed"); } - if (listener.failed()) + if (listener.failed()) { return listener.checkAndResetError(); + } } // 2. Run one-shot-bufferize, without the pass baggage. @@ -933,9 +954,10 @@ transform_dialect::IREEEliminateEmptyTensorsOp::applyToOne( ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state) { if (failed( - eliminateEmptyTensors(rewriter, target, getBufferizationOptions()))) + eliminateEmptyTensors(rewriter, target, getBufferizationOptions()))) { return emitDefaultDefiniteFailure(target) << "failed to eliminate tensor.empty ops"; + } return DiagnosedSilenceableFailure::success(); } @@ -983,8 +1005,9 @@ transform_dialect::ShareForallOperandsOp::applyToOne( llvm::to_vector(llvm::seq(0, forallOp.getOutputs().size())); } for (int64_t outputIdx : getShareOperands()) { - if (outputIdx < 0 || outputIdx >= forallOp.getOutputs().size()) + if (outputIdx < 0 || outputIdx >= forallOp.getOutputs().size()) { return mlir::emitDefiniteFailure(forallOp, "operand idx overflow"); + } Value toShare = forallOp.getOutputs()[outputIdx]; if (std::distance(toShare.getUses().begin(), toShare.getUses().end()) != 2) { @@ -997,8 +1020,9 @@ transform_dialect::ShareForallOperandsOp::applyToOne( tensor::ExtractSliceOp extractSliceOp; for (Operation *user : toShare.getUsers()) { extractSliceOp = dyn_cast(user); - if (extractSliceOp) + if (extractSliceOp) { break; + } } if (!extractSliceOp) { /*return mlir::emitSilenceableFailure( @@ -1013,10 +1037,12 @@ transform_dialect::ShareForallOperandsOp::applyToOne( // (i.e., same source/target, offsets, sizes and strides). auto isMatchingParallelInsertSlice = [&](Operation &op) { auto insertSlice = dyn_cast(&op); - if (!insertSlice) + if (!insertSlice) { return false; - if (insertSlice.getDest() != bbArg) + } + if (insertSlice.getDest() != bbArg) { return false; + } return llvm::equal(insertSlice.getMixedOffsets(), extractSliceOp.getMixedOffsets()) && llvm::equal(insertSlice.getMixedSizes(), @@ -1115,8 +1141,9 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, FailureOr fuseConsumerResults = scf::tileAndFuseConsumerOfSlices(rewriter, target, loops); - if (failed(fuseConsumerResults)) + if (failed(fuseConsumerResults)) { return failure(); + } // Report back the relevant handles to the transform op. originalConsumerOps.push_back( diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp index fd4b75a4e73c..945ffdffb3a3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp @@ -229,8 +229,9 @@ swapExpandShapeWithSlice(RewriterBase &rewriter, auto isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isZeroInteger(offset)) + if (!isZeroInteger(offset)) { return false; + } FailureOr maybeEqual = ValueBoundsConstraintSet::areEqual(sliceSize, size); return llvm::succeeded(maybeEqual) && maybeEqual.value(); @@ -275,8 +276,9 @@ swapExpandShapeWithSlice(RewriterBase &rewriter, // Offset = cumulative product of leading unit extracted dims. for (; i < e; ++i) { int64_t expandedDim = indices[i]; - if (!isOneInteger(sizes[expandedDim])) + if (!isOneInteger(sizes[expandedDim])) { break; + } basis.push_back(outputShape[expandedDim]); delinOffsets.push_back(offsets[expandedDim]); @@ -719,8 +721,9 @@ swapCollapseShapeWithSlice(RewriterBase &rewriter, for (; idx < reassocGroupSize; ++idx) { int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - if (currentCollapsedsize < expandedShapeSize) + if (currentCollapsedsize < expandedShapeSize) { break; + } // We need to make sure that the slice size can be set to the shape size // and the offset to 0. diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp index 7ea7a24be18c..bbf0572abaef 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp @@ -62,8 +62,9 @@ namespace mlir::iree_compiler { static Value convertElementType(OpBuilder &b, Location loc, Type targetType, Value source) { Type sourceType = source.getType(); - if (sourceType == targetType) + if (sourceType == targetType) { return source; + } if (isa(sourceType) && isa(targetType)) { unsigned sourceBitWidth = sourceType.getIntOrFloatBitWidth(); unsigned destBitWidth = targetType.getIntOrFloatBitWidth(); @@ -82,8 +83,9 @@ static std::optional getLegalizedType(Type t) { if (auto shapedType = dyn_cast(t)) { std::optional legalizedElementType = legalizeStorageElementType(shapedType); - if (!legalizedElementType) + if (!legalizedElementType) { return std::nullopt; + } return RankedTensorType::get(shapedType.getShape(), legalizedElementType.value(), shapedType.getEncoding()); @@ -117,8 +119,9 @@ struct TypePropagationTypeConverter : public TypeConverter { TypePropagationTypeConverter() { addConversion([](Type t) { auto convertedType = getLegalizedType(t); - if (!convertedType) + if (!convertedType) { return t; + } return convertedType.value(); }); } diff --git a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp index 2ce80a85b4f7..5e0dca7ccbb8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/UserConfig.cpp @@ -19,8 +19,9 @@ setUserConfig(mlir::FunctionOpInterface entryPointFn, Operation *computeOp, } auto info = compilationInfo.getTranslationInfo(); - if (failed(setTranslationInfo(entryPointFn, info))) + if (failed(setTranslationInfo(entryPointFn, info))) { return failure(); + } setLoweringConfig(computeOp, compilationInfo.getLoweringConfig()); eraseCompilationInfo(computeOp); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp index f5b99f7ce9fc..f216e054b92a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Transforms/ReshapeFusion.cpp @@ -28,8 +28,9 @@ canExpandInnerTiledOp(InnerTiledOp op, OpOperand *fusedOperand, ArrayRef reassociation) { // Only single result inner_tiled ops are tested or used anywhere, so restrict // to single result for now. - if (op->getNumResults() != 1) + if (op->getNumResults() != 1) { return failure(); + } // Only outer dims can be expanded because inner dims depend on the `kind` // attribute's implementation. @@ -82,9 +83,10 @@ static InnerTiledOp expandInnerTiledOp( // dims. Get iteration domain to query sizes of dims not in the fused operand. SmallVector iterationDomain = op.getIterationDomain(rewriter); for (int64_t i = 0; i < numIterDims; ++i) { - if (iterDimExpansion[i].empty()) + if (iterDimExpansion[i].empty()) { iterDimExpansion[i].push_back( {expandedDimCounter++, iterationDomain[i].size}); + } } SmallVector newIndexingMaps; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 0ac641b33520..bb17a14f64d8 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -69,13 +69,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, static llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const ScalableTileFlags &scalableTileFlags) { - if (scalableTileFlags.empty()) + if (scalableTileFlags.empty()) { return os; + } os << "scalableTiles = ["; for (unsigned i = 0; i < scalableTileFlags.size(); ++i) { os << (scalableTileFlags[i] ? "true" : "false"); - if (i + 1 < scalableTileFlags.size()) + if (i + 1 < scalableTileFlags.size()) { os << ", "; + } } return os; } @@ -279,8 +281,9 @@ deserializeEncodingInfo(DictionaryAttr attr) { } if (attr.contains("scalableTiles")) { auto value = attr.getNamed("scalableTiles"); - if (!value || !isa(value->getValue())) + if (!value || !isa(value->getValue())) { return std::nullopt; + } ScalableTileFlags res = llvm::map_to_vector( cast(value->getValue()), [](Attribute a) { return cast(a).getValue(); }); diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp index ae2d87c67fa6..b6198d6edbae 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp @@ -149,15 +149,17 @@ struct DispatchTensorStoreOpInterface auto maybeBuffer = getBuffer(rewriter, storeOp->getOpOperand(0).get(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } Value srcMemref = *maybeBuffer; // If everything bufferized inplace, no copy is needed. We wrote to the // target buffer already. The copy folds away in that case. if (failed(options.createMemCpy(rewriter, storeOp->getLoc(), srcMemref, - target))) + target))) { return failure(); + } rewriter.eraseOp(storeOp); return success(); @@ -176,8 +178,9 @@ struct LoadFromBufferOpInterface getSourceSubspanMemref( cast>(loadFromBufferOp.getBuffer())); // Conservatively return false if the subspan is not found. - if (!subspanOp) + if (!subspanOp) { return false; + } std::optional descriptorFlags = subspanOp->getDescriptorFlags(); return !descriptorFlags.has_value() || @@ -219,15 +222,17 @@ struct StoreToBufferOpInterface auto storeOp = cast(op); FailureOr maybeBuffer = getBuffer(rewriter, storeOp.getTensor(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } Value srcMemref = *maybeBuffer; // If everything bufferized inplace, no copy is needed. We wrote to the // target buffer already. The copy folds away in that case. if (failed(options.createMemCpy(rewriter, storeOp.getLoc(), srcMemref, - storeOp.getBuffer()))) + storeOp.getBuffer()))) { return failure(); + } rewriter.eraseOp(storeOp); return success(); @@ -285,13 +290,15 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, rewriter.setInsertionPoint(op); // Nothing to do. This op is already bufferized. - if (dspOp.hasPureBufferSemantics()) + if (dspOp.hasPureBufferSemantics()) { return success(); + } // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need // basis. - if (!dspOp.hasPureTensorSemantics()) + if (!dspOp.hasPureTensorSemantics()) { return op->emitError() << "op does not have tensor semantics"; + } // New input operands for the cloned op. SmallVector newOperands, newOutputBuffers; @@ -305,8 +312,9 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, } if (!dspOp.isDpsInit(&opOperand)) { auto maybeBuffer = getBuffer(rewriter, opOperand.get(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } // Input operands are never written to. newOperands.push_back(*maybeBuffer); continue; @@ -319,8 +327,9 @@ static LogicalResult bufferizeLinalgExtOp(RewriterBase &rewriter, FailureOr resultBuffer = getBuffer( rewriter, aliasingOpOperands.getAliases().front().opOperand->get(), options, state); - if (failed(resultBuffer)) + if (failed(resultBuffer)) { return failure(); + } newOperands.push_back(*resultBuffer); newOutputBuffers.push_back(*resultBuffer); } @@ -385,8 +394,9 @@ getSourceAndDestFromPackUnPackOp(RewriterBase &rewriter, OpTy op, static_assert(llvm::is_one_of::value); Value source; auto maybeBuffer = getBuffer(rewriter, op.getSource(), options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } source = *maybeBuffer; Value dest; @@ -397,8 +407,9 @@ getSourceAndDestFromPackUnPackOp(RewriterBase &rewriter, OpTy op, FailureOr resultBuffer = getBuffer( rewriter, aliasingOpOperands.getAliases().front().opOperand->get(), options, state); - if (failed(resultBuffer)) + if (failed(resultBuffer)) { return failure(); + } dest = *resultBuffer; return std::make_pair(source, dest); } @@ -412,8 +423,9 @@ static LogicalResult bufferizePackOp(RewriterBase &rewriter, linalg::PackOp op, auto maybeSrcAndDest = getSourceAndDestFromPackUnPackOp(rewriter, op, options, state); - if (failed(maybeSrcAndDest)) + if (failed(maybeSrcAndDest)) { return failure(); + } auto [source, dest] = *maybeSrcAndDest; // Set insertion point now that potential alloc/dealloc are introduced. @@ -438,8 +450,9 @@ static LogicalResult bufferizeUnPackOp(RewriterBase &rewriter, auto maybeSrcAndDest = getSourceAndDestFromPackUnPackOp(rewriter, op, options, state); - if (failed(maybeSrcAndDest)) + if (failed(maybeSrcAndDest)) { return failure(); + } auto [source, dest] = *maybeSrcAndDest; // Set insertion point now that potential alloc/dealloc are introduced. @@ -482,8 +495,9 @@ struct PackUnPackOpInterface auto dspOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (dspOp.isDpsInit(&opOperand)) + if (dspOp.isDpsInit(&opOperand)) { return {dspOp.getTiedOpResult(&opOperand)}; + } return {}; } @@ -493,10 +507,11 @@ struct PackUnPackOpInterface auto dspOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (dspOp.isDpsInit(&opOperand)) + if (dspOp.isDpsInit(&opOperand)) { return {AliasingValue(dspOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent, /*isDefinite=*/false)}; + } return {}; } @@ -531,8 +546,9 @@ struct DispatchTensorLoadOpSubsetInterface // DispatchTensorStoreOp result that bufferizes inplace. auto loadOp = cast(op); auto storeOp = dyn_cast(op); - if (!storeOp) + if (!storeOp) { return false; + } return equivalenceFn(loadOp.getSource(), storeOp.getTarget()); } @@ -556,8 +572,9 @@ struct DispatchTensorStoreOpSubsetInterface // DispatchTensorLoadOp result that bufferizes inplace. auto storeOp = cast(op); auto loadOp = dyn_cast(op); - if (!loadOp) + if (!loadOp) { return false; + } return equivalenceFn(loadOp.getSource(), storeOp.getTarget()); } diff --git a/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp b/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp index 4817e365192a..8e2227590d04 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/AffineMinDistributedSCFCanonicalization.cpp @@ -40,8 +40,9 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { // Check if any of the dimensions is a ForOp or ParallelOp induction variable. for (auto dim : minOp.getDimOperands()) { auto ivArg = dyn_cast(dim); - if (!ivArg) + if (!ivArg) { continue; + } Operation *containingOp = ivArg.getOwner()->getParentOp(); auto forOp = dyn_cast_if_present(containingOp); if (forOp && forOp.getInductionVar() == dim) { @@ -52,8 +53,9 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { break; } auto parallelOp = dyn_cast_if_present(containingOp); - if (!parallelOp) + if (!parallelOp) { continue; + } for (auto [index, inductionVar] : llvm::enumerate(parallelOp.getInductionVars())) { if (inductionVar == dim) { @@ -64,11 +66,13 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { break; } } - if (iv) + if (iv) { break; + } } - if (!iv) + if (!iv) { return false; + } // Calculate the affine map representing `%ub - %iv`. AffineExpr ivDim; AffineExpr ubDim; @@ -94,11 +98,13 @@ static bool affineMinOpDivisible(affine::AffineMinOp minOp, int64_t dividend) { // `dividend` or equal to `%ub - %iv`. for (AffineExpr result : minOp.getAffineMap().getResults()) { if (auto cst = dyn_cast(result)) { - if (cst.getValue() <= 0 || cst.getValue() % dividend != 0) + if (cst.getValue() <= 0 || cst.getValue() % dividend != 0) { return false; + } } else { - if (diffExp != result) + if (diffExp != result) { return false; + } } } // Now check that for every value of the induction variable `%ub - %iv` is @@ -121,13 +127,15 @@ static bool isDivisible(Value v, int64_t dividend) { affine::canonicalizeMapAndOperands(&modMap, &ops); modMap = simplifyAffineMap(modMap); auto cst = dyn_cast(modMap.getResult(0)); - if (cst) + if (cst) { return (cst.getValue() == 0); + } // If the map doesn't fold to 0 but simplifies to (d0 %n) with d0 an // affine.min, check if all the results of the affine.min's map are divisible // by `dividend`. - if (modMap.getResult(0) != mod) + if (modMap.getResult(0) != mod) { return false; + } assert(ops.size() == 1); auto minOp = ops[0].getDefiningOp(); return (minOp && affineMinOpDivisible(minOp, dividend)); @@ -149,12 +157,14 @@ static std::optional foldAffineMin(affine::AffineMinOp minOp) { constantResult = cst.getValue(); } } - if (constantResult == 0) + if (constantResult == 0) { return {}; + } // If afine.min map's results are all positive and divisible by // `constantResult` then it can be replaced by `constantResult`. - if (affineMinOpDivisible(minOp, constantResult)) + if (affineMinOpDivisible(minOp, constantResult)) { return constantResult; + } return {}; } @@ -167,8 +177,9 @@ struct AffineMinDistributedSCFCanonicalizationPattern matchAndRewrite(mlir::affine::AffineMinOp minOp, mlir::PatternRewriter &rewriter) const override { std::optional cst = foldAffineMin(minOp); - if (!cst) + if (!cst) { return failure(); + } rewriter.replaceOpWithNewOp(minOp, rewriter.getIndexAttr(*cst)); return success(); diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index d1a94c33567f..8165a7743a74 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -45,11 +45,13 @@ namespace mlir::iree_compiler { static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands, Operation *baseOp) { for (auto val : nonIndexComputationOperands) { - if (op == val.getDefiningOp()) + if (op == val.getDefiningOp()) { return false; + } } - if (op->isProperAncestor(baseOp)) + if (op->isProperAncestor(baseOp)) { return false; + } return !isa(op); } @@ -154,16 +156,18 @@ std::optional hoistOneStaticallyBoundAllocation( vector::ScalableValueBoundsConstraintSet::computeScalableBound( value, std::nullopt, vscaleRange->vscaleMin, vscaleRange->vscaleMax, presburger::BoundType::UB); - if (failed(ub)) + if (failed(ub)) { return failure(); + } if (ub->map.isSingleConstant()) { auto constantBound = ub->map.getSingleConstantResult(); return OpFoldResult(builder.getIndexAttr(constantBound)); } - if (!vscale) + if (!vscale) { vscale = vector::VectorScaleOp::create(builder, loc); + } return affine::materializeComputedBound( builder, loc, ub->map, {std::make_pair(vscale, std::nullopt)}); } @@ -172,8 +176,9 @@ std::optional hoistOneStaticallyBoundAllocation( presburger::BoundType::UB, {value, std::nullopt}, /*stopCondition=*/nullptr, /*closedUB=*/true); - if (failed(ub)) + if (failed(ub)) { return failure(); + } return OpFoldResult(builder.getIndexAttr(*ub)); }; @@ -202,8 +207,9 @@ std::optional hoistOneStaticallyBoundAllocation( Value dynamicSize = dynamicSizes[index++]; auto ub = computeAllocationBound(dynamicSize); - if (failed(ub)) + if (failed(ub)) { return std::nullopt; + } allocSizes.push_back(*ub); subviewSizes.push_back(dynamicSize); @@ -270,8 +276,9 @@ void hoistStaticallyBoundAllocationsInFunc( // Collect all allocLikes that are hoistable. funcOp.walk([&](AllocLikeOpType allocLikeOp) { - if (allocLikeOp->getBlock() == &funcOp.getFunctionBody().front()) + if (allocLikeOp->getBlock() == &funcOp.getFunctionBody().front()) { return; + } if (allocLikeOp.getDynamicSizes().empty()) { allocLikeOps.push_back(allocLikeOp); return; @@ -290,8 +297,9 @@ void hoistStaticallyBoundAllocationsInFunc( SmallVector deallocOps; for (Operation *user : allocLikeOp->getUsers()) { auto dealloc = dyn_cast(user); - if (dealloc) + if (dealloc) { deallocOps.push_back(dealloc); + } } LLVM_DEBUG({ @@ -303,8 +311,9 @@ void hoistStaticallyBoundAllocationsInFunc( }); std::optional replacement = hoistOneStaticallyBoundAllocation( funcOp, rewriter, allocLikeOp, vscaleRange); - if (!replacement) + if (!replacement) { continue; + } LLVM_DEBUG({ llvm::dbgs() << "Replacement : "; replacement->dump(); @@ -312,8 +321,9 @@ void hoistStaticallyBoundAllocationsInFunc( Value replacementVal = replacement.value(); rewriter.replaceOp(allocLikeOp, replacementVal); - for (memref::DeallocOp deallocOp : deallocOps) + for (memref::DeallocOp deallocOp : deallocOps) { rewriter.eraseOp(deallocOp); + } } } @@ -751,10 +761,12 @@ void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target) { // like scf.for, since the value bounds interface requires index types. auto maybeLb = getConstantIntValue(lb); auto maybeUb = getConstantIntValue(ub); - if (!maybeLb || !maybeUb) + if (!maybeLb || !maybeUb) { return; - if (*maybeLb >= *maybeUb) + } + if (*maybeLb >= *maybeUb) { return; + } } } @@ -812,8 +824,9 @@ void analyseAllocsForPacking(mlir::FunctionOpInterface funcOp, // Skip the whole analysis if any user is a subview. // TODO: This could be extended if needed by recursively merging // liveness. - if (isa(user)) + if (isa(user)) { return; + } if (group.liveness.count(user)) { aliasGroups.push_back(i); break; @@ -851,14 +864,16 @@ void analyseAllocsForPacking(mlir::FunctionOpInterface funcOp, LLVM_DEBUG({ for (size_t i = 0; i < groups.size(); i++) { llvm::dbgs() << "Alias group " << i << ":\n"; - for (Operation *op : groups[i].allocs) + for (Operation *op : groups[i].allocs) { op->dump(); + } } }); for (size_t i = 0; i < groups.size(); i++) { - if (groups[i].allocs.empty()) + if (groups[i].allocs.empty()) { continue; + } aliasGroups.push_back(std::move(groups[i].allocs)); } } @@ -873,8 +888,9 @@ static int64_t getAllocSize(Operation *op, DataLayout &dataLayout) { void packAllocs(OpBuilder &builder, mlir::FunctionOpInterface funcOp, ArrayRef aliasGroups) { - if (aliasGroups.empty()) + if (aliasGroups.empty()) { return; + } DataLayout dataLayout = DataLayout::closest(funcOp); builder.setInsertionPointToStart(&(*funcOp.getFunctionBody().begin())); int64_t maxAlloc = 0; @@ -1061,8 +1077,9 @@ struct HoistForallFromFor : public OpRewritePattern { BlockArgument destBbArg = cast(parallelInsert.getDest()); tensor::ExtractSliceOp destSlice; for (auto user : destBbArg.getUsers()) { - if (user == parallelInsert) + if (user == parallelInsert) { continue; + } auto maybeSlice = dyn_cast(user); if (!maybeSlice) { // Fail if the destination has more users than a direct insert and @@ -1099,8 +1116,9 @@ struct HoistForallFromFor : public OpRewritePattern { for (auto [dim, size] : llvm::enumerate(insert.getMixedSizes())) { FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( {size}, {insert.getDest(), static_cast(dim)}); - if (failed(equalDimSize) || !*equalDimSize) + if (failed(equalDimSize) || !*equalDimSize) { return false; + } } return true; }; diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h index ec139f1ae402..9435589f61a7 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h @@ -226,10 +226,12 @@ struct LinalgBasePromotionPattern : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) + if (failed(filter.checkAndNotify(rewriter, op))) { return failure(); - if (failed(promoteSubviewsPrecondition(op, options))) + } + if (failed(promoteSubviewsPrecondition(op, options))) { return failure(); + } // TODO: We cannot use root update here. This // pattern is creating other ops, so if the diff --git a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp index d15aed405b31..c1f6e8fd31bb 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp @@ -32,8 +32,9 @@ FailureOr getRootOperation(ArrayRef computeOps) { if (auto linalgOp = dyn_cast(op)) { // Do not treat linalg ops that are all parallel as root operations in // this sweep. - if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) + if (linalgOp.getNumLoops() == linalgOp.getNumParallelLoops()) { continue; + } // All other linalg ops are root ops. rootOperation = op; diff --git a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp index b2983b5bfe4a..02a03a6d2295 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/EncodingUtils.cpp @@ -57,8 +57,9 @@ FailureOr> getInnerTileSizesOfrImpl( if (ShapedType::isStaticShape(staticTileSizes)) { if (!materializeEncodingInfo.scalableTiles || llvm::none_of(materializeEncodingInfo.scalableTiles.value(), - [](bool scalable) { return scalable; })) + [](bool scalable) { return scalable; })) { return getAsOpFoldResult(rewriter.getI64ArrayAttr(staticTileSizes)); + } // In this case, we have scalable tiles present and we have to generate the // necessary vscale operation and the corresponding static_size * vscale // values. diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 4c711767a532..b444cdf50d0c 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -122,27 +122,32 @@ bool canPerformVectorAccessUsingAllThreads(ArrayRef shape, // Verify that each dimension of the shape can be distributed on the // threads // For zero dim tensor, consider it's too small to access using all threads. - if (shape.size() == 0) + if (shape.size() == 0) { return false; + } int64_t threadsAvailable = threadCount; for (const auto &[index, dim] : llvm::enumerate(llvm::reverse(shape))) { int64_t numElementPerThread = index == 0 ? vectorSize : 1; int64_t numThreads = dim / numElementPerThread; - if (numThreads == 0) + if (numThreads == 0) { return false; + } if (numThreads > threadsAvailable) { // If there are no enough remaining threads to distribute the current // dimension, try to use all remaining threads. But we still need to make // sure all work can be distributed to these threads evenly. - if (numThreads % threadsAvailable != 0) + if (numThreads % threadsAvailable != 0) { return false; + } numThreads = threadsAvailable; } - if (threadsAvailable % numThreads != 0) + if (threadsAvailable % numThreads != 0) { return false; + } threadsAvailable = threadsAvailable / numThreads; - if (threadsAvailable == 1) + if (threadsAvailable == 1) { break; + } } return threadsAvailable == 1; } @@ -200,8 +205,9 @@ FailureOr getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { FailureOr> tileSizes = getGPUTileSize(funcOp, tilingLevel); - if (failed(tileSizes)) + if (failed(tileSizes)) { return failure(); + } scf::SCFTileSizeComputationFunction computeFn = [tileSizes](OpBuilder &builder, Operation *op) -> SmallVector { @@ -230,16 +236,18 @@ std::optional allocateWorkgroupMemory(OpBuilder &builder, mlir::FunctionOpInterface funcOp = subview->getParentOfType(); - if (!funcOp) + if (!funcOp) { return std::nullopt; + } // The subview size bounds are expected to be constant; they specify the shape // of the allocation. SmallVector shape; for (Value bound : sizeBounds) { APInt value; - if (!matchPattern(bound, m_ConstantInt(&value))) + if (!matchPattern(bound, m_ConstantInt(&value))) { return std::nullopt; + } shape.push_back(value.getSExtValue()); } @@ -272,10 +280,12 @@ static bool propagateCopyDestIntoProducerFill(memref::CopyOp copyOp) { } auto fillOp = dyn_cast(prevOp); - if (!fillOp) + if (!fillOp) { break; - if (fillOp.output() != copyOp.getSource()) + } + if (fillOp.output() != copyOp.getSource()) { break; + } // Move the fillOp and change the destination to the copy destination. fillOp->moveBefore(copyOp); fillOp.getOutputsMutable().assign(copyOp.getTarget()); @@ -327,10 +337,12 @@ propagateCopySourceIntoConsumerGeneric(memref::CopyOp copyOp, auto consumer = dyn_cast(nextOp); if (!consumer || consumer.getNumDpsInits() != 1 || !consumer.getMatchingIndexingMap(consumer.getDpsInitOperand(0)) - .isIdentity()) + .isIdentity()) { break; - if (*consumer.getOutputs().begin() != copyOp.getTarget()) + } + if (*consumer.getOutputs().begin() != copyOp.getTarget()) { break; + } insertInputValueIntoGeneric(copyOp.getSource(), consumer); toDelete.push_back(consumer); return true; @@ -346,12 +358,14 @@ void propagateSharedMemoryCopy(mlir::FunctionOpInterface funcOp) { funcOp.walk([&toDelete](memref::CopyOp copyOp) { if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { if (propagateCopyDestIntoProducerFill(copyOp) || - propagateCopySourceIntoConsumerGeneric(copyOp, toDelete)) + propagateCopySourceIntoConsumerGeneric(copyOp, toDelete)) { toDelete.push_back(copyOp.getOperation()); + } } }); - for (Operation *op : toDelete) + for (Operation *op : toDelete) { op->erase(); + } } void insertBarriersAroundSharedMemoryCopy(mlir::FunctionOpInterface funcOp) { @@ -461,16 +475,18 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input, // integer type. auto unpack = [loc, &builder, needsPacking, equivIntType, origInputType](Value packedVal) -> Value { - if (!needsPacking) + if (!needsPacking) { return packedVal; + } auto asInt = arith::TruncIOp::create(builder, loc, equivIntType, packedVal); return arith::BitcastOp::create(builder, loc, origInputType, asInt); }; auto pack = [loc, &builder, needsPacking, equivIntType, shuffleIntType](Value unpackedVal) -> Value { - if (!needsPacking) + if (!needsPacking) { return unpackedVal; + } auto asInt = arith::BitcastOp::create(builder, loc, equivIntType, unpackedVal); return arith::ExtUIOp::create(builder, loc, shuffleIntType, asInt); @@ -667,8 +683,9 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { return nativeSize; } if (auto writeOp = dyn_cast(op)) { - if (writeOp.getVectorType().getRank() < 2) + if (writeOp.getVectorType().getRank() < 2) { return std::nullopt; + } SmallVector nativeSize(writeOp.getVectorType().getRank() - 2, 1); nativeSize.append({m, n}); return nativeSize; @@ -679,11 +696,13 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { VectorType sliceType; for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } return llvm::to_vector(sliceType.getShape()); @@ -692,8 +711,9 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { if (auto vecType = dyn_cast(op->getResultTypes()[0])) { // TODO: The condition for unrolling elementwise should be restricted // only to operations that need unrolling (connected to the contract). - if (vecType.getRank() < 2) + if (vecType.getRank() < 2) { return std::nullopt; + } // First check whether there is a slice to infer the shape from. This is // required for cases where the accumulator type differs from the input @@ -702,15 +722,18 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { VectorType sliceType; for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } - if (sliceType) + if (sliceType) { return llvm::to_vector(sliceType.getShape()); + } // Else unroll for trailing elementwise. SmallVector nativeSize(vecType.getRank() - 2, 1); @@ -729,12 +752,15 @@ std::optional> getWmmaNativeVectorSize(Operation *op) { static std::optional getVectorContractOpOperandId(vector::ContractionOp contractOp, OpResult result) { - if (contractOp.getLhs() == result) + if (contractOp.getLhs() == result) { return 0; - if (contractOp.getRhs() == result) + } + if (contractOp.getRhs() == result) { return 1; - if (contractOp.getAcc() == result) + } + if (contractOp.getAcc() == result) { return 2; + } return std::nullopt; } @@ -747,24 +773,30 @@ getVectorContractOpOperandIdForVectorReadOp(Operation *op) { // Check if the vector::TransferReadOp is consumed directly by // vector::ContractionOp. - if (op->use_empty()) + if (op->use_empty()) { return std::nullopt; + } Operation *firstLevelUser = *((op->getUsers()).begin()); - if (!firstLevelUser) + if (!firstLevelUser) { return std::nullopt; - if (auto contractOp = dyn_cast(firstLevelUser)) + } + if (auto contractOp = dyn_cast(firstLevelUser)) { return getVectorContractOpOperandId(contractOp, op->getResult(0)); + } // Check if the vector::TransferReadOp is consumed indirectly by // vector::ContractionOp. Only check until the second level of use-def chain. - if (firstLevelUser->use_empty()) + if (firstLevelUser->use_empty()) { return std::nullopt; + } Operation *secondLevelUser = *((firstLevelUser->getUsers()).begin()); - if (!secondLevelUser) + if (!secondLevelUser) { return std::nullopt; - if (auto contractOp = dyn_cast(secondLevelUser)) + } + if (auto contractOp = dyn_cast(secondLevelUser)) { return getVectorContractOpOperandId(contractOp, firstLevelUser->getResult(0)); + } return std::nullopt; } @@ -780,15 +812,15 @@ std::optional> getMmaNativeVectorSize(Operation *op) { Type sourceType = contract.getLhsType().getElementType(); // Set mmaShapeK based on sourceType. - if (sourceType.isInteger(4)) + if (sourceType.isInteger(4)) { mmaShapeK = 64; - else if (sourceType.isInteger(8)) + } else if (sourceType.isInteger(8)) { mmaShapeK = 32; - else if (sourceType.isF16() || sourceType.isBF16()) + } else if (sourceType.isF16() || sourceType.isBF16()) { mmaShapeK = 16; - else if (sourceType.isF32()) + } else if (sourceType.isF32()) { mmaShapeK = 8; - else { + } else { LDBG() << "unsupported shape for vector.contract: "; return std::nullopt; } @@ -803,8 +835,9 @@ std::optional> getMmaNativeVectorSize(Operation *op) { // Shape of warp-level vector write operation. if (auto writeOp = dyn_cast(op)) { - if (writeOp.getVectorType().getRank() < 2) + if (writeOp.getVectorType().getRank() < 2) { return std::nullopt; + } SmallVector outputShape(writeOp.getVectorType().getRank() - 2, 1); outputShape.append({mmaShapeM, mmaShapeN}); LDBG() << "shape for vector.xfer_write: " << llvm::interleaved(outputShape); @@ -892,11 +925,13 @@ std::optional> getMmaNativeVectorSize(Operation *op) { VectorType sliceType; for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } LDBG() << "shape for vector.xfer_read: " @@ -911,19 +946,24 @@ std::optional> getMmaNativeVectorSize(Operation *op) { bool hasGlobalMemoryAddressSpace(MemRefType memrefType) { Attribute addrSpace = memrefType.getMemorySpace(); - if (!addrSpace) + if (!addrSpace) { return true; + } auto intAttr = dyn_cast(addrSpace); // Accept both default numeric address space and HAL descriptor type address // space--the former is used by LLVMGPU while the latter is used by SPIR-V. - if (intAttr && intAttr.getInt() == 0) + if (intAttr && intAttr.getInt() == 0) { return true; + } auto gpuAttr = dyn_cast(addrSpace); - if (gpuAttr && gpuAttr.getValue() == gpu::AddressSpace::Global) + if (gpuAttr && gpuAttr.getValue() == gpu::AddressSpace::Global) { return true; + } auto amdgpuAttr = dyn_cast(addrSpace); - if (amdgpuAttr && amdgpuAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer) + if (amdgpuAttr && + amdgpuAttr.getValue() == amdgpu::AddressSpace::FatRawBuffer) { return true; + } return isa(addrSpace); } @@ -970,8 +1010,9 @@ bool sharedMemTransposeFilter(AffineMap indexMap) { //===----------------------------------------------------------------------===// IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context) { - if (clTestTarget.empty()) + if (clTestTarget.empty()) { return nullptr; + } auto [archAndFeatures, backend] = StringRef(clTestTarget).split("@"); if (backend.empty()) { @@ -979,16 +1020,17 @@ IREE::GPU::TargetAttr getCLGPUTarget(MLIRContext *context) { // for cases like "ampere" which can be accepted by both CUDA and Vulkan; // it's very limited. So it's targeting common cases to make writing tests // simpler. - if (StringRef(clTestTarget).starts_with("sm_")) + if (StringRef(clTestTarget).starts_with("sm_")) { backend = "cuda"; - else if (StringRef(clTestTarget).starts_with("gfx")) + } else if (StringRef(clTestTarget).starts_with("gfx")) { backend = "hip"; - else if (StringRef(clTestTarget).starts_with("adreno")) + } else if (StringRef(clTestTarget).starts_with("adreno")) { backend = "vulkan"; - else if (StringRef(clTestTarget).starts_with("apple")) + } else if (StringRef(clTestTarget).starts_with("apple")) { backend = "vulkan"; - else if (StringRef(clTestTarget).starts_with("valhall")) + } else if (StringRef(clTestTarget).starts_with("valhall")) { backend = "vulkan"; + } } auto [arch, features] = StringRef(archAndFeatures).split(':'); // Use the target specified in the command line for testing purposes. @@ -1041,11 +1083,13 @@ void addConfigWavesPerEu(MLIRContext *context, int64_t wavesPerEu, std::optional getGPUSubgroupSize(mlir::FunctionOpInterface func) { // First try to see if there is a subgroup size chosen in the CodeGen pipeline // configuration. - if (std::optional subgroupSize = getSubgroupSize(func)) + if (std::optional subgroupSize = getSubgroupSize(func)) { return subgroupSize.value(); + } // Then try to find the subgroup size from the target description. - if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func)) + if (IREE::GPU::TargetAttr target = getGPUTargetAttr(func)) { return target.getPreferredSubgroupSize(); + } return std::nullopt; } diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp index d07d8e28f060..70ed11c1b093 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinalgOpInfo.cpp @@ -127,8 +127,9 @@ void LinalgOpInfo::computeInfo(LinalgOp linalgOp) { bool isMatmulOrBatchMatmul(linalg::LinalgOp linalgOp) { // (Batch) matmul should be a reduction op with 2/3 parallel dimensions. if (!linalg::isaContractionOpInterface(linalgOp) || - !llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops())) + !llvm::is_contained({2u, 3u}, linalgOp.getNumParallelLoops())) { return false; + } // Also exclude the case of matvec, which has only one non-unit parallel dim. // They should go down different pipelines. diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index 40a174f5a877..f7563963b036 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -93,8 +93,9 @@ mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, llvm::map_to_vector<8>(sourceBlock, [&](Operation &op) { return &op; }); for (auto &sourceOp : allOps) { - if (sourceOp->hasTrait()) + if (sourceOp->hasTrait()) { continue; + } if (auto symbolOp = dyn_cast(sourceOp)) { auto symbolName = symbolOp.getName(); @@ -172,13 +173,15 @@ replaceEntryPointUses(mlir::ModuleOp moduleOp, auto replaceSymbolRefs = [](Operation *rootOp, const DenseMap &map) { auto allUses = SymbolTable::getSymbolUses(rootOp); - if (!allUses) + if (!allUses) { return; + } for (auto use : *allUses) { auto oldAttr = use.getSymbolRef(); auto newAttr = map.lookup(oldAttr); - if (!newAttr) + if (!newAttr) { continue; + } auto newDict = use.getUser()->getAttrDictionary().replace( [&](Attribute attr) -> std::pair { if (attr == oldAttr) { @@ -267,8 +270,9 @@ LogicalResult linkExecutablesInto( // Merge sources into the linked source listing. if (auto sourcesAttr = variantOp.getSourcesAttr()) { - for (auto sourceAttr : sourcesAttr.getValue()) + for (auto sourceAttr : sourcesAttr.getValue()) { linkedSourceAttrs.set(sourceAttr.getName(), sourceAttr.getValue()); + } } // Remap variant refs. diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp index f4797747da7e..12d68539d0e7 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.cpp @@ -124,8 +124,9 @@ StringRef getDeleteMarker() { return "delete"; } StringRef getMarkerOrNull(Operation *op) { StringAttr attr = op->getAttrOfType(LinalgTransforms::kLinalgTransformMarker); - if (!attr) + if (!attr) { return ""; + } return attr.getValue(); } diff --git a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h index f2c9a3fa80f6..3d2325d42c38 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/MarkerUtils.h @@ -51,8 +51,9 @@ struct LinalgTransformationFilter { bool hasReplacementFilter(Operation *op) const; LinalgTransformationFilter &addFilter(const FilterFunction &f) { - if (f) + if (f) { filters.push_back(f); + } return *this; } diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 8e7d511128cf..892194dd2513 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -296,8 +296,9 @@ std::array getMaxWorkgroupCount(Operation *op) { bool isReadOnly(Value v) { Operation *definingOp = v.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return false; + } return TypeSwitch(definingOp) .Case( [&](arith::ConstantOp constantOp) { return true; }) @@ -536,8 +537,9 @@ LogicalResult setDefaultCustomOpLoweringConfig( for (Operation &op : dummyFuncOp.getBody().front()) { auto currLoweringConfig = getLoweringConfig(&op); - if (!currLoweringConfig) + if (!currLoweringConfig) { continue; + } // Translate the lowering config to the original operation. if (std::optional originalOperation = @@ -546,8 +548,9 @@ LogicalResult setDefaultCustomOpLoweringConfig( } auto currWorkgroupTileSizes = currLoweringConfig.getWorkgroupTileSizes(); - if (currWorkgroupTileSizes.empty()) + if (currWorkgroupTileSizes.empty()) { continue; + } workgroupTileSizes = currWorkgroupTileSizes; workgroupInterchange = currLoweringConfig.getWorkgroupInterchange(); } @@ -572,8 +575,9 @@ LogicalResult setDefaultCustomOpLoweringConfig( /// Returns the first of `exprs` which is of the type `T`. template static AffineExpr getAffineExprOfType(ArrayRef exprs) { - if (auto it = llvm::find_if(exprs, llvm::IsaPred); it != exprs.end()) + if (auto it = llvm::find_if(exprs, llvm::IsaPred); it != exprs.end()) { return *it; + } return nullptr; } @@ -611,8 +615,9 @@ static std::optional getDimension(Operation *op) { } template static std::optional getDimension(Operation *op) { - if (!op) + if (!op) { return std::nullopt; + } if (auto dimension = getDimension(op)) { return dimension; } @@ -630,8 +635,9 @@ checkDimensions(ArrayRef vals, std::optional refDimension = std::nullopt) { for (auto v : vals) { auto currDimension = getDimension(v.getDefiningOp()); - if (!currDimension) + if (!currDimension) { return std::nullopt; + } if (refDimension) { if (refDimension.value() != currDimension.value()) { return std::nullopt; @@ -891,8 +897,9 @@ isTiledAndDistributedLoop(scf::ForOp forOp) { countDim = ifx.getDimIndex(); } - if (!idDim || !countDim) + if (!idDim || !countDim) { return std::nullopt; + } Builder b(forOp.getContext()); loopInfo.untiledLowerBound = b.getIndexAttr(0); @@ -1083,8 +1090,9 @@ FailureOr getSoftwarePipelineStoreStage(DictionaryAttr config) { /// Returns a small tiling factor for the given reduction `dimSize`. /// Returns 0 to avoid tiling. int getReductionTilingFactor(int64_t dimSize) { - if (dimSize % 4 == 0) + if (dimSize % 4 == 0) { return 4; + } // Try to find the smallest prime factor as the tiling factor. As a trade off // between generated code size and compilation time, only look at prime @@ -1092,8 +1100,9 @@ int getReductionTilingFactor(int64_t dimSize) { static constexpr std::array primeNumbers = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47}; for (int n : primeNumbers) { - if (dimSize % n == 0) + if (dimSize % n == 0) { return n; + } } return 1; // Otherwise just tile with size 1. @@ -1221,16 +1230,19 @@ Value findOrCreateSubspanBuffer( // Look for an existing op. Block *block = subspanOp->getBlock(); for (Operation &op : *block) { - if (&op == subspanOp.getOperation()) + if (&op == subspanOp.getOperation()) { break; + } auto bufferSubspanOp = dyn_cast(&op); - if (!bufferSubspanOp) + if (!bufferSubspanOp) { continue; + } auto bufferMemrefType = dyn_cast(bufferSubspanOp.getResult().getType()); - if (!bufferMemrefType) + if (!bufferMemrefType) { continue; + } if (bufferSubspanOp.getBinding() != subspanOp.getBinding() || bufferSubspanOp.getDescriptorType() != subspanOp.getDescriptorType() || @@ -1238,14 +1250,16 @@ Value findOrCreateSubspanBuffer( !llvm::equal(bufferSubspanOp.getDynamicDims(), subspanOp.getDynamicDims()) || bufferSubspanOp.getAlignment() != subspanOp.getAlignment() || - memRefType != bufferMemrefType) + memRefType != bufferMemrefType) { continue; + } if (useRocdlBuffers && bufferSubspanOp->hasOneUse()) { auto castOp = dyn_cast( *bufferSubspanOp->getUsers().begin()); - if (!castOp) + if (!castOp) { continue; + } return castOp.getResult(); } return bufferSubspanOp.getResult(); @@ -1284,8 +1298,9 @@ Operation *setInsertionPointAfterLastValue(OpBuilder &builder, definingOp = &cast(val).getOwner()->getOperations().front(); } - if (!definingOp) + if (!definingOp) { continue; + } if (lastOp && definingOp == lastOp) { // Combine 'setInsertionPointBefore' by ANDing because we only want to set // the insertion point before the last op if all values this operation is @@ -1293,8 +1308,9 @@ Operation *setInsertionPointAfterLastValue(OpBuilder &builder, setInsertionPointBefore &= isa(val); continue; } - if (lastOp && domInfo.dominates(definingOp, lastOp)) + if (lastOp && domInfo.dominates(definingOp, lastOp)) { continue; + } lastOp = definingOp; // For block arguments we want the insertion point to be at the start of @@ -1591,12 +1607,14 @@ void sinkOpsInCFG(const SmallVector &allocs, SmallVector getStaticNumWorkgroups(mlir::FunctionOpInterface funcOp) { SmallVector result; std::optional exportOp = getEntryPoint(funcOp); - if (!exportOp) + if (!exportOp) { return result; + } Block *body = exportOp->getWorkgroupCountBody(); - if (!body) + if (!body) { return result; + } auto returnOp = cast(body->getTerminator()); assert(returnOp.getNumOperands() == 3); @@ -1684,9 +1702,10 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::UB, {shapedValue, dimNum}, /*stopCondition=*/nullptr, /*closedUB=*/true); - if (succeeded(maybeDimBoundSize)) + if (succeeded(maybeDimBoundSize)) { return DimBoundSize{/*baseSize=*/*maybeDimBoundSize, /*scalable=*/false}; + } return failure(); } FailureOr maybeDimBound = @@ -1694,21 +1713,26 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, shapedValue, dimNum, /*vscaleMin=*/vscaleRange->vscaleMin, /*vscaleMax=*/vscaleRange->vscaleMax, presburger::BoundType::UB); - if (failed(maybeDimBound)) + if (failed(maybeDimBound)) { return failure(); + } auto boundSize = maybeDimBound->getSize(); - if (succeeded(boundSize)) + if (succeeded(boundSize)) { return boundSize; - if (roundUp == RoundUpVscaleMultiple::No) + } + if (roundUp == RoundUpVscaleMultiple::No) { return failure(); + } // If the upper bound map is of the form `add(subExpr, cst)` (cst <= 0), // round it up to `subExpr` (and try matching the bound again). auto binOp = dyn_cast(maybeDimBound->map.getResult(0)); - if (!binOp || binOp.getKind() != AffineExprKind::Add) + if (!binOp || binOp.getKind() != AffineExprKind::Add) { return failure(); + } auto cst = dyn_cast(binOp.getRHS()); - if (!cst || cst.getValue() > 0) + if (!cst || cst.getValue() > 0) { return failure(); + } DimBound roundedDimBound{AffineMap::get(maybeDimBound->map.getNumDims(), maybeDimBound->map.getNumSymbols(), binOp.getLHS())}; @@ -2052,8 +2076,9 @@ std::optional static inferSizesFromMixedSizes( } std::optional inferSizesFromIR(Value val) { - if (!val.getDefiningOp()) + if (!val.getDefiningOp()) { return std::nullopt; + } std::optional result; LDBG() << "Inferring sizes for: " << val; @@ -2076,20 +2101,23 @@ std::optional inferSizesFromIR(Value val) { } std::optional getConstantIndex(Value value) { - if (!isa(value.getType())) + if (!isa(value.getType())) { return std::nullopt; + } APInt val; - if (!matchPattern(value, m_ConstantInt(&val))) + if (!matchPattern(value, m_ConstantInt(&val))) { return std::nullopt; + } return val.getSExtValue(); } bool alwaysRunsFirstIteration(scf::ForOp op) { // Can't perform the analysis if the loops's bounds aren't index-typed. - if (!op.getInductionVar().getType().isIndex()) + if (!op.getInductionVar().getType().isIndex()) { return false; + } FailureOr isLb = ValueBoundsConstraintSet::compare( getAsOpFoldResult(op.getLowerBound()), ValueBoundsConstraintSet::LT, getAsOpFoldResult(op.getUpperBound())); @@ -2098,8 +2126,9 @@ bool alwaysRunsFirstIteration(scf::ForOp op) { bool neverRunsSecondIteration(scf::ForOp op) { // Can't perform the analysis if the loops's bounds aren't index-typed. - if (!op.getInductionVar().getType().isIndex()) + if (!op.getInductionVar().getType().isIndex()) { return false; + } // If the upper bound (ub) is less than or equal to the loop step, then // lower bound + step must be greater than the upper bound, assuming the // lower bound is non-negative. From dbfa48e756cbaa97c5ec0de1d6cbe4489f4b45b3 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 16:01:06 -0500 Subject: [PATCH 47/71] Add braces in Codegen backends (GPU/CPU/SPIRV). NFC. 3/n (#23145) --- .../Codegen/Common/CPU/CPULowerToUKernels.cpp | 3 +- .../Codegen/Common/CPU/CPUPrepareUkernels.cpp | 30 ++- .../Common/GPU/AMDGPUDistributeContract.cpp | 3 +- .../Common/GPU/GPUCheckResourceUsage.cpp | 15 +- .../Common/GPU/GPUConvertToCoalescedDMA.cpp | 24 +- .../Common/GPU/GPUCreateFastSlowPath.cpp | 21 +- .../Codegen/Common/GPU/GPUDistribute.cpp | 21 +- .../Common/GPU/GPUDistributeScfFor.cpp | 3 +- .../GPU/GPUDistributeSharedMemoryCopy.cpp | 35 ++- .../Common/GPU/GPUDistributionPatterns.cpp | 9 +- .../GPU/GPUGreedilyDistributeToThreads.cpp | 3 +- .../Codegen/Common/GPU/GPUHeuristics.cpp | 3 +- .../Codegen/Common/GPU/GPUMultiBuffering.cpp | 9 +- .../GPUNestedLayoutDistributionPatterns.cpp | 9 +- .../Common/GPU/GPUPackToIntrinsics.cpp | 3 +- .../Codegen/Common/GPU/GPUPatterns.cpp | 27 ++- .../Codegen/Common/GPU/GPUPipelining.cpp | 87 ++++--- .../Common/GPU/GPUPromoteMatmulOperands.cpp | 6 +- .../Common/GPU/GPUReduceBankConflicts.cpp | 18 +- .../Codegen/Common/GPU/GPUTensorAlloc.cpp | 24 +- .../Codegen/Common/GPU/GPUTensorTile.cpp | 36 ++- .../compiler/Codegen/Common/GPU/GPUTile.cpp | 12 +- .../GPU/GPUTileAndConvertConvToMatmul.cpp | 3 +- .../Codegen/Common/GPU/GPUTileReduction.cpp | 15 +- .../Common/GPU/GPUVectorDistribution.cpp | 3 +- .../Common/GPU/VectorReductionToGPU.cpp | 48 ++-- .../Common/GPU/WorkgroupReordering.cpp | 9 +- .../Dialect/Codegen/IR/IREECodegenAttrs.cpp | 72 ++++-- .../Dialect/Codegen/IR/IREECodegenAttrs.h | 6 +- .../Dialect/Codegen/IR/IREECodegenDialect.cpp | 3 +- .../Dialect/Codegen/IR/IREECodegenOps.cpp | 18 +- .../Codegen/Dialect/Codegen/IR/UKernelOps.cpp | 3 +- .../Dialect/GPU/IR/DerivedConfigUtils.cpp | 3 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 15 +- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 18 +- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 43 ++-- .../GPU/TargetUtils/ReductionConfigUtils.cpp | 9 +- .../Transforms/BufferizationInterfaces.cpp | 15 +- .../DistributeInnerTiledToLanes.cpp | 6 +- .../Dialect/GPU/Transforms/Transforms.cpp | 15 +- .../Codegen/Dialect/PCF/IR/PCFTypes.cpp | 5 +- .../PCF/Transforms/ConvertSRefToMemRef.cpp | 13 +- .../Dialect/VectorExt/IR/VectorExtAttrs.cpp | 6 +- .../Dialect/VectorExt/IR/VectorExtOps.cpp | 62 +++-- .../Transforms/BufferizationInterfaces.cpp | 3 +- .../Codegen/LLVMCPU/ConvertToLLVM.cpp | 12 +- .../compiler/Codegen/LLVMCPU/DispatchABI.cpp | 30 ++- .../Codegen/LLVMCPU/KernelDispatch.cpp | 153 ++++++++---- .../LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp | 12 +- .../LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp | 6 +- .../LLVMCPU/LLVMCPUAssignImportOrdinals.cpp | 6 +- .../LLVMCPUCheckIRBeforeLLVMConversion.cpp | 6 +- .../LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp | 6 +- .../compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp | 6 +- .../LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp | 3 +- .../compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp | 9 +- .../LLVMCPUVectorTransposeLowering.cpp | 3 +- .../iree/compiler/Codegen/LLVMCPU/Utils.cpp | 19 +- .../LLVMCPU/VectorContractCustomKernels.cpp | 8 +- .../Codegen/LLVMGPU/ConvertToLLVM.cpp | 48 ++-- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 89 ++++--- .../LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp | 6 +- .../LLVMGPUCastAddressSpaceFunction.cpp | 3 +- .../LLVMGPU/LLVMGPULowerExecutableTarget.cpp | 3 +- .../LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp | 3 +- .../LLVMGPUTensorCoreVectorization.cpp | 6 +- .../LLVMGPU/LLVMGPUTileAndDistribute.cpp | 6 +- .../Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp | 19 +- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 12 +- .../ROCDLAnnotateKernelForTranslation.cpp | 12 +- .../ROCDLBufferInstructionsOptimization.cpp | 6 +- .../ROCDLConfigureBufferInstructions.cpp | 12 +- .../TransformExtensions/LLVMGPUExtensions.cpp | 178 +++++++++----- .../Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp | 39 ++- .../Utils/ROCDLPrefetchSharedMemoryCopy.cpp | 39 ++- .../iree/compiler/Codegen/SPIRV/AMDConfig.cpp | 9 +- .../compiler/Codegen/SPIRV/AdrenoConfig.cpp | 12 +- .../compiler/Codegen/SPIRV/AppleConfig.cpp | 6 +- .../Codegen/SPIRV/ConvertToSPIRVPass.cpp | 27 ++- .../compiler/Codegen/SPIRV/KernelConfig.cpp | 224 ++++++++++++------ .../compiler/Codegen/SPIRV/MaliConfig.cpp | 6 +- .../compiler/Codegen/SPIRV/NVIDIAConfig.cpp | 6 +- .../iree/compiler/Codegen/SPIRV/Passes.cpp | 3 +- .../SPIRV/SPIRVAnnotateWinogradLoops.cpp | 6 +- .../SPIRV/SPIRVBreakDownLargeVector.cpp | 15 +- .../Codegen/SPIRV/SPIRVConvertGPUTarget.cpp | 39 ++- .../Codegen/SPIRV/SPIRVEmulateI64.cpp | 30 ++- .../SPIRVEraseStorageBufferStaticShape.cpp | 6 +- .../SPIRV/SPIRVInitialVectorLowering.cpp | 42 ++-- .../Codegen/SPIRV/SPIRVLinkExecutables.cpp | 6 +- .../SPIRV/SPIRVMapMemRefStorageClass.cpp | 6 +- .../SPIRVMaterializeExecutableConditions.cpp | 6 +- .../SPIRV/SPIRVSelectLoweringStrategy.cpp | 3 +- .../Codegen/SPIRV/SPIRVTileAndDistribute.cpp | 9 +- .../Codegen/SPIRV/SPIRVTileAndPromote.cpp | 27 ++- .../SPIRVTileAndVectorizeToCooperativeOps.cpp | 39 ++- .../Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp | 57 +++-- .../src/iree/compiler/Codegen/SPIRV/Utils.cpp | 9 +- .../iree/compiler/Codegen/SPIRV/Verifiers.cpp | 9 +- .../compiler/Codegen/VMVX/KernelDispatch.cpp | 3 +- .../VMVX/VMVXAssignConstantOrdinals.cpp | 9 +- .../VMVX/VMVXLowerExecutableTargetPass.cpp | 3 +- .../VMVX/VMVXLowerLinalgMicrokernels.cpp | 60 +++-- 103 files changed, 1488 insertions(+), 743 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp index a599c5ce6419..971a4d3c7c18 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp @@ -77,8 +77,9 @@ class CPULowerToUKernelsPass /// Returns `true` if an `outsOperand` value is initialized to zero. static bool isInitializedToZero(Value outsOperand) { auto fillOp = outsOperand.getDefiningOp(); - if (!fillOp) + if (!fillOp) { return false; + } Value fillVal = fillOp.getDpsInputOperand(0)->get(); return matchPattern(fillVal, m_Zero()) || matchPattern(fillVal, m_AnyZeroFloat()); diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp index 4c35dd38456c..3c3865e26b00 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp @@ -60,8 +60,9 @@ static void tileNonPackedDimsFor3DPackOps(RewriterBase &rewriter, // Skip the tiling if the size is already 1. RankedTensorType srcType = packOp.getSourceType(); for (auto [idx, val] : llvm::enumerate(tileSizes)) { - if (val && srcType.getDimSize(idx) == 1) + if (val && srcType.getDimSize(idx) == 1) { return; + } } auto outerDimsPerm = packOp.getOuterDimsPerm(); @@ -96,8 +97,9 @@ static void tileNonPackedDimsFor5DPUnpackOps(RewriterBase &rewriter, // Skip the tiling if the size is already 1. RankedTensorType destType = unpackOp.getDestType(); for (auto [idx, val] : llvm::enumerate(tileSizes)) { - if (val && destType.getDimSize(idx) == 1) + if (val && destType.getDimSize(idx) == 1) { return; + } } auto tilingInterfaceOp = cast(unpackOp.getOperation()); @@ -157,8 +159,9 @@ dropBatchTileSize(IREE::CPU::LoweringConfigAttr config) { SmallVector newItems; for (auto [level, tileSizes, scalableTileFlags] : tilingInfo) { tileSizes.erase(tileSizes.begin()); - if (!scalableTileFlags.empty()) + if (!scalableTileFlags.empty()) { scalableTileFlags.erase(scalableTileFlags.begin()); + } newItems.emplace_back( IREE::CPU::getTilingLevelName(level), IREE::CPU::LoweringConfigAttr::getTilingLevelAttr( @@ -262,16 +265,18 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern { llvm::SmallDenseSet s; s.insert(packOp.getInnerDimsPos().begin(), packOp.getInnerDimsPos().end()); for (auto dim : llvm::seq(0, packOp.getSourceRank())) { - if (s.contains(dim)) + if (s.contains(dim)) { continue; + } srcPos = dim; break; } int destPos = srcPos; for (auto [idx, val] : llvm::enumerate(packOp.getOuterDimsPerm())) { - if (val == srcPos) + if (val == srcPos) { destPos = idx; + } } if (packOp.getSourceType().getDimSize(srcPos) != 1) { @@ -284,15 +289,17 @@ struct Convert3DPackto2DPackPattern : public OpRewritePattern { SmallVector newInnerDimsPos(packOp.getInnerDimsPos()); for (auto &val : newInnerDimsPos) { assert(val != srcPos); - if (val > srcPos) + if (val > srcPos) { val--; + } } SmallVector newOuterDimsPerm(packOp.getOuterDimsPerm()); if (!newOuterDimsPerm.empty()) { newOuterDimsPerm.erase(newOuterDimsPerm.begin() + destPos); for (auto &val : newOuterDimsPerm) { - if (val > srcPos) + if (val > srcPos) { val--; + } } } @@ -341,8 +348,9 @@ struct Convert5DUnPackto4DUnPackPattern int64_t destPos = 0; for (auto [idx, val] : llvm::enumerate(seqOrOuterDimsPerm)) { - if (s.contains(val)) + if (s.contains(val)) { continue; + } srcPos = idx; destPos = val; break; @@ -361,16 +369,18 @@ struct Convert5DUnPackto4DUnPackPattern SmallVector newInnerDimsPos(unpackOp.getInnerDimsPos()); for (auto &val : newInnerDimsPos) { assert(val != destPos); - if (val > destPos) + if (val > destPos) { val--; + } } SmallVector newOuterDimsPerm(unpackOp.getOuterDimsPerm()); if (!newOuterDimsPerm.empty()) { newOuterDimsPerm.erase(newOuterDimsPerm.begin() + srcPos); for (auto &val : newOuterDimsPerm) { - if (val > destPos) + if (val > destPos) { val--; + } } } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp index 3b5958aa2f1b..4324d169ec78 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp @@ -244,8 +244,9 @@ struct DistributeContract final : OpDistributionPattern { int64_t lhsKBatch = lhsLayout.getBatchTile()[lhsK]; int64_t rhsKBatch = rhsLayout.getBatchTile()[rhsK]; - if (lhsKBatch != rhsKBatch) + if (lhsKBatch != rhsKBatch) { return std::nullopt; + } return lhsKBatch; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp index 3db1921caa86..10836994c951 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCheckResourceUsage.cpp @@ -29,8 +29,9 @@ static int shapedTypeStaticSize( std::function getIndexBitwidth) { int allocSize = 1; for (auto dimSize : shapedType.getShape()) { - if (ShapedType::isDynamic(dimSize)) + if (ShapedType::isDynamic(dimSize)) { continue; + } allocSize *= dimSize; } if (auto elementType = dyn_cast(shapedType.getElementType())) { @@ -42,8 +43,9 @@ static int shapedTypeStaticSize( assert(getIndexBitwidth && "getIndexBitwidth should have been set earlier"); allocSize *= getIndexBitwidth(func); - } else + } else { allocSize *= IREE::Util::getTypeBitWidth(shapedType.getElementType()); + } } return allocSize; } @@ -53,19 +55,22 @@ static int shapedTypeStaticSize( static LogicalResult checkGPUAllocationSize( mlir::FunctionOpInterface funcOp, unsigned limit, std::function getIndexBitwidth) { - if (funcOp.getFunctionBody().empty()) + if (funcOp.getFunctionBody().empty()) { return success(); + } SmallVector allocOps; funcOp.walk([&](memref::AllocOp allocOp) { allocOps.push_back(allocOp); }); - if (allocOps.empty()) + if (allocOps.empty()) { return success(); + } int cumSize = 0; for (auto allocOp : allocOps) { auto allocType = cast(allocOp.getType()); - if (!hasSharedMemoryAddressSpace(allocType)) + if (!hasSharedMemoryAddressSpace(allocType)) { continue; + } if (!allocOp.getDynamicSizes().empty()) { return allocOp.emitOpError( diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp index a1de04220136..8d4f0fad59a6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp @@ -108,8 +108,9 @@ computeThreadNumThreadsImpl(OpBuilder &builder, Operation *op, // Find minimum elements per transfer across all DMA sizes. int64_t minElementsPerTransfer = std::numeric_limits::max(); for (int64_t dmaSize : dmaSizes) { - if (dmaSize % elementBits != 0) + if (dmaSize % elementBits != 0) { continue; + } int64_t elementsPerLane = dmaSize / elementBits; int64_t elementsPerTransfer = *subgroupSize * elementsPerLane; minElementsPerTransfer = @@ -424,8 +425,9 @@ struct ConvertGatherToCoalescedDMA int64_t minElementsPerTransfer = std::numeric_limits::max(); for (int64_t dmaSize : dmaSizes) { - if (dmaSize % elementBits != 0) + if (dmaSize % elementBits != 0) { continue; + } int64_t elementsPerLane = dmaSize / elementBits; int64_t elementsPerTransfer = *subgroupSize * elementsPerLane; minElementsPerTransfer = @@ -611,20 +613,23 @@ struct GPUConvertToCoalescedDMAPass final OpTy op) { MLIRContext *context = &getContext(); auto dmaConfig = getLoweringConfig(op); - if (!dmaConfig) + if (!dmaConfig) { return failure(); + } // Get the function containing this operation. auto funcOp = op->template getParentOfType(); - if (!funcOp) + if (!funcOp) { return failure(); + } // Get workgroup size and subgroup size from translation_info. std::optional> workgroupSize = getWorkgroupSize(funcOp); std::optional subgroupSize = getSubgroupSize(funcOp); - if (!workgroupSize || !subgroupSize) + if (!workgroupSize || !subgroupSize) { return failure(); + } // Calculate number of subgroups per dimension. // workgroupSize is [X, Y, Z], and we divide by subgroupSize to get warps. @@ -670,8 +675,9 @@ struct GPUConvertToCoalescedDMAPass final // We need innermostDim >= subgroupSize * minElementsPerLane. int64_t minElementsPerTransfer = std::numeric_limits::max(); for (int64_t dmaSize : dmaSizes) { - if (dmaSize % elementBits != 0) + if (dmaSize % elementBits != 0) { continue; + } int64_t elementsPerLane = dmaSize / elementBits; int64_t elementsPerTransfer = *subgroupSize * elementsPerLane; minElementsPerTransfer = @@ -688,8 +694,9 @@ struct GPUConvertToCoalescedDMAPass final auto [tileSizes, numTiledDims] = computeSubgroupTileSizes(rewriter, shape, numWarps); - if (numTiledDims == 0) + if (numTiledDims == 0) { return failure(); + } scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes); @@ -735,8 +742,9 @@ struct GPUConvertToCoalescedDMAPass final }) .Default([](Operation *) { return failure(); }); - if (failed(tilingResult)) + if (failed(tilingResult)) { continue; + } // Replace the original op with the tiled version. rewriter.replaceOp(op, tilingResult->replacements); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp index efefcebb8cd7..b9f4cea258ae 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp @@ -58,21 +58,24 @@ static void applyFastSlowPathConversion(mlir::FunctionOpInterface funcOp) { // Find the anchor tensor.pad op, from which we get the conditions for // switching between the fast and slow path. auto padOps = llvm::to_vector(body->getOps()); - if (llvm::size(padOps) != 1) + if (llvm::size(padOps) != 1) { return; + } tensor::PadOp padOp = *padOps.begin(); // If all padding sizes are zero, we don't need to do anything. SmallVector lowPads = padOp.getMixedLowPad(); SmallVector highPads = padOp.getMixedHighPad(); - if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) + if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero)) { return; + } IRRewriter rewriter(funcOp.getContext()); rewriter.setInsertionPoint(body->getTerminator()); SmallVector allOps; - for (Operation &op : body->without_terminator()) + for (Operation &op : body->without_terminator()) { allOps.push_back(&op); + } BackwardSliceOptions options; options.filter = [](Operation *op) { return true; }; @@ -96,13 +99,15 @@ static void applyFastSlowPathConversion(mlir::FunctionOpInterface funcOp) { } } Value ifCond = eqZeroCmpVals.front(); - for (Value cmp : llvm::ArrayRef(eqZeroCmpVals).drop_front()) + for (Value cmp : llvm::ArrayRef(eqZeroCmpVals).drop_front()) { ifCond = arith::AndIOp::create(rewriter, loc, ifCond, cmp); + } SmallVector cloneOps; for (Operation *op : allOps) { - if (!padSizeOps.contains(op)) + if (!padSizeOps.contains(op)) { cloneOps.push_back(op); + } } // Build the scf.if op itself. Clone all ops other than those used for @@ -122,15 +127,17 @@ static void applyFastSlowPathConversion(mlir::FunctionOpInterface funcOp) { }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { IRMapping bvm; - for (Operation *op : cloneOps) + for (Operation *op : cloneOps) { builder.clone(*op, bvm); + } scf::YieldOp::create(builder, loc); }; scf::IfOp::create(rewriter, padOp.getLoc(), ifCond, thenBuilder, elseBuilder); // All of these ops have been cloned to both regions. Erease them now. - for (Operation *op : llvm::reverse(cloneOps)) + for (Operation *op : llvm::reverse(cloneOps)) { rewriter.eraseOp(op); + } } struct GPUCreateFastSlowPathPass final diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp index e893d8a019ae..c5645046b66e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp @@ -25,8 +25,9 @@ replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, Block *parent, Value replacement, ArrayRef availableMappingSizes) { parent->walk([&](gpu::ThreadIdOp idOp) { - if (availableMappingSizes[static_cast(idOp.getDimension())] == 1) + if (availableMappingSizes[static_cast(idOp.getDimension())] == 1) { rewriter.replaceAllUsesWith(idOp.getResult(), replacement); + } }); } @@ -51,14 +52,17 @@ DiagnosedSilenceableFailure static mapNestedForallToThreadsImpl( diag = mlir::transform::gpu::mapOneForallToThreadsImpl( rewriter, std::nullopt, forallOp, blockDims, warpSize, syncAfterDistribute); - if (diag.isDefiniteFailure()) + if (diag.isDefiniteFailure()) { return WalkResult::interrupt(); - if (diag.succeeded()) + } + if (diag.succeeded()) { return WalkResult::skip(); + } return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) { return diag; + } // Replace ids of dimensions known to be 1 by 0 to simplify the IR. // Here, the result of mapping determines the available mapping sizes. @@ -96,16 +100,19 @@ struct GPUDistributePass final if (!hasWorkgroupMapping) { result = mapNestedForallToThreadsImpl( rewriter, forallOp, workgroupSize.value(), subgroupSize, false); - if (result.isDefiniteFailure()) + if (result.isDefiniteFailure()) { return WalkResult::interrupt(); - if (result.succeeded()) + } + if (result.succeeded()) { return WalkResult::skip(); + } } return WalkResult::advance(); }); - if (walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) { return signalPassFailure(); + } } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp index 504e2366b47a..67b5b1c680fe 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp @@ -40,8 +40,9 @@ struct DistributeLoop final : OpRewritePattern { // Only distribute if we see the marker attribute. auto numDimAttr = forOp->getAttrOfType(getGPUDistributeAttrName()); - if (!numDimAttr) + if (!numDimAttr) { return failure(); + } // Get workgroup sizes if not using gpu.block_dim SmallVector workgroupSize; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp index 9a4fca90eb3e..31b55b56a77d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp @@ -70,8 +70,9 @@ static LogicalResult tileCopyToWorkgroupMem(mlir::FunctionOpInterface funcOp, unsigned rank = dstMemRefType.getRank(); // Return empty tile size for zero dim tensor. - if (rank == 0) + if (rank == 0) { return tileSizesVal; + } int copyTileSize = copyVectorNumBits / dstMemRefType.getElementTypeBitWidth(); for (unsigned i = 0; i < rank - 1; i++) { @@ -145,8 +146,9 @@ getTileToDistributableSize(linalg::GenericOp copyOp, unroll.push_back(numThreads * numElementPerThread); assert(threadsAvailable % numThreads == 0); threadsAvailable = threadsAvailable / numThreads; - if (threadsAvailable == 1) + if (threadsAvailable == 1) { break; + } } assert(threadsAvailable == 1); unroll.resize(shape.size(), 1); @@ -162,8 +164,9 @@ static LogicalResult tileToUnroll(mlir::FunctionOpInterface funcOp, [flatWorkgroupSize](OpBuilder &builder, Operation *operation) { SmallVector tileSizesVal; auto copyOp = dyn_cast(operation); - if (!copyOp) + if (!copyOp) { return tileSizesVal; + } std::optional> staticSize = getTileToDistributableSize(copyOp, flatWorkgroupSize); for (int64_t dim : *staticSize) { @@ -235,8 +238,9 @@ static LogicalResult tileAndDistribute(mlir::FunctionOpInterface funcOp, [](OpBuilder &builder, Operation *operation) { SmallVector tileSizesVal; auto copyOp = dyn_cast(operation); - if (!copyOp) + if (!copyOp) { return tileSizesVal; + } SmallVector staticSize = getNativeDstShape(copyOp); for (int64_t dim : staticSize) { tileSizesVal.push_back(arith::ConstantIndexOp::create( @@ -308,8 +312,9 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp, static void hoistAlloc(mlir::FunctionOpInterface funcOp) { SmallVector allocs; funcOp.walk([&](memref::AllocOp alloc) { - if (alloc.getOperands().empty()) + if (alloc.getOperands().empty()) { allocs.push_back(alloc); + } }); for (memref::AllocOp alloc : allocs) { alloc->moveBefore(&(*funcOp.getBlocks().begin()), @@ -325,15 +330,17 @@ static void removeRedundantBarriers(mlir::FunctionOpInterface funcOp) { Operation *prevOp = copyOp->getPrevNode(); SmallVector redundantBarriers; while (prevOp) { - if (isa(prevOp)) + if (isa(prevOp)) { redundantBarriers.push_back(prevOp); - else + } else { break; + } prevOp = prevOp->getPrevNode(); } if (prevOp && hasMarker(prevOp, getCopyToWorkgroupMemoryMarker())) { - for (Operation *op : redundantBarriers) + for (Operation *op : redundantBarriers) { op->erase(); + } } } }); @@ -345,8 +352,9 @@ static int64_t numIteration(scf::ForOp forOp) { auto ubCstOp = forOp.getUpperBound().getDefiningOp(); auto stepCstOp = forOp.getStep().getDefiningOp(); if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 || - ubCstOp.value() < 0 || stepCstOp.value() < 0) + ubCstOp.value() < 0 || stepCstOp.value() < 0) { return 0; + } int64_t tripCount = llvm::divideCeil(ubCstOp.value() - lbCstOp.value(), stepCstOp.value()); return tripCount; @@ -358,8 +366,9 @@ unrollSharedMemoryLoops(mlir::FunctionOpInterface funcOp, const llvm::SmallDenseSet &loopsToIgnore) { SmallVector forOpsToUnroll; funcOp.walk([&](scf::ForOp forOp) { - if (!loopsToIgnore.count(forOp)) + if (!loopsToIgnore.count(forOp)) { forOpsToUnroll.push_back(forOp); + } }); for (scf::ForOp forOp : llvm::reverse(forOpsToUnroll)) { (void)loopUnrollByFactor(forOp, numIteration(forOp)); @@ -378,11 +387,13 @@ LogicalResult gpuDistributeSharedMemoryCopy(mlir::FunctionOpInterface funcOp) { MLIRContext *context = funcOp.getContext(); SmallVector copiesToWorkgroupMem; funcOp.walk([&](linalg::GenericOp copyOp) { - if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) + if (hasMarker(copyOp, getCopyToWorkgroupMemoryMarker())) { copiesToWorkgroupMem.push_back(copyOp); + } }); - if (copiesToWorkgroupMem.empty()) + if (copiesToWorkgroupMem.empty()) { return success(); + } // Step 0. First clean up the IR. hoistAlloc(funcOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 07db6b4b82aa..3714e71e28ff 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -31,13 +31,15 @@ struct DistributeConstants final : OpDistributionPattern { DistributionSignature &signature, PatternRewriter &rewriter) const override { auto constant = dyn_cast(constantOp.getResult()); - if (!constant) + if (!constant) { return failure(); + } // Only handle splat values for now. auto attr = dyn_cast(constantOp.getValue()); - if (!attr) + if (!attr) { return failure(); + } VectorLayoutInterface layout = signature[constant]; @@ -62,8 +64,9 @@ struct DistributePoison final : OpDistributionPattern { PatternRewriter &rewriter) const override { auto poisonVal = dyn_cast(poisonOp.getResult()); - if (!poisonVal) + if (!poisonVal) { return failure(); + } SmallVector distributedShape = signature[poisonVal].getDistributedShape(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp index 290713161da5..bd27e31e8282 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp @@ -139,8 +139,9 @@ static void processRegion(RewriterBase &rewriter, Region *region) { if (auto tilableOp = dyn_cast(op)) { // Do not distribute to threads of an op wants to use DMA. if (auto useDMAConfig = - getLoweringConfig(op)) + getLoweringConfig(op)) { continue; + } tileToThreads(rewriter, tilableOp); continue; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index 8f91170b2f37..006bb813003b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -207,8 +207,9 @@ static FailureOr fitScheduleInSharedMemory( auto decrementIfPossible = [](MutableArrayRef sizes) -> LogicalResult { for (int64_t &size : sizes) { - if (size <= 1) + if (size <= 1) { continue; + } --size; return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp index fc2a04e2d30d..192dce67de4f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMultiBuffering.cpp @@ -32,14 +32,16 @@ struct GPUMultiBufferingPass final SmallVector allocs; // Collect all the alloc operations. funcOp.walk([&](memref::AllocOp allocOp) { - if (hasSharedMemoryAddressSpace(allocOp.getType())) + if (hasSharedMemoryAddressSpace(allocOp.getType())) { allocs.push_back(allocOp); + } }); assert(funcOp.getBlocks().size() == 1); for (memref::AllocOp allocOp : allocs) { - if (allocOp->getParentOp() != funcOp) + if (allocOp->getParentOp() != funcOp) { allocOp->moveBefore(&*funcOp.begin()->begin()); + } } // Then perform multibuffering transformations. @@ -50,8 +52,9 @@ struct GPUMultiBufferingPass final // Skip allocations not used in a loop. for (Operation *user : allocOp->getUsers()) { auto loop = user->getParentOfType(); - if (!loop) + if (!loop) { return WalkResult::advance(); + } } allocs.push_back(allocOp); return WalkResult::advance(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 6ae77aea687f..a4923e0dc3f3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -36,8 +36,9 @@ using namespace mlir::iree_compiler::IREE::VectorExt; using VectorValue = TypedValue; static bool isBroadcast(AffineExpr expr) { - if (auto constExpr = dyn_cast(expr)) + if (auto constExpr = dyn_cast(expr)) { return constExpr.getValue() == 0; + } return false; } @@ -81,8 +82,9 @@ static SmallVector getTransferIndicesFromNestedLayout( // a constant less than `elementCount`, we can do this, unlocking // potential optimizations. bool disjoint = false; - if (std::optional offsetConst = getConstantIntValue(offset)) + if (std::optional offsetConst = getConstantIntValue(offset)) { disjoint = *offsetConst < elementCount; + } slicedIndices[pos] = affine::AffineLinearizeIndexOp::create(b, loc, ids, sizes, disjoint); } @@ -222,8 +224,9 @@ static LogicalResult populateWarpAndThreadIndices( int64_t rank = vectorLayout.getRank(); SmallVector threadIds = vectorLayout.computeThreadIds(threadId, subgroupSize, rewriter); - if (threadIds.empty() && rank != 0) + if (threadIds.empty() && rank != 0) { return failure(); + } warpIndices = SmallVector(threadIds.begin(), threadIds.begin() + rank); threadIndices = SmallVector(threadIds.begin() + rank, threadIds.begin() + 2 * rank); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp index f90411db7f8f..8057b238432c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp @@ -131,8 +131,9 @@ struct PackDestinationForOp final : OpRewritePattern { // Get the enclosing scf.for op. auto parentOp = yieldOp->getParentOp(); auto forOp = dyn_cast(parentOp); - if (!forOp) + if (!forOp) { return failure(); + } linalg::UnPackOp unpackOp; linalg::PackOp packOp; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp index de963c3ba1fd..a38a177f9a42 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp @@ -54,11 +54,13 @@ struct FlattenTransferReadOp : public OpRewritePattern { Value source = transferReadOp.getBase(); MemRefType sourceType = dyn_cast(source.getType()); // Contiguity check is valid on tensors only. - if (!sourceType) + if (!sourceType) { return failure(); + } // Already 2D or lower nothing to do. - if (vectorType.getRank() < 3) + if (vectorType.getRank() < 3) { return failure(); + } // The innermost dim is always considered non-unit as it wont be dropped // Therefore, we initialize `numberOfNonUnitDims` to 1 and not 0 int numberOfNonUnitDims = 1; @@ -86,12 +88,15 @@ struct FlattenTransferReadOp : public OpRewritePattern { } int rankOfCollapsedVector = 2; // TODO: generalize this pattern, relax the requirements here. - if (transferReadOp.hasOutOfBoundsDim()) + if (transferReadOp.hasOutOfBoundsDim()) { return failure(); - if (!transferReadOp.getPermutationMap().isMinorIdentity()) + } + if (!transferReadOp.getPermutationMap().isMinorIdentity()) { return failure(); - if (transferReadOp.getMask()) + } + if (transferReadOp.getMask()) { return failure(); + } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr( SmallVector(rankOfCollapsedVector, true)); auto newidentityMap = @@ -113,8 +118,9 @@ struct FlattenTransferReadOp : public OpRewritePattern { SmallVector subViewOffsets, subViewSizes, subViewStrides; subViewSizes.append(sourceType.getRank() - vectorType.getRank(), rewriter.getIndexAttr(1)); - for (int64_t dim : vectorType.getShape()) + for (int64_t dim : vectorType.getShape()) { subViewSizes.push_back(rewriter.getIndexAttr(dim)); + } for (int i = 0; i < sourceType.getRank(); i++) { subViewOffsets.push_back(transferReadOp.getIndices()[i]); subViewStrides.push_back(rewriter.getIndexAttr(1)); @@ -136,8 +142,9 @@ struct FlattenTransferReadOp : public OpRewritePattern { rewriter, loc, vectorTypeBroadcast, readCollapse); SmallVector transposePermutation; for (int i = 0; i < vectorType.getRank(); i++) { - if (i == vectorType.getRank() - 2) + if (i == vectorType.getRank() - 2) { continue; + } transposePermutation.push_back(i); } transposePermutation.insert(transposePermutation.begin() + @@ -186,8 +193,9 @@ struct CombineTransferReadOpBroadcast final /// Returns true if op is appropriate contract for promotion. static LogicalResult contractOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) + if (!linalgOp) { return failure(); + } // Limit promotion to matmul and batch matmul, there may be generic // ops with more batch dimensions we didn't distribute and therefore // cannot find a higher bound. @@ -206,8 +214,9 @@ struct DropSharedMemoryDeallocOp : public OpRewritePattern { LogicalResult matchAndRewrite(memref::DeallocOp op, PatternRewriter &rewriter) const override { if (!hasSharedMemoryAddressSpace( - cast(op.getMemref().getType()))) + cast(op.getMemref().getType()))) { return failure(); + } rewriter.eraseOp(op); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp index b17a554b3684..d13f6b6ece7f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp @@ -48,22 +48,26 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // speculatively. if (!isa(op)) { // Return/execute the op if it is a side effect free. - if (mlir::isMemoryEffectFree(op)) + if (mlir::isMemoryEffectFree(op)) { return op; + } // Return/execute the op if it is barrier, commit group, or ldmatrix op. if (isa(op)) + nvgpu::DeviceAsyncWaitOp>(op)) { return op; + } // Return/execute the op if it is a shared memory load. if (auto loadOp = dyn_cast(op)) { auto loadBaseType = cast(loadOp.getBase().getType()); - if (hasSharedMemoryAddressSpace(loadBaseType)) + if (hasSharedMemoryAddressSpace(loadBaseType)) { return op; + } } if (auto loadOp = dyn_cast(op)) { auto loadBaseType = loadOp.getMemRefType(); - if (hasSharedMemoryAddressSpace(loadBaseType)) + if (hasSharedMemoryAddressSpace(loadBaseType)) { return op; + } } // If we are here that means the operation does not have predication support // and cannot be speculatively executed. Thus, unpeeled epilogue is not @@ -107,12 +111,14 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, /// set. static void addDepOps(llvm::SmallDenseSet &dep, Operation *op, Block *block) { - if (!dep.insert(op).second) + if (!dep.insert(op).second) { return; + } for (Value operand : op->getOperands()) { Operation *defOp = operand.getDefiningOp(); - if (defOp && defOp->getBlock() == block) + if (defOp && defOp->getBlock() == block) { addDepOps(dep, defOp, block); + } } } @@ -123,8 +129,9 @@ static void getPipelineStages(scf::ForOp forOp, std::vector> &ops, unsigned depth) { - if (!forOp->hasAttr(kPipeliningLoopMarker)) + if (!forOp->hasAttr(kPipeliningLoopMarker)) { return; + } // Track dependencies of stage 0 ops. llvm::SmallDenseSet loadDep; @@ -138,12 +145,14 @@ getPipelineStages(scf::ForOp forOp, // stage `maxDepth`. In order to have a correct scheduling even with back // edges we order stages in decreasing order. for (Operation &op : forOp.getBody()->getOperations()) { - if (!loadDep.count(&op) && !isa(op)) + if (!loadDep.count(&op) && !isa(op)) { ops.push_back(std::make_pair(&op, depth)); + } } for (Operation &op : forOp.getBody()->getOperations()) { - if (loadDep.count(&op)) + if (loadDep.count(&op)) { ops.push_back(std::make_pair(&op, 0)); + } } } @@ -156,8 +165,9 @@ static void setAsyncAnnotations(Operation *op, // copies in flight. bool copyBeforeLoad = schedule == PipeliningSchedulingStrategy::nvidiaTensorCore; - if (waitOp.getNumGroups()) + if (waitOp.getNumGroups()) { return; + } int numGroupInFlight = 0; if (part == scf::PipeliningOption::PipelinerPart::Kernel || part == scf::PipeliningOption::PipelinerPart::Prologue) { @@ -178,8 +188,9 @@ static void setAsyncAnnotations(Operation *op, schedule == PipeliningSchedulingStrategy::loadStoreStage0 ? 0 : 1; if (pipelineStoreStage != 0 || part != mlir::scf::PipeliningOption::PipelinerPart::Prologue || - iteration >= depth - 1) + iteration >= depth - 1) { return; + } OpBuilder b(op); barrierOp->setAttr(kPipeliningExtraBarrier, b.getUnitAttr()); } @@ -194,12 +205,14 @@ static bool setPipeliningMarkers(scf::ForOp forOp, bool pipelineStoreStage) { SmallVector barriers; for (Operation &op : forOp.getBody()->getOperations()) { // Pipeline the most inner for op that should be a flat region. - if (op.getNumRegions() > 0) + if (op.getNumRegions() > 0) { return false; + } if (isa(op)) { barriers.push_back(&op); - if (pipelineStoreStage == 0) + if (pipelineStoreStage == 0) { op.setAttr(kPipeliningFirstStage, builder.getUnitAttr()); + } } if (isa(op)) { copyToWorkgroupMemory = true; @@ -212,21 +225,26 @@ static bool setPipeliningMarkers(scf::ForOp forOp, bool pipelineStoreStage) { continue; } auto ld = dyn_cast(op); - if (!ld) + if (!ld) { continue; + } auto ldSrcType = cast(ld.getBase().getType()); - if (!hasGlobalMemoryAddressSpace(ldSrcType) || !ld->hasOneUse()) + if (!hasGlobalMemoryAddressSpace(ldSrcType) || !ld->hasOneUse()) { continue; + } auto st = dyn_cast(ld->use_begin()->getOwner()); - if (!st) + if (!st) { continue; + } auto stSrcType = cast(st.getBase().getType()); - if (!hasSharedMemoryAddressSpace(stSrcType)) + if (!hasSharedMemoryAddressSpace(stSrcType)) { continue; + } copyToWorkgroupMemory = true; ld->setAttr(kPipeliningFirstStage, builder.getUnitAttr()); - if (pipelineStoreStage == 0) + if (pipelineStoreStage == 0) { st->setAttr(kPipeliningFirstStage, builder.getUnitAttr()); + } } if (copyToWorkgroupMemory) { forOp->setAttr(kPipeliningLoopMarker, builder.getUnitAttr()); @@ -287,14 +305,16 @@ struct MainLoopInfo { // of some other op. void backwardSliceOfDependentOps(llvm::SetVector &dependentOps, Operation *op, Block *block) { - if (!seenDepOps.insert(op)) + if (!seenDepOps.insert(op)) { return; + } // Add the unseen op to the dependentOps and recurse on its operands. dependentOps.insert(op); for (Value operand : op->getOperands()) { Operation *defOp = operand.getDefiningOp(); - if (defOp && defOp->getBlock() == block) + if (defOp && defOp->getBlock() == block) { backwardSliceOfDependentOps(dependentOps, defOp, block); + } } } @@ -304,8 +324,9 @@ struct MainLoopInfo { void mmaOperandDefOperation(Operation *op, llvm::SetVector &defOperation, Block *block) { - if (!op) + if (!op) { return; + } // If the operations defining the mma.sync's operand is one of the // qualifying operations, add the operations to the current kgroup defining @@ -326,14 +347,16 @@ struct MainLoopInfo { void vistMmaSyncOp(Operation *op, int kgroup) { // if the operation in an `scf.yield`, we reached the end of MmaSyncOp chain // return. - if (seenMmaOps.count(op) || isa(op)) + if (seenMmaOps.count(op) || isa(op)) { return; + } seenMmaOps.insert(op); // If the kgroup is not in the vector, create a new WarpMmaOp. - if (warpOperations.size() < kgroup + 1) + if (warpOperations.size() < kgroup + 1) { warpOperations.push_back(WarpMmaOp()); + } mmaOperandDefOperation(op->getOperand(0).getDefiningOp(), warpOperations[kgroup].lhsOperations, @@ -426,8 +449,9 @@ struct MainLoopInfo { LDBG() << "-- missing warpOperations -> not schedulable"; isSchedulable = false; } - if (!isSchedulable) + if (!isSchedulable) { return; + } // Collect the dependent operations for `cp.async` in the mainloop order for // coarse-grained software pipeling. The deps are collected in stage order, @@ -552,8 +576,9 @@ static void getNvidiaAmpereTensorCorePipeline( // Issue mma.sync on previous loaded kgroup. for (Operation &op : forOp.getBody()->getOperations()) { - if (mainloop.warpOperations[kgroup].mmaOperations.count(&op)) + if (mainloop.warpOperations[kgroup].mmaOperations.count(&op)) { ops.push_back(std::make_pair(&op, numStages - 1)); + } } } @@ -565,8 +590,9 @@ static void getNvidiaAmpereTensorCorePipeline( // it at one place. // Schedule all cp.async and one cp.async.commit_group. for (Operation &op : forOp.getBody()->getOperations()) { - if (mainloop.copyGlobalToSharedOpDeps.count(&op)) + if (mainloop.copyGlobalToSharedOpDeps.count(&op)) { ops.push_back(std::make_pair(&op, 0 /*pipelineStage*/)); + } } ops.push_back( std::make_pair(mainloop.asyncCreateGroupOp[0], 0 /*pipelineStage*/)); @@ -585,14 +611,16 @@ static void getNvidiaAmpereTensorCorePipeline( // into one stage ahead. for (Operation &op : forOp.getBody()->getOperations()) { if (mainloop.warpOperations[0].lhsOperations.count(&op) || - mainloop.warpOperations[0].rhsOperations.count(&op)) + mainloop.warpOperations[0].rhsOperations.count(&op)) { ops.push_back(std::make_pair(&op, numStages - 2)); + } } // Issue mma.sync on for the last kgroup at the end of the mainloop. for (Operation &op : forOp.getBody()->getOperations()) { - if (mainloop.warpOperations[numKgroups - 1].mmaOperations.count(&op)) + if (mainloop.warpOperations[numKgroups - 1].mmaOperations.count(&op)) { ops.push_back(std::make_pair(&op, numStages - 1)); + } } // Prints the mainloop schedule generated for NVIDIA Ampere through native @@ -667,8 +695,9 @@ struct GPUPipeliningPass final // Remove extra barriers from the prologue assuming appropriate // multi-buffering. funcOp.walk([](gpu::BarrierOp barrierOp) { - if (barrierOp->hasAttr(kPipeliningExtraBarrier)) + if (barrierOp->hasAttr(kPipeliningExtraBarrier)) { barrierOp->erase(); + } }); } }; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp index d953725fd6a1..1625081df28b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp @@ -68,8 +68,9 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) { // TODO (nirvedhmeshram) : This is fairly special case. Instead we should // just promote results before doing padding which introduces the extract // slice. - if (!valToMakeShared.hasOneUse()) + if (!valToMakeShared.hasOneUse()) { return; + } valueToReplace = extractSliceOp.getResult(); for (auto user : extractSliceOp->getUsers()) { opsToReplaceUseIn.insert(user); @@ -120,8 +121,9 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) { void promoteOperand(OpBuilder &builder, Operation *op, unsigned index, IREE::GPU::PromotionAttr promotionAttr) { auto dpsOp = dyn_cast(op); - if (!dpsOp) + if (!dpsOp) { return; + } // We use the convention that if we are passing an index beyond the inputs // then we promote the result of the corresponding dps init. if (index >= dpsOp.getNumDpsInputs()) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp index 05ddd0cc95a4..d8a98dbecd33 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp @@ -40,16 +40,19 @@ static bool hasCollapseShapeUser(memref::AllocOp allocOp) { static void padAlloc(MLIRContext *context, memref::AllocOp allocOp, unsigned paddingSizeBits) { auto allocOpShape = allocOp.getType().getShape(); - if (allocOpShape.empty()) + if (allocOpShape.empty()) { return; + } int64_t innerDim = allocOpShape.back(); - if (ShapedType::isDynamic(innerDim)) + if (ShapedType::isDynamic(innerDim)) { return; + } // Return if we have CollapseShape op as an user as padding in that case is // unsupported. - if (hasCollapseShapeUser(allocOp)) + if (hasCollapseShapeUser(allocOp)) { return; + } Type elType = allocOp.getType().getElementType(); unsigned bitwidth = @@ -125,8 +128,9 @@ static unsigned computeEffectiveExtraBytes(mlir::FunctionOpInterface funcOp, MemRefType allocType = cast(allocOp.getType()); ArrayRef shape = allocType.getShape(); - if (shape.empty()) + if (shape.empty()) { return; + } int outerProduct = 1; for (std::size_t i = 0; i < shape.size() - 1; ++i) { @@ -181,8 +185,9 @@ struct GPUReduceBankConflictsPass final return; } - if (failed(reduceSharedMemoryBankConflicts(funcOp, paddingBits))) + if (failed(reduceSharedMemoryBankConflicts(funcOp, paddingBits))) { signalPassFailure(); + } } }; @@ -198,8 +203,9 @@ LogicalResult reduceSharedMemoryBankConflicts(mlir::FunctionOpInterface funcOp, sharedMemAllocs.push_back(allocOp); } }); - for (memref::AllocOp alloc : sharedMemAllocs) + for (memref::AllocOp alloc : sharedMemAllocs) { padAlloc(funcOp->getContext(), alloc, paddingSize); + } // In the current form this always succeeds. return success(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp index fdf549f9d10c..9c3708e0a06b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp @@ -29,8 +29,9 @@ constexpr int copyVectorNumBits = 128; /// Filter to decide which contract ops need allocations. static bool contractOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) + if (!linalgOp) { return false; + } if (!linalg::isaContractionOpInterface(linalgOp)) { return false; @@ -39,8 +40,9 @@ static bool contractOpFilter(Operation *op) { // The workgroup specialization already makes static shapes available for the // main tile part and makes the partial tile computation small, so promoting // to shared memory for the partial tile actually hurts the performance. - if (linalgOp.hasDynamicShape()) + if (linalgOp.hasDynamicShape()) { return false; + } // Check if the shape is tile-distributable. The leading dimension must be a // multiple of the target vector size, which is 128b / the element bit width. @@ -76,8 +78,9 @@ static bool contractOpFilter(Operation *op) { /// Filter to decide which transpose ops need allocations. static bool transposeOpFilter(Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) + if (!linalgOp) { return false; + } LinalgOpInfo opInfo(linalgOp, sharedMemTransposeFilter); return opInfo.isTranspose(); } @@ -101,18 +104,21 @@ struct SwapAllocTensorPattern final LogicalResult matchAndRewrite(bufferization::AllocTensorOp allocOp, PatternRewriter &rewriter) const override { - if (!allocOp.getCopy()) + if (!allocOp.getCopy()) { return failure(); + } auto linalgOp = allocOp.getCopy().getDefiningOp(); - if (!linalgOp) + if (!linalgOp) { return failure(); + } // Make sure we don't use the initial values for the linalg output we are // copying during the tensor allocation. unsigned resultNumber = cast(allocOp.getCopy()).getResultNumber(); OpOperand *initOperand = linalgOp.getDpsInitOperand(resultNumber); - if (linalgOp.payloadUsesValueFromOperand(initOperand)) + if (linalgOp.payloadUsesValueFromOperand(initOperand)) { return failure(); + } rewriter.setInsertionPoint(linalgOp); std::optional memorySpace = allocOp.getMemorySpace(); @@ -148,12 +154,14 @@ struct GPUTensorAllocPass final funcOp.walk([&](Operation *op) { switch (promoteSharedMemPattern) { case GPUPromoteSharedMemPattern::ContractionOpPattern: - if (contractOpFilter(op)) + if (contractOpFilter(op)) { opsToPromote.push_back(op); + } break; case GPUPromoteSharedMemPattern::TransposeOpPattern: - if (transposeOpFilter(op)) + if (transposeOpFilter(op)) { opsToPromote.push_back(op); + } break; } }); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp index 0f597ccba08b..328c8924297b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorTile.cpp @@ -45,8 +45,9 @@ class TileConsumerAndFuseInputProducer final LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) + if (failed(filter.checkAndNotify(rewriter, op))) { return failure(); + } // Make sure we have a PartitionableLoopInterface op here and query the tile // sizes from the partitionable loops. @@ -63,8 +64,9 @@ class TileConsumerAndFuseInputProducer final } // Mask out non reduction dimensions. for (unsigned depth : partitionedLoops) { - if (depth < tileSizes.size()) + if (depth < tileSizes.size()) { tileSizes[depth] = 0; + } } // Make sure we have a tile size for each dimension. @@ -120,11 +122,13 @@ class TileConsumerAndFuseInputProducer final return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); } - if (!fuseInputProducer) + if (!fuseInputProducer) { return tilingResult; + } // If there are no generated loops generated, fusion is immaterial. - if (tilingResult->loops.empty()) + if (tilingResult->loops.empty()) { return tilingResult; + } // Collect immediate input operands that are fusable into the tiled loop. // We have tensor extract slice ops taking slices of the untiled op. @@ -135,15 +139,18 @@ class TileConsumerAndFuseInputProducer final assert(tilingResult->tiledOps.size() == 1); Operation *tiledOp = tilingResult->tiledOps.front(); auto dsOp = dyn_cast(tiledOp); - if (!dsOp) + if (!dsOp) { return tilingResult; + } for (OpOperand *operand : dsOp.getDpsInputOperands()) { auto sliceOp = operand->get().getDefiningOp(); - if (!sliceOp) + if (!sliceOp) { continue; + } auto tilingOp = sliceOp.getSource().getDefiningOp(); - if (!tilingOp) + if (!tilingOp) { continue; + } if (isa(sliceOp.getSource().getDefiningOp())) { continue; } @@ -248,13 +255,15 @@ static LogicalResult tileParallelDims(mlir::FunctionOpInterface funcOp, for (TilingInterface tilingOp : computeOps) { auto attr = tilingOp->getAttr(LinalgTransforms::kLinalgTransformMarker); - if (attr == marker) + if (attr == marker) { continue; + } size_t numLoops = 0; for (auto type : tilingOp.getLoopIteratorTypes()) { - if (type == utils::IteratorType::parallel) + if (type == utils::IteratorType::parallel) { numLoops++; + } } IRRewriter rewriter(tilingOp->getContext()); rewriter.setInsertionPoint(tilingOp); @@ -263,8 +272,9 @@ static LogicalResult tileParallelDims(mlir::FunctionOpInterface funcOp, auto partitionedLoops = interfaceOp.getPartitionableLoops(kNumMaxParallelDims); // If there are no dimensions to tile skip the transformation. - if (partitionedLoops.empty()) + if (partitionedLoops.empty()) { continue; + } SmallVector numThreads(numLoops, rewriter.getIndexAttr(0)); int64_t id = 0, threadDim = 0; SmallVector idDims; @@ -307,8 +317,9 @@ static LogicalResult tileAndUnrollConv(mlir::FunctionOpInterface funcOp) { IRRewriter rewriter(funcOp.getContext()); SmallVector tileSizes = getAsIndexOpFoldResult( funcOp.getContext(), getTileSizes(consumerOp, 1)); - if (tileSizes.empty()) + if (tileSizes.empty()) { return success(); + } FailureOr tileAndFuseResult = scf::tileConsumerAndFuseProducersUsingSCF( @@ -375,8 +386,9 @@ struct GPUTensorTilePass final // Tile to serial loops to the wg tile size to handle reductions and other // dimension that have not been distributed. if (failed(tileReductionToSerialLoops(funcOp, /*fuseInputProducer=*/false, - /*coalesceLoops=*/false))) + /*coalesceLoops=*/false))) { return signalPassFailure(); + } LLVM_DEBUG({ llvm::dbgs() << "// --- After tile reductions:\n"; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp index 2fdedf111a5f..8d2397649aaa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp @@ -64,8 +64,9 @@ collectComputeOps(mlir::FunctionOpInterface funcOp, computeOps = getComputeOps(funcOp); for (Operation *op : computeOps) { if (auto config = - getLoweringConfig(op)) + getLoweringConfig(op)) { configs.push_back(config); + } } if (computeOps.size() > 1) { // Only keep the last compute ops. @@ -80,8 +81,9 @@ collectComputeOps(mlir::FunctionOpInterface funcOp, ifOps.front()->walk([&configs](Operation *op) { if (isa(op)) { if (auto config = - getLoweringConfig(op)) + getLoweringConfig(op)) { configs.push_back(config); + } } }); @@ -276,8 +278,9 @@ struct GPUTilePass final : impl::GPUTilePassBase { SmallVector computeOps; FailureOr loweringConfig = collectComputeOps(funcOp, computeOps); - if (failed(loweringConfig)) + if (failed(loweringConfig)) { return signalPassFailure(); + } assert(computeOps.size() <= 2); // Now tile the last computation op to invocations and fuse all operand @@ -286,8 +289,9 @@ struct GPUTilePass final : impl::GPUTilePassBase { for (Operation *computeOp : computeOps) { auto consumerOp = dyn_cast(computeOp); if (!consumerOp || - failed(tileAndDistributeToThreads(consumerOp, threadTileSizes))) + failed(tileAndDistributeToThreads(consumerOp, threadTileSizes))) { return signalPassFailure(); + } } LLVM_DEBUG({ diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp index 8c19c39c02ff..19be905eb779 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp @@ -57,8 +57,9 @@ void static removeUnitExtentDimsfromMaps(linalg::LinalgOp linalgOp, return; } SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - if (indexingMaps.empty()) + if (indexingMaps.empty()) { return; + } AffineMap inputMap = indexingMaps[0]; AffineMap filterMap = indexingMaps[1]; AffineMap outputMap = indexingMaps[2]; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp index 4231691da23a..a186ae6e17ba 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileReduction.cpp @@ -25,12 +25,14 @@ static LogicalResult tileReduction(linalg::LinalgOp op) { SmallVector dims; op.getReductionDims(dims); SmallVector tileSize = getTileSizes(op, 1); - if (tileSize.empty()) + if (tileSize.empty()) { return success(); + } // Make sure reduction dimensions are the innermost ones. for (int i = 0; i < dims.size(); ++i) { - if (dims[dims.size() - 1 - i] != op.getNumLoops() - 1 - i) + if (dims[dims.size() - 1 - i] != op.getNumLoops() - 1 - i) { return success(); + } } IRRewriter rewriter(op.getContext()); SmallVector sizes; @@ -40,8 +42,9 @@ static LogicalResult tileReduction(linalg::LinalgOp op) { rewriter.setInsertionPoint(op); FailureOr results = scf::tileReductionUsingScf( rewriter, cast(op.getOperation()), sizes); - if (failed(results)) + if (failed(results)) { return failure(); + } rewriter.replaceOp(op, results->replacements); return success(); } @@ -50,14 +53,16 @@ static LogicalResult tileFusedOps(linalg::LinalgOp op) { IRRewriter rewriter(op.getContext()); rewriter.setInsertionPoint(op); SmallVector tileSizes = getTileSizes(op, 1); - if (tileSizes.empty()) + if (tileSizes.empty()) { return success(); + } linalg::LinalgTilingOptions tileOption; tileOption.setTileSizes(tileSizes); FailureOr tiledOps = linalg::tileLinalgOp(rewriter, op, tileOption); - if (failed(tiledOps)) + if (failed(tiledOps)) { return failure(); + } rewriter.replaceOp(op, tiledOps->tensorResults); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index f7ab3452fbce..da283d2f3248 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -321,8 +321,9 @@ static void applyVectorDistribution(Operation *root, while (!worklist.empty()) { Operation *op = worklist.front(); worklist.pop_front(); - if (op == nullptr) + if (op == nullptr) { continue; + } LLVM_DEBUG(llvm::dbgs() << "Distributing: "); LLVM_DEBUG(op->print(llvm::dbgs(), OpPrintingFlags().skipRegions())); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp index 3f3f2a0c649b..0946e7374918 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp @@ -62,14 +62,18 @@ static bool isUniformLoad(Operation *op) { using namespace IREE::HAL; auto loadOp = dyn_cast(op); - if (!loadOp) + if (!loadOp) { return false; - if (!hasGlobalMemoryAddressSpace(loadOp.getMemRefType())) + } + if (!hasGlobalMemoryAddressSpace(loadOp.getMemRefType())) { return false; + } auto space = loadOp.getMemRefType().getMemorySpace(); auto descTypeAttr = dyn_cast_if_present(space); - if (descTypeAttr && descTypeAttr.getValue() == DescriptorType::UniformBuffer) + if (descTypeAttr && + descTypeAttr.getValue() == DescriptorType::UniformBuffer) { return true; + } auto subspan = loadOp.getMemRef().getDefiningOp(); if (auto fatBufferCast = @@ -77,16 +81,20 @@ static bool isUniformLoad(Operation *op) { subspan = fatBufferCast.getSource().getDefiningOp(); } - if (!subspan) + if (!subspan) { return false; + } descTypeAttr = dyn_cast_if_present( cast(subspan.getResult().getType()).getMemorySpace()); - if (descTypeAttr && descTypeAttr.getValue() == DescriptorType::UniformBuffer) + if (descTypeAttr && + descTypeAttr.getValue() == DescriptorType::UniformBuffer) { return true; + } if (auto flags = subspan.getDescriptorFlags()) { - if (bitEnumContainsAll(*flags, IREE::HAL::DescriptorFlags::ReadOnly)) + if (bitEnumContainsAll(*flags, IREE::HAL::DescriptorFlags::ReadOnly)) { return true; + } } return false; } @@ -97,18 +105,24 @@ static void moveScalarAndBindingUniformCode(gpu::WarpExecuteOnLane0Op warpOp) { /// Hoist ops without side effect as well as special binding ops. auto canBeHoisted = [](Operation *op, function_ref definedOutside) { - if (op->getNumRegions() != 0) + if (op->getNumRegions() != 0) { return false; - if (!llvm::all_of(op->getOperands(), definedOutside)) + } + if (!llvm::all_of(op->getOperands(), definedOutside)) { return false; - if (isMemoryEffectFree(op)) + } + if (isMemoryEffectFree(op)) { return true; + } if (isa(op)) + IREE::HAL::InterfaceConstantLoadOp, memref::AssumeAlignmentOp>( + op)) { return true; - if (isUniformLoad(op)) + } + if (isUniformLoad(op)) { return true; + } // Shared memory is already scoped to the workgroup and can safely be // hoisted out of the the warp op. if (auto allocOp = dyn_cast(op)) { @@ -144,8 +158,9 @@ static void moveScalarAndBindingUniformCode(gpu::WarpExecuteOnLane0Op warpOp) { } // Move all the ops marked as uniform outside of the region. - for (Operation *op : opsToMove) + for (Operation *op : opsToMove) { op->moveBefore(warpOp); + } } /// Pattern to convert single element vector.insert to broadcast, this is a @@ -155,8 +170,9 @@ struct InsertToBroadcast final : OpRewritePattern { LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { - if (insertOp.getDestVectorType().getNumElements() != 1) + if (insertOp.getDestVectorType().getNumElements() != 1) { return failure(); + } rewriter.replaceOpWithNewOp( insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore()); return success(); @@ -173,8 +189,9 @@ struct WarpOpBarrier final : OpRewritePattern { warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); Operation *lastNode = yield->getPrevNode(); auto barrierOp = dyn_cast_if_present(lastNode); - if (!barrierOp) + if (!barrierOp) { return failure(); + } rewriter.setInsertionPointAfter(warpOp); (void)gpu::BarrierOp::create(rewriter, barrierOp.getLoc()); @@ -274,8 +291,9 @@ struct VectorReductionToGPUPass final }; auto distributionFn = [](Value val) { auto vecType = dyn_cast(val.getType()); - if (!vecType) + if (!vecType) { return AffineMap::get(val.getContext()); + } // Create an identity dim map of rank |vecRank|. This greedily divides // threads along the outermost vector dimensions to the innermost ones. int64_t vecRank = vecType.getRank(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp index c567745392b3..3faa7a244368 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp @@ -139,20 +139,23 @@ struct ReorderWorkgroupsPass final .Case("", ReorderWorkgroupsStrategy::None) .Case("transpose", ReorderWorkgroupsStrategy::Transpose) .Default(failure()); - if (failed(selectedStrategy)) + if (failed(selectedStrategy)) { return failure(); + } reorderingStrategy = *selectedStrategy; return success(); } void runOnOperation() override { - if (reorderingStrategy == ReorderWorkgroupsStrategy::None) + if (reorderingStrategy == ReorderWorkgroupsStrategy::None) { return; + } FunctionOpInterface funcOp = getOperation(); - if (filterFn && failed(filterFn(funcOp))) + if (filterFn && failed(filterFn(funcOp))) { return; + } LLVM_DEBUG({ llvm::dbgs() << "--- Before reorder workgroups with workgroup counts ---"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp index 0a281311227f..c60a7a2b0a37 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp @@ -137,11 +137,13 @@ void LoweringConfigTilingLevelAttr::print(mlir::AsmPrinter &printer) const { [&](auto pair) { auto [tileSize, isScalable] = pair; // Wrap scalable sizes in square brackets. - if (isScalable) + if (isScalable) { printer << '['; + } printer << tileSize; - if (isScalable) + if (isScalable) { printer << ']'; + } }); } printer << ']'; @@ -163,8 +165,9 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, auto parseListOfSizes = [&](SmallVector *scalableFlags = nullptr, bool prefixChecked = false) -> FailureOr> { - if (!prefixChecked && parser.parseLSquare()) + if (!prefixChecked && parser.parseLSquare()) { return failure(); + } if (parser.parseOptionalRSquare().succeeded()) { // Empty list. return SmallVector(); @@ -177,15 +180,18 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, expectScalableSizes && parser.parseOptionalLSquare().succeeded(); int64_t size = 0; if (parser.parseInteger(size) || - (isScalable && parser.parseRSquare())) + (isScalable && parser.parseRSquare())) { return failure(); + } sizes.push_back(size); - if (scalableFlags) + if (scalableFlags) { scalableFlags->push_back(isScalable); + } return success(); }); - if (failed(listParse) || parser.parseRSquare()) + if (failed(listParse) || parser.parseRSquare()) { return failure(); + } return sizes; }; SmallVector scalableFlags; @@ -193,8 +199,9 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, // Case 1: Simple list of tile sizes, e.g.: // [0, [32], 16] auto tileSizes = parseListOfSizes(&scalableFlags, /*prefixChecked=*/true); - if (failed(tileSizes)) + if (failed(tileSizes)) { return {}; + } return parser.getChecked( loc, parser.getContext(), *tileSizes, ArrayRef{}, scalableFlags); @@ -202,15 +209,18 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, // Case 2: sizes and interchange, e.g.: // {sizes = [0, [32], 16], interchange = [0, 1, 2]} if (parser.parseLBrace() || parser.parseKeyword("sizes") || - parser.parseEqual()) + parser.parseEqual()) { return {}; + } auto tileSizes = parseListOfSizes(&scalableFlags); if (failed(tileSizes) || parser.parseComma() || - parser.parseKeyword("interchange") || parser.parseEqual()) + parser.parseKeyword("interchange") || parser.parseEqual()) { return {}; + } auto tileInterchange = parseListOfSizes(); - if (failed(tileInterchange) || parser.parseRBrace()) + if (failed(tileInterchange) || parser.parseRBrace()) { return {}; + } return parser.getChecked( loc, parser.getContext(), *tileSizes, *tileInterchange, scalableFlags); } @@ -218,8 +228,9 @@ Attribute LoweringConfigTilingLevelAttr::parse(mlir::AsmParser &parser, LogicalResult LoweringConfigTilingLevelAttr::verify( function_ref emitError, ArrayRef tileSizes, ArrayRef tileInterchange, ArrayRef scalableFlags) { - if (!scalableFlags.empty() && scalableFlags.size() != tileSizes.size()) + if (!scalableFlags.empty() && scalableFlags.size() != tileSizes.size()) { return emitError() << "scalable flags length does not match tile sizes"; + } return success(); } @@ -254,29 +265,33 @@ LoweringConfigAttr::get(MLIRContext *context, TileSizesListTypeRef tileSizes, TileSizesListType LoweringConfigAttr::getTileSizeVals() const { TileSizesListType tileSizes; - for (auto &level : getTilingLevels()) + for (auto &level : getTilingLevels()) { tileSizes.push_back(SmallVector(level.getSizes())); + } return tileSizes; } SmallVector LoweringConfigAttr::getTileSizeVals(unsigned level) const { auto levels = getTilingLevels(); - if (level >= levels.size()) + if (level >= levels.size()) { return {}; + } return SmallVector(levels[level].getSizes()); } ScalableTileFlagsListType LoweringConfigAttr::getScalableTileFlagVals() { ScalableTileFlagsListType scalableFlags; - for (auto &level : getTilingLevels()) + for (auto &level : getTilingLevels()) { scalableFlags.push_back(SmallVector(level.getScalableFlags())); + } return scalableFlags; } SmallVector LoweringConfigAttr::getScalableTileFlagVals(unsigned level) { auto levels = getTilingLevels(); - if (level >= levels.size()) + if (level >= levels.size()) { return {}; + } SmallVector scalableFlags(levels[level].getScalableFlags()); // Extend the scalable flags with `false` to match the length of the sizes. scalableFlags.resize(levels[level].getSizes().size()); @@ -286,8 +301,9 @@ SmallVector LoweringConfigAttr::getScalableTileFlagVals(unsigned level) { SmallVector LoweringConfigAttr::getTileInterchangeVals(unsigned level) const { auto levels = getTilingLevels(); - if (level >= levels.size()) + if (level >= levels.size()) { return {}; + } return SmallVector(levels[level].getInterchange()); } @@ -338,8 +354,9 @@ bool LoweringConfigAttr::hasWorkgroupTilingLevel() const { LogicalResult LoweringConfigAttr::verify(function_ref emitError, LoweringConfigTilingLevelsAttr levels) { - if (!levels) + if (!levels) { return emitError() << "missing lowering config levels"; + } return success(); } @@ -516,23 +533,27 @@ static OpFoldResult getMinimumConstantOffsetValue(OpBuilder &b, Location loc, OpFoldResult offset, int64_t rotationInvariant) { auto value = dyn_cast_if_present(offset); - if (!value) + if (!value) { return offset; + } auto add = value.getDefiningOp(); - if (!add) + if (!add) { return offset; + } llvm::APInt constant; - if (!matchPattern(add.getRhs(), m_ConstantInt(&constant))) + if (!matchPattern(add.getRhs(), m_ConstantInt(&constant))) { return offset; + } int64_t constantOffset = constant.getSExtValue(); int64_t baseMod = constantOffset % rotationInvariant; // Skip constructing the new apply if it's not needed (c < rotationInvariant). - if (baseMod == constantOffset) + if (baseMod == constantOffset) { return offset; + } Value modOffset = arith::ConstantIndexOp::create(b, loc, baseMod); // If the original add is nsw/nuw, then the new add must also be given we're @@ -798,14 +819,16 @@ void eraseTranslationInfo(FunctionOpInterface funcOp) { SmallVector getTileSizes(Operation *op, unsigned level) { IREE::Codegen::LoweringConfigAttrInterface configAttr = getLoweringConfig(op); - if (!configAttr) + if (!configAttr) { return {}; + } return configAttr.getStaticTilingLevelSizes(level, op); } SmallVector getTileSizes(OpBuilder &b, Operation *op, unsigned level) { IREE::Codegen::LoweringConfigAttrInterface configAttr = getLoweringConfig(op); - if (!configAttr) + if (!configAttr) { return {}; + } return llvm::map_to_vector(configAttr.getTilingLevelSizes(b, level, op), [&](OpFoldResult s) -> Value { return getValueOrCreateConstantIndexOp( @@ -856,8 +879,9 @@ bool hasRootOpInfo(Operation *op) { IREE::Codegen::UKernelProviderInterface getUKernelProviderFromTarget(DictionaryAttr dict) { - if (!dict) + if (!dict) { return {}; + } return dict.getAs( kUKernelProviderName); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h index 1f2325226948..acc87a0cb1d0 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h @@ -103,8 +103,9 @@ template FailureOr getLoweringConfigCarryingOp(ArrayRef computeOps) { for (Operation *op : computeOps) { - if (getLoweringConfig(op)) + if (getLoweringConfig(op)) { return op; + } } return failure(); } @@ -117,8 +118,9 @@ getLoweringConfigCarryingOp(ArrayRef computeOps) { template FailureOr getFirstLoweringConfig(ArrayRef computeOps) { FailureOr op = getLoweringConfigCarryingOp(computeOps); - if (failed(op)) + if (failed(op)) { return failure(); + } return getLoweringConfig(*op); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp index 9c22481e8663..48b5b02f8dbb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp @@ -211,8 +211,9 @@ IREECodegenDialect::verifyOperationAttribute(Operation *op, } } - if (symbol != kTuningSpecEntrypointAttrName) + if (symbol != kTuningSpecEntrypointAttrName) { return success(); + } const std::string requiredByEntrypointMessage = " (required by '" + std::string(kTuningSpecEntrypointAttrName) + "')"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp index cf1e47557a39..98e6571b4cfc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp @@ -42,8 +42,9 @@ LogicalResult ExtractStridedMetadataOp::inferReturnTypes( ExtractStridedMetadataOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { auto sourceType = dyn_cast(adaptor.getSource().getType()); - if (!sourceType) + if (!sourceType) { return failure(); + } unsigned sourceRank = sourceType.getRank(); IndexType indexType = IndexType::get(context); @@ -55,8 +56,9 @@ LogicalResult ExtractStridedMetadataOp::inferReturnTypes( // Offset. inferredReturnTypes.push_back(indexType); // Sizes and strides. - for (unsigned i = 0; i < sourceRank * 2; ++i) + for (unsigned i = 0; i < sourceRank * 2; ++i) { inferredReturnTypes.push_back(indexType); + } return success(); } @@ -282,8 +284,9 @@ LogicalResult InnerTiledOp::verify() { SmallVector indexingMaps = getIndexingMapsArray(); // Verify that an indexing map was specified for each operand. - if (indexingMaps.size() != expectedNumIns + expectedNumOuts) + if (indexingMaps.size() != expectedNumIns + expectedNumOuts) { return emitOpError("expected an indexing map for each operand"); + } // Verify that each index map has 'numIterators' inputs, no symbols, and // that the number of map outputs equals the rank of its associated @@ -292,9 +295,10 @@ LogicalResult InnerTiledOp::verify() { for (const auto &it : llvm::enumerate(indexingMaps)) { auto index = it.index(); auto map = it.value(); - if (map.getNumSymbols() != 0) + if (map.getNumSymbols() != 0) { return emitOpError("expected indexing map ") << index << " to have no symbols"; + } auto shapedType = opTypes[index]; unsigned rank = shapedType.getRank(); // Verify that the map has the right number of inputs, outputs, and indices. @@ -370,9 +374,11 @@ LogicalResult InnerTiledOp::verify() { } static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) - if (targetExpr == map.getResult(i)) + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + if (targetExpr == map.getResult(i)) { return i; + } + } return -1; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp index e9c4c751a0cd..d437d0771740 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp @@ -329,8 +329,9 @@ struct UKernelOpsBufferizationInterface SmallVector nonTensorResultValues; for (OpResult result : op->getResults()) { Type resultType = result.getType(); - if (isa(resultType)) + if (isa(resultType)) { continue; + } nonTensorResultTypes.push_back(resultType); nonTensorResultValues.push_back(result); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp index 8a2ed597c344..cd55d18c0efb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp @@ -177,8 +177,9 @@ SmallVector deriveThreadTileSizes(Operation *op) { .Case( [&](IREE::LinalgExt::MapScatterOp scatterOp) -> SmallVector { ShapedType inputType = scatterOp.getInputType(); - if (!inputType.hasStaticShape()) + if (!inputType.hasStaticShape()) { return {}; + } ArrayRef loopBounds = inputType.getShape(); int64_t elemBits = inputType.getElementTypeBitWidth(); int64_t vectorSize = kPreferredCopyNumBits / elemBits; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 6d548c86c46a..45595fcc1fbb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -700,8 +700,9 @@ Attribute MMAAttr::getDistributionMappingKind() const { OpFoldResult MMAAttr::getDistributionWorkerCount(OpBuilder &, Location, Operation *) const { - if (!getDistributionMappingKind()) + if (!getDistributionMappingKind()) { return OpFoldResult(); + } return getAsIndexOpFoldResult(getContext(), getSubgroupSize()); } @@ -1832,8 +1833,9 @@ DataTiledScaledMMAAttr::verifyIndexingMaps(ArrayRef maps) const { std::optional TargetAttr::getCUDAComputeCapability() const { StringRef arch = getArch(); - if (!arch.starts_with("sm_")) + if (!arch.starts_with("sm_")) { return false; + } APInt version; if (arch.substr(3).getAsInteger(10, version)) { return false; @@ -1844,14 +1846,16 @@ std::optional TargetAttr::getCUDAComputeCapability() const { bool TargetAttr::supportsTF32InputMMAOps() const { // TODO: scan the list of MMA ops to decude after plumbing through support // for NVIDIA TensorCore MMA ops. - if (auto cc = getCUDAComputeCapability()) + if (auto cc = getCUDAComputeCapability()) { return cc >= 80; + } return false; } bool TargetAttr::supportsSyncMMAOps() const { - if (auto cc = getCUDAComputeCapability()) + if (auto cc = getCUDAComputeCapability()) { return cc >= 80; + } return false; } @@ -1984,8 +1988,9 @@ getLoopBounds(ArrayRef loopRanges, for (auto [loopRange, givenTileSize] : llvm::zip_equal(loopRanges, givenTileSizes)) { // No loop if the tile size is 0. - if (isZeroInteger(givenTileSize)) + if (isZeroInteger(givenTileSize)) { continue; + } lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); steps.push_back(givenTileSize); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 555635a3e3dc..cd573602e77e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -348,10 +348,12 @@ static std::optional getMmaScheduleFromProblemAndTarget( for (IREE::GPU::ScaledMMAAttr smma : target.getWgp().getScaledMma()) { // Intrinsics that do not specify a distribution kind cannot be // distributed. - if (!smma.getDistributionMappingKind()) + if (!smma.getDistributionMappingKind()) { continue; - if (smma.getSubgroupSize() != targetSubgroupSize) + } + if (smma.getSubgroupSize() != targetSubgroupSize) { continue; + } auto [m, n, k, kB] = smma.getScaledMNKShape(); SmallVector elementTypes; @@ -365,10 +367,12 @@ static std::optional getMmaScheduleFromProblemAndTarget( for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { // Intrinsics that do not specify a distribution kind cannot be // distributed. - if (!mma.getDistributionMappingKind()) + if (!mma.getDistributionMappingKind()) { continue; - if (mma.getSubgroupSize() != targetSubgroupSize) + } + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } auto [mSize, nSize, kSize] = mma.getMNKShape(); auto [aType, bType, cType] = mma.getABCElementTypes(); @@ -1554,8 +1558,9 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target, int64_t lossFactor = 32; for (; lossFactor >= 1; lossFactor >>= 1) { - if (distributeToThreads(numThreads, lossFactor) == 1) + if (distributeToThreads(numThreads, lossFactor) == 1) { break; + } } } @@ -1733,8 +1738,9 @@ setDirectConvolutionLoweringConfig(IREE::GPU::TargetAttr target, return failure(); } - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); const int64_t splitReductionTripCnt = getSplitReductionTripCount(entryPoint); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index d8f738e88c73..4093391acd77 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -110,8 +110,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, SmallVector mmaAttrs; mmaAttrs.reserve(wgp->mmaCount); - for (int i = 0; i < wgp->mmaCount; ++i) + for (int i = 0; i < wgp->mmaCount; ++i) { mmaAttrs.push_back(MMAAttr::get(context, wgp->mmaOps[i])); + } SmallVector scaledMmaAttrs; scaledMmaAttrs.reserve(wgp->scaledMmaCount); @@ -814,10 +815,12 @@ std::optional getARMGPUTargetDetails(StringRef target) { } StringRef normalizeARMGPUTarget(StringRef target) { - if (target == "valhall") + if (target == "valhall") { return "valhall1"; - if (target.starts_with("valhall")) + } + if (target.starts_with("valhall")) { return target; + } return llvm::StringSwitch(target.lower()) .Cases({"mali-g715", "mali-g615"}, "valhall4") @@ -954,15 +957,19 @@ std::optional getNVIDIAGPUTargetDetails(StringRef target) { } StringRef normalizeNVIDIAGPUTarget(StringRef target) { - if (target.starts_with("sm_")) + if (target.starts_with("sm_")) { return target; + } - if (target.starts_with("rtx40")) + if (target.starts_with("rtx40")) { return "sm_89"; - if (target.starts_with("rtx30")) + } + if (target.starts_with("rtx30")) { return "sm_86"; - if (target.starts_with("rtx20")) + } + if (target.starts_with("rtx20")) { return "sm_75"; + } return llvm::StringSwitch(target.lower()) .Case("a100", "sm_80") @@ -1002,22 +1009,26 @@ const WgpDetails *getAdrenoWgpDetails() { } bool verifyQualcommGPUTarget(StringRef target) { - if (target == "adreno") + if (target == "adreno") { return true; + } StringRef t = target; - if (!t.consume_front("adreno-")) + if (!t.consume_front("adreno-")) { return false; + } // The can exist an optional L at the end. - if (t.ends_with("l")) + if (t.ends_with("l")) { t = t.drop_back(); + } // Check whether we have a product number unsigned number = 0; // StringRef::consumeInteger() returns true to signify errors. - if (t.size() != 3 || t.consumeInteger(10, number)) + if (t.size() != 3 || t.consumeInteger(10, number)) { return false; + } return true; } @@ -1036,8 +1047,9 @@ std::optional getQualcommGPUTargetDetails(StringRef target) { // Adreno-750: https://vulkan.gpuinfo.org/displayreport.php?id=27414 // Adreno-740: https://vulkan.gpuinfo.org/displayreport.php?id=19218 // Adreno-730: https://vulkan.gpuinfo.org/displayreport.php?id=19382 - if (verifyQualcommGPUTarget(target)) + if (verifyQualcommGPUTarget(target)) { return TargetDetails{adrenoWgp, nullptr}; + } return std::nullopt; } @@ -1103,9 +1115,11 @@ TargetAttr getMetalTargetDetails(MLIRContext *context) { TargetAttr getCUDATargetDetails(StringRef target, StringRef features, MLIRContext *context) { - if (std::optional details = getNVIDIAGPUTargetDetails(target)) + if (std::optional details = + getNVIDIAGPUTargetDetails(target)) { return createTargetAttr(*details, normalizeNVIDIAGPUTarget(target), features, context); + } return nullptr; } @@ -1147,8 +1161,9 @@ StringRef normalizeHIPTarget(StringRef target) { StringRef normalizeVulkanAMDGPUTarget(StringRef target) { // We cannot accept rdnaN as a target for LLVM AMDGPU backend; so the // following is only meant for Vulkan but not HIP. - if (target.starts_with("rdna")) + if (target.starts_with("rdna")) { return target; + } return normalizeAMDGPUTarget(target); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp index 78828768eedf..242accb8bdcd 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp @@ -128,13 +128,15 @@ static LogicalResult checkSingleCombiner(linalg::LinalgOp op) { SmallVector combinerOps; if (matchReduction(op.getRegionOutputArgs(), index, combinerOps) && combinerOps.size() == 1) { - if (foundSingleReductionOutput) + if (foundSingleReductionOutput) { return failure(); + } foundSingleReductionOutput = true; continue; } - if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity()) + if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity()) { return failure(); + } } if (!foundSingleReductionOutput) { return failure(); @@ -666,8 +668,9 @@ LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, } } - if (subgroupSize == 0) + if (subgroupSize == 0) { return failure(); + } FailureOr bitWidth = getBitWidth(op); if (failed(bitWidth)) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp index d6da7774b767..8d4d88cfccf5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BufferizationInterfaces.cpp @@ -38,8 +38,9 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options, state); - if (failed(resultBuffer)) + if (failed(resultBuffer)) { return failure(); + } result.push_back(*resultBuffer); } else { result.push_back(opOperand.get()); @@ -121,8 +122,9 @@ struct BarrierRegionOpBufferizationInterface memrefType = bufferization::getBufferType( barrierOp.getOperand(argNum), options, state, invocationStack); } - if (failed(memrefType)) + if (failed(memrefType)) { return failure(); + } return cast(*memrefType); } @@ -207,8 +209,9 @@ struct ValueBarrierOpBufferizationInterface auto srcMemrefType = bufferization::getBufferType( barrierOp.getInputs()[cast(value).getResultNumber()], options, state, invocationStack); - if (failed(srcMemrefType)) + if (failed(srcMemrefType)) { return failure(); + } return cast(*srcMemrefType); } @@ -280,8 +283,9 @@ struct YieldOpBufferizationInterface if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options, state); - if (failed(maybeBuffer)) + if (failed(maybeBuffer)) { return failure(); + } newResults.push_back(*maybeBuffer); } else { newResults.push_back(value); @@ -443,8 +447,9 @@ struct BufferResourceCastOpBufferizationInterface assert(value.getDefiningOp() == castOp && "invalid value"); auto srcMemrefType = bufferization::getBufferType( castOp.getInput(), options, state, invocationStack); - if (failed(srcMemrefType)) + if (failed(srcMemrefType)) { return failure(); + } auto baseMemrefType = cast(srcMemrefType.value()); if (!hasStorageBufferMemSpace(baseMemrefType)) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp index 18f45ed12e30..42f96aca3611 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/DistributeInnerTiledToLanes.cpp @@ -63,8 +63,9 @@ LogicalResult fuseProducersGreedily(RewriterBase &rewriter, // Materialize the slice of the producer in place. std::optional fusedProducer = scf::tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); - if (!fusedProducer) + if (!fusedProducer) { continue; + } // We have no way to know whether a multi-use value can be yielded from the // parallel loop so never yield a replacement. @@ -73,8 +74,9 @@ LogicalResult fuseProducersGreedily(RewriterBase &rewriter, for (auto tiledOp : fusedProducer->tiledOps) { for (OpOperand &operand : tiledOp->getOpOperands()) { auto sliceOp = operand.get().getDefiningOp(); - if (!sliceOp) + if (!sliceOp) { continue; + } candidates.push_back(sliceOp); } } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index d67ced9ed05c..5f53373f4cb9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -1281,8 +1281,9 @@ convertScaledContractionToInnerTiledMma( lhsInnerPerm, rhsInnerPerm, sc1InnerPerm, sc2InnerPerm, accInnerPerm}; SmallVector identityPerm = {0, 1}; if (lhsInnerPerm == identityPerm && rhsInnerPerm == identityPerm && - accInnerPerm == identityPerm) + accInnerPerm == identityPerm) { perms = std::nullopt; + } IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig = getLoweringConfig(linalgOp); @@ -1424,8 +1425,9 @@ FailureOr convertContractionToInnerTiledMma( SmallVector identityPerm = {0, 1}; if (lhsInnerPerm == identityPerm && rhsInnerPerm == identityPerm && - accInnerPerm == identityPerm) + accInnerPerm == identityPerm) { perms = std::nullopt; + } IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig = getLoweringConfig(linalgOp); @@ -1875,12 +1877,15 @@ void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns) { //===---------------------------------------------------------------------===// static bool isLaneMappableForall(scf::ForallOp forallOp) { - if (forallOp.getNumResults() > 0) + if (forallOp.getNumResults() > 0) { return false; - if (forallOp.getRank() != 1) + } + if (forallOp.getRank() != 1) { return false; - if (!forallOp.getMapping().has_value()) + } + if (!forallOp.getMapping().has_value()) { return false; + } Attribute mapping = *forallOp.getMapping()->getValue().begin(); if (mapping != IREE::GPU::LaneIdAttr::get(forallOp.getContext(), 0)) { return false; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp index 16a9ff3b5e30..1a2e62edb373 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.cpp @@ -98,10 +98,11 @@ void ShapedRefType::print(AsmPrinter &printer) const { ArrayRef shape = getShape(); for (int64_t dim : shape) { - if (ShapedType::isDynamic(dim)) + if (ShapedType::isDynamic(dim)) { printer << '?'; - else + } else { printer << dim; + } printer << 'x'; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp index 74de12f6eaf6..8d0e5d04c5af 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertSRefToMemRef.cpp @@ -1115,8 +1115,10 @@ struct ConvertWhileOp final : OpConversionPattern { auto newOp = scf::WhileOp::create(rewriter, op.getLoc(), resultTypes, inits); for (auto i : {0u, 1u}) { - if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) + if (failed( + rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) { return failure(); + } auto &dstRegion = newOp.getRegion(i); rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); } @@ -1159,17 +1161,20 @@ void ConvertSRefToMemRefPass::runOnOperation() { // only implements context specific conversions. auto isLegallyTypedOp = [&](Operation *op) -> bool { for (Type type : op->getResultTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } for (Region ®ion : op->getRegions()) { for (Type type : region.getArgumentTypes()) { - if (isIllegalType(type)) + if (isIllegalType(type)) { return false; + } } } if (auto funcInterface = dyn_cast(op)) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp index 25d9d5c3ab61..cffd8cf605c4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp @@ -435,11 +435,13 @@ NestedLayoutAttr::computeThreadIds(Value threadId, int64_t subgroupSize, SmallVector subgroupDimToResult, threadDimToResult; if (failed(basisFromSizesStrides(getSubgroupTile(), getSubgroupStrides(), - subgroupBasis, subgroupDimToResult))) + subgroupBasis, subgroupDimToResult))) { return {}; + } if (failed(basisFromSizesStrides(getThreadTile(), getThreadStrides(), - threadBasis, threadDimToResult))) + threadBasis, threadDimToResult))) { return {}; + } // Add the subgroup_size to the end of the subgroup delinearization basis. subgroupBasis.push_back(subgroupSize); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp index f8f422d0c334..b9a8c076015a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp @@ -68,8 +68,9 @@ static ParseResult parseIndexVecs(OpAsmParser &parser, SmallVectorImpl &indexVecs, SmallVectorImpl &indexVecTypes, ArrayAttr &indexed) { - if (parser.parseLSquare()) + if (parser.parseLSquare()) { return failure(); + } SMLoc loc; SmallVector indexedArr; @@ -127,11 +128,13 @@ static void printIndexVecs(OpAsmPrinter &p, Operation *op, static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op, SmallVector elidedAttrs = {}) { elidedAttrs.push_back(TransferGatherOp::getOperandSegmentSizeAttr()); - if (op.getPermutationMap().isMinorIdentity()) + if (op.getPermutationMap().isMinorIdentity()) { elidedAttrs.push_back(op.getPermutationMapAttrName()); + } // Elide in_bounds attribute if all dims are out-of-bounds. - if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; })) + if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; })) { elidedAttrs.push_back(op.getInBoundsAttrName()); + } p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -140,8 +143,9 @@ void TransferGatherOp::print(OpAsmPrinter &p) { printIndexVecs(p, *this, getIndexVecs(), getIndexVecs().getTypes(), getIndexedAttr()); p << ", " << getPadding(); - if (getMask()) + if (getMask()) { p << ", " << getMask(); + } printTransferAttrs(p, *this, {"indexed"}); p << " : " << getShapedType() << ", " << getType(); } @@ -151,9 +155,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds) { - if (!isa(shapedType)) + if (!isa(shapedType)) { return op->emitOpError( "requires source to be a memref or ranked tensor type"); + } Type elementType = shapedType.getElementType(); DataLayout dataLayout = DataLayout::closest(op); @@ -272,30 +277,36 @@ LogicalResult TransferGatherOp::verify() { : VectorType(); auto sourceElementType = shapedType.getElementType(); - if (static_cast(getIndices().size()) != shapedType.getRank()) + if (static_cast(getIndices().size()) != shapedType.getRank()) { return emitOpError("requires ") << shapedType.getRank() << " indices"; + } if (failed(verifyTransferOp(cast(getOperation()), shapedType, vectorType, maskType, - inferredMaskType, permutationMap, getInBounds()))) + inferredMaskType, permutationMap, + getInBounds()))) { return failure(); + } if (auto sourceVectorElementType = dyn_cast(sourceElementType)) { // Source has vector element type. // Check that 'sourceVectorElementType' and 'paddingType' types match. - if (sourceVectorElementType != paddingType) + if (sourceVectorElementType != paddingType) { return emitOpError( "requires source element type and padding type to match."); + } } else { // Check that 'paddingType' is valid to store in a vector type. - if (!VectorType::isValidElementType(paddingType)) + if (!VectorType::isValidElementType(paddingType)) { return emitOpError("requires valid padding vector elemental type"); + } // Check that padding type and vector element types match. - if (paddingType != sourceElementType) + if (paddingType != sourceElementType) { return emitOpError( "requires formal padding and source of the same elemental type"); + } } if (failed(verifyPermutationMap(permutationMap, @@ -353,28 +364,34 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, OpAsmParser::UnresolvedOperand maskInfo; // Parsing with support for paddingValue. if (parser.parseOperand(sourceInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) { return failure(); + } SmallVector indexVecTypes; ArrayAttr indexed; - if (parseIndexVecs(parser, indexVecInfo, indexVecTypes, indexed)) + if (parseIndexVecs(parser, indexVecInfo, indexVecTypes, indexed)) { return failure(); + } result.addAttribute("indexed", indexed); - if (parser.parseComma() || parser.parseOperand(paddingInfo)) + if (parser.parseComma() || parser.parseOperand(paddingInfo)) { return failure(); + } ParseResult hasMask = parser.parseOptionalComma(); if (hasMask.succeeded()) { - if (parser.parseOperand(maskInfo)) + if (parser.parseOperand(maskInfo)) { return failure(); + } } // Parse attributes and types. if (parser.parseOptionalAttrDict(result.attributes) || - parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) + parser.getCurrentLocation(&typesLoc) || + parser.parseColonTypeList(types)) { return failure(); + } // Check if number of types given are correct. int64_t nRequiredTypes = 2; @@ -387,10 +404,12 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, // sourceTy, resultTy auto shapedType = dyn_cast(types[0]); VectorType vectorType = dyn_cast(types[1]); - if (!shapedType || !isa(shapedType)) + if (!shapedType || !isa(shapedType)) { return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - if (!vectorType) + } + if (!vectorType) { return parser.emitError(typesLoc, "requires vector type"); + } auto permMapAttrName = TransferGatherOp::getPermutationMapAttrName(result.name); Attribute permMapAttr = result.attributes.get(permMapAttrName); @@ -414,12 +433,14 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, parser.resolveOperands(indexVecInfo, indexVecTypes, typesLoc, result.operands) || parser.resolveOperand(paddingInfo, shapedType.getElementType(), - result.operands)) + result.operands)) { return failure(); + } if (hasMask.succeeded()) { - if (dyn_cast(shapedType.getElementType())) + if (dyn_cast(shapedType.getElementType())) { return parser.emitError( maskInfo.location, "does not support masks with vector element type"); + } if (vectorType.getRank() != permMap.getNumResults()) { return parser.emitError(typesLoc, "expected the same rank for the vector and the " @@ -428,8 +449,9 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). auto maskType = vector::inferTransferOpMaskType(vectorType, permMap); - if (parser.resolveOperand(maskInfo, maskType, result.operands)) + if (parser.resolveOperand(maskInfo, maskType, result.operands)) { return failure(); + } } result.addAttribute(TransferGatherOp::getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr( diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp index d54ba2e05d06..c0df96ac24e6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/BufferizationInterfaces.cpp @@ -55,8 +55,9 @@ struct TransferGatherOpInterface "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, gatherOp.getBase(), options, state); - if (failed(buffer)) + if (failed(buffer)) { return failure(); + } replaceOpWithNewBufferizedOp( rewriter, gatherOp, gatherOp.getVectorType(), *buffer, gatherOp.getIndices(), gatherOp.getIndexVecs(), gatherOp.getIndexed(), diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp index d8f6caf063a5..766aa6aac6da 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp @@ -116,8 +116,9 @@ struct ConvertHALEntryPointFuncOp LogicalResult matchAndRewrite(func::FuncOp stdFuncOp, func::FuncOpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - if (!stdFuncOp.isPublic()) + if (!stdFuncOp.isPublic()) { return failure(); + } FunctionType fnType = stdFuncOp.getFunctionType(); if (fnType.getNumInputs() != 0 || fnType.getNumResults() != 0) { stdFuncOp->emitWarning() @@ -773,8 +774,9 @@ struct RewriteCallOpABI : public OpRewritePattern { PatternRewriter &rewriter) const override { auto symbol = dyn_cast(callOp.getCallableForCallee()); auto flatSymbol = dyn_cast_if_present(symbol); - if (!flatSymbol) + if (!flatSymbol) { return failure(); + } // Ensure the target function is extern. // To support conversion inserting calls in local patterns that can't add @@ -821,8 +823,9 @@ struct RewriteExternCallOpToDynamicImportCallOp // Ignore indirect calls (they're probably already converted imports). auto symbol = dyn_cast(callOp.getCallableForCallee()); auto flatSymbol = dyn_cast_if_present(symbol); - if (!flatSymbol) + if (!flatSymbol) { return failure(); + } // Ensure the target function is extern. // To support conversion inserting calls in local patterns that can't add @@ -1139,8 +1142,9 @@ void ConvertToLLVMPass::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.insert(abi, typeConverter); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { return signalPassFailure(); + } } // Post conversion patterns. diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp index d82b1b6ec9ea..bbd2f35d5af9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp @@ -339,8 +339,9 @@ HALDispatchABI::getProcessorType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified(context, "iree_hal_processor_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto uint64Type = IntegerType::get(context, 64); SmallVector fieldTypes; @@ -365,8 +366,9 @@ HALDispatchABI::getEnvironmentType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified( context, "iree_hal_executable_environment_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto opaquePtrType = LLVM::LLVMPointerType::get(context); SmallVector fieldTypes; @@ -399,8 +401,9 @@ HALDispatchABI::getDispatchStateType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified( context, "iree_hal_executable_dispatch_state_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto uint8Type = IntegerType::get(context, 8); auto uint16Type = IntegerType::get(context, 16); @@ -453,8 +456,9 @@ HALDispatchABI::getWorkgroupStateType(MLIRContext *context, llvm::sys::ScopedLock lock(sMutex); auto structType = LLVM::LLVMStructType::getIdentified( context, "iree_hal_executable_workgroup_state_v0_t"); - if (structType.isInitialized()) + if (structType.isInitialized()) { return structType; + } auto uint16Type = IntegerType::get(context, 16); auto uint32Type = IntegerType::get(context, 32); @@ -583,8 +587,9 @@ static StringRef getDimName(int32_t dim) { // the ops if MLIR or LLVM is likely to reject them. static bool isLocationValidForDI(Location loc) { // Unknown locations are passed as null and DI doesn't like that. - if (isa(loc)) + if (isa(loc)) { return false; + } // MLIR currently can't handle name-only locations. We do this check to ensure // there's at least one real location MLIR can pass along. if (auto callLoc = dyn_cast(loc)) { @@ -604,11 +609,13 @@ static bool isLocationValidForDI(Location loc) { static Value buildArgDI(Operation *forOp, int argNum, Value value, Twine name, LLVM::DITypeAttr type, OpBuilder &builder) { - if (!clVerboseDebugInfo) + if (!clVerboseDebugInfo) { return value; + } auto loc = forOp->getLoc(); - if (!isLocationValidForDI(loc)) + if (!isLocationValidForDI(loc)) { return value; + } auto scopeAttr = getLocalScopeAttr(forOp); LLVM::DbgValueOp::create(builder, loc, value, LLVM::DILocalVariableAttr::get( @@ -621,11 +628,13 @@ static Value buildArgDI(Operation *forOp, int argNum, Value value, Twine name, static Value buildValueDI(Operation *forOp, Value value, Twine name, LLVM::DITypeAttr type, OpBuilder &builder) { - if (!clVerboseDebugInfo) + if (!clVerboseDebugInfo) { return value; + } auto loc = forOp->getLoc(); - if (!isLocationValidForDI(loc)) + if (!isLocationValidForDI(loc)) { return value; + } auto scopeAttr = getLocalScopeAttr(forOp); LLVM::DbgValueOp::create(builder, loc, value, LLVM::DILocalVariableAttr::get( @@ -1379,8 +1388,9 @@ Value HALDispatchABI::getIndexValue(Location loc, int64_t value, Value HALDispatchABI::castValueToType(Location loc, Value value, Type resultType, OpBuilder &builder) { // NOTE: we should handle more cases here (and proper sign extension). - if (value.getType() == resultType) + if (value.getType() == resultType) { return value; + } return builder.createOrFold(loc, resultType, value); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 7dbb37bda123..ec7957be9d21 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -201,8 +201,9 @@ static void getRangeBounds(TilingInterface op, SmallVectorImpl &lb, SmallVector loopRange = op.getIterationDomain(builder); auto getStaticValue = [](OpFoldResult ofr) -> int64_t { std::optional intVal = getConstantIntValue(ofr); - if (!intVal) + if (!intVal) { return ShapedType::kDynamic; + } return intVal.value(); }; lb = llvm::map_to_vector(loopRange, @@ -332,8 +333,9 @@ static int64_t getVectorSize(mlir::FunctionOpInterface entryPointFn, static int64_t getVectorSize(mlir::FunctionOpInterface entryPointFn, ShapedType shapedType) { Type elementType = shapedType.getElementType(); - if (!elementType.isIntOrFloat()) + if (!elementType.isIntOrFloat()) { return 1; + } unsigned byteWidth = IREE::Util::getRoundedElementByteWidth(elementType); return getVectorSize(entryPointFn, byteWidth); } @@ -385,12 +387,14 @@ getMinTilingSizesForEachDim(mlir::FunctionOpInterface entryPointFn, for (auto [index, map] : llvm::enumerate(op.getIndexingMapsArray())) { // Check the fastest varying dimension of the operand. Set the vector size // of the corresponding loop to the vector size. - if (map.getNumResults() == 0) + if (map.getNumResults() == 0) { continue; + } auto fastestVaryingDimExpr = dyn_cast(map.getResults().back()); - if (!fastestVaryingDimExpr) + if (!fastestVaryingDimExpr) { continue; + } unsigned fastestVaryingDim = fastestVaryingDimExpr.getPosition(); // If the indexing map has result it has to be a shaped type. @@ -923,8 +927,9 @@ getDefaultDistributedLevelTileSizes(Operation *op, // Final fix up of the tile sizes to make sure that they divide the problem // size to make it vectorizable. for (auto i : llvm::seq(0, distributedTileSizes.size())) { - if (!distributedTileSizes[i]) + if (!distributedTileSizes[i]) { continue; + } distributedTileSizes[i] = getMaxDistributionTileSize( lbs[i], ubs[i], distributedTileSizes[i], adjustedMinTileSizes[i], config.allowIncompleteTile); @@ -950,12 +955,14 @@ static void splitParallelAndReductionTiles( llvm::enumerate(tilingOp.getLoopIteratorTypes())) { if (iteratorType == utils::IteratorType::parallel) { reductionSizes[index] = 0; - if (reductionScalableFlags) + if (reductionScalableFlags) { (*reductionScalableFlags)[index] = false; + } } else { parallelSizes[index] = 0; - if (parallelScalableFlags) + if (parallelScalableFlags) { (*parallelScalableFlags)[index] = false; + } } } } @@ -965,8 +972,9 @@ static void setAlwaysVectorizeSizes(linalg::LinalgOp op, SmallVector staticLoopRanges = op.getStaticLoopRanges(); for (auto [index, size, iterType] : llvm::enumerate(staticLoopRanges, op.getIteratorTypesArray())) { - if (ShapedType::isStatic(size)) + if (ShapedType::isStatic(size)) { continue; + } vecTileSizes[index] = 1; } LDBG() << "Set always-vectorize sizes: " << vecTileSizes; @@ -1372,8 +1380,9 @@ setMatmulPeelingRootConfig(mlir::FunctionOpInterface entryPointFn, // The LLVM backend struggles to legalize non-power-of-two scalable vectors, // hence the extra rounding up. for (auto [index, size] : llvm::enumerate(roundedVecTileSizes)) { - if (!size) + if (!size) { continue; + } roundedVecTileSizes[index] = roundUpToPow2(size, /*predicate=*/inputVecScalableTileFlags[index]); @@ -1501,8 +1510,9 @@ static FailureOr nonWideningLinalgElementType(linalg::LinalgOp op) { } assert(!inputAndOutputElementTypes.empty() && "expected linalg op to have input and output types"); - if (!llvm::all_equal(inputAndOutputElementTypes)) + if (!llvm::all_equal(inputAndOutputElementTypes)) { return failure(); + } return inputAndOutputElementTypes[0]; } @@ -1522,8 +1532,9 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics( mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp op, int64_t vectorSize, SmallVectorImpl &sizes, SmallVectorImpl &scalableSizeFlags) { - if (sizes.empty()) + if (sizes.empty()) { getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags); + } // Find the smallest type size in the matmul. SmallVector matmulTypes; @@ -1534,11 +1545,13 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics( int64_t minSize = std::numeric_limits::max(); for (Type mmType : matmulTypes) { - if (auto shType = dyn_cast(mmType)) + if (auto shType = dyn_cast(mmType)) { mmType = shType.getElementType(); + } - if (mmType.isSignlessIntOrFloat()) + if (mmType.isSignlessIntOrFloat()) { minSize = std::min(minSize, int64_t{mmType.getIntOrFloatBitWidth()}); + } } LDBG() << "Smallest type found: " << minSize << " bits"; @@ -1567,14 +1580,16 @@ getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, linalg::LinalgOp op, int64_t vectorSize, SmallVectorImpl &sizes, SmallVectorImpl &scalableSizeFlags) { - if (sizes.empty()) + if (sizes.empty()) { getDefaultMatmulVectorSizes(op, vectorSize, sizes, scalableSizeFlags); + } // TODO: support widening matmul. // Determines n dimension tile size with VLEN for // nonWideningLinalgElementType. FailureOr elementType = nonWideningLinalgElementType(op); - if (failed(elementType)) + if (failed(elementType)) { return; + } // nativeVectorSize is cacluated with VLEN and LMUL=2. int64_t nativeVectorSize = getNativeVectorSizeInBytes(entryPointFn); @@ -1591,8 +1606,9 @@ getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, } FailureOr cDims = linalg::inferContractionDims(op); - if (failed(cDims) || cDims->m.size() != 1) + if (failed(cDims) || cDims->m.size() != 1) { return; + } // Use 7 x lmul4 to fully utilize vector registers. sizes[0] = 7; // Calculate tile size for the main vector dimension (N). @@ -1620,12 +1636,14 @@ getMatmulAArch64SMEVectorSizes(linalg::LinalgOp op, // Double-check the operation is one that is supported for lowering to ArmSME. Operation *rawOp = op.getOperation(); if (!(IREE::LinalgExt::isPureMatmul(rawOp) || - isa(rawOp))) + isa(rawOp))) { return; + } auto elementType = nonWideningLinalgElementType(op); - if (failed(elementType)) + if (failed(elementType)) { return; + } // TODO(macdue): Come up with some heuristics to pick the appropriate tiling // for SME, i.e. optimal layout based on static sizes. @@ -2023,9 +2041,10 @@ getMmt4dLoweringConfig(linalg::LinalgOp op, DictionaryAttr targetConfig) { bool scalableTilesFound = false; // If scalable vectorization is enabled, adjust the vector tile sizes and the // corresponding scalable flags. - if (targetConfig && isScalableVectorizationEnabled()) + if (targetConfig && isScalableVectorizationEnabled()) { scalableTilesFound = adjustVectorSizesForScalableVectorization( op, targetConfig, M0, N0, vecTileSizes, vecScalableTileFlags); + } // In the existence of scalable tiles, we do not yet support limiting vector // sizes as this assumes static tile sizes. // TODO: extend this mechanism to handle _scalable_ tile sizes as well. @@ -2133,8 +2152,9 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, // but it does not know that it is working on packed domain. We need to take // inner tile sizes into account and adjust the distribution tile sizes. for (auto [pos, size] : llvm::zip_equal(dimPos, innerTiles)) { - if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) + if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) { continue; + } distTileSizes[pos] = distTileSizes[pos] / size; distTileSizes[pos] = std::max(distTileSizes[pos], int64_t{1}); } @@ -2191,8 +2211,9 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn, ArrayRef dimPos = op.getInnerDimsPos(); for (auto [pos, size, scalable] : llvm::zip_equal(dimPos, innerTiles, scalableFlags)) { - if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) + if (distTileSizes[pos] == 0 || ShapedType::isDynamic(size)) { continue; + } int64_t alignedTileSize = llvm::alignTo(distTileSizes[pos], size); distTileSizes[pos] = roundUpToPow2(alignedTileSize, scalable); } @@ -2520,8 +2541,9 @@ static void getTransposeX86VectorSizes( linalg::GenericOp genericOp, IREE::HAL::ExecutableTargetAttr targetAttr, ArrayRef minTileSizes, SmallVectorImpl &sizes) { if (!targetAttr || !hasAVX2Feature(targetAttr.getConfiguration()) || - !x86TransposeLoweringPrecondition(genericOp)) + !x86TransposeLoweringPrecondition(genericOp)) { return; + } if (llvm::count_if(minTileSizes, [](int64_t tileSize) { return tileSize > 1; }) != 2) { @@ -2561,12 +2583,14 @@ static void getTransposeX86VectorSizes( static void getTransposeAArch64VectorSizes( linalg::GenericOp genericOp, IREE::HAL::ExecutableTargetAttr targetAttr, SmallVectorImpl &sizes, SmallVectorImpl &scalableFlags) { - if (!targetAttr || !isLinalgGeneric2DTranspose(genericOp)) + if (!targetAttr || !isLinalgGeneric2DTranspose(genericOp)) { return; + } auto elementType = nonWideningLinalgElementType(genericOp); - if (failed(elementType)) + if (failed(elementType)) { return; + } if (hasSMEFeature(targetAttr.getConfiguration()) && isScalableVectorizationEnabled() && !clDisableArmSMETiling) { @@ -2599,12 +2623,14 @@ getTransposeVectorSizes(mlir::FunctionOpInterface entryPointFn, scalableFlags); } - if (tileSizes.empty()) + if (tileSizes.empty()) { return std::nullopt; + } // If scalable flags are empty, assume target doesn't care about scalability. - if (scalableFlags.empty()) + if (scalableFlags.empty()) { scalableFlags = SmallVector(tileSizes.size(), false); + } LDBG() << "Transpose vector sizes: " << tileSizes; LDBG() << "Transpose vector scalable flags: " << scalableFlags; @@ -2621,15 +2647,17 @@ setTransposeLikeOpRootConfig(mlir::FunctionOpInterface entryPointFn, assert(!getLoweringConfig(genericOp) && "expected lowering_config is not set"); - if (!linalgOpInfo.isTranspose()) + if (!linalgOpInfo.isTranspose()) { return failure(); + } LDBG() << "Setting transpose-like op root configuration"; std::optional vecDims = getTransposeVectorSizes( entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo); - if (!vecDims) + if (!vecDims) { return failure(); + } auto [vecSizes, vecScalableDims] = *vecDims; @@ -2667,10 +2695,12 @@ static LogicalResult setElementwiseGenericOpRootConfig( LDBG() << "Setting elementwise generic op root configuration"; unsigned numLoops = genericOp.getNumLoops(); - if (numLoops == 0) + if (numLoops == 0) { return failure(); - if (!linalg::isElementwise(genericOp)) + } + if (!linalg::isElementwise(genericOp)) { return failure(); + } DistributionHeuristicConfig distConfig; distConfig.allowIncompleteTile = true; @@ -2797,13 +2827,15 @@ enum class Conv2DDimOrder { static Conv2DDimOrder getConv2DDimOrder(Operation *op) { if (isa(op)) + linalg::PoolingNchwMaxOp>(op)) { return Conv2DDimOrder::Nchw; + } if (isa(op)) + linalg::DepthwiseConv2DNhwcHwcOp>(op)) { return Conv2DDimOrder::Nhwc; + } llvm::llvm_unreachable_internal("unsupported conv op"); } @@ -2890,42 +2922,54 @@ getNhwcConvVectorSizes(mlir::FunctionOpInterface entryPointFn, if (targetAttr) { DictionaryAttr targetConfig = targetAttr.getConfiguration(); if (isX86(targetConfig)) { - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, 8, vectorSize, 1, 1, 8}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, 8, vectorSize, 1, 3}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, 8, vectorSize, 1, 8}; + } llvm_unreachable("unsupported conv"); } if (isRISCV(targetConfig)) { - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, 8, vectorSize * 2, 1, 1, 8}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, 8, vectorSize, 1, 3}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, 8, vectorSize * 2, 1, 8}; + } llvm_unreachable("unsupported conv"); } if (isAArch64(targetConfig)) { - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, 32, 64, 1, 1, 16}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, 4, 4, 1, 4}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, 32, 64, 1, 16}; + } llvm_unreachable("unsupported conv"); } } // Get default hard-coded tile sizes if we couldn't compute anything // better. - if (is2DConvOp(op)) + if (is2DConvOp(op)) { return {1, 1, vectorSize, vectorSize, 1, 1, vectorSize}; - if (is2DDepthConvOp(op)) + } + if (is2DDepthConvOp(op)) { return {1, 1, vectorSize, vectorSize, 1, vectorSize}; - if (is2DPoolingOp(op)) + } + if (is2DPoolingOp(op)) { return {1, 1, vectorSize, vectorSize, 1, vectorSize}; + } llvm_unreachable("unsupported conv"); } @@ -3713,8 +3757,9 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn, linalgOp.getNumLoops(), false); for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { auto unpackOp = opOperand->get().getDefiningOp(); - if (!unpackOp) + if (!unpackOp) { continue; + } foundUnPackOp = true; auto idxMap = linalgOp.getMatchingIndexingMap(opOperand); @@ -3732,19 +3777,22 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn, ArrayRef dimPos = unpackOp.getInnerDimsPos(); for (auto [pos, size, scalable] : llvm::zip_equal(dimPos, innerTiles, scalableFlags)) { - if (ShapedType::isDynamic(size)) + if (ShapedType::isDynamic(size)) { continue; + } auto dimExpr = dyn_cast(idxMap.getResult(pos)); - if (!dimExpr) + if (!dimExpr) { return failure(); + } int mappedPos = dimExpr.getPosition(); alignedSizes[mappedPos] = std::lcm(alignedSizes[mappedPos], size); vecParallelScalableTileFlags[mappedPos] = scalable; } } - if (!foundUnPackOp) + if (!foundUnPackOp) { return success(); + } LDBG() << "The tile sizes for each dimension should be aligned to " << alignedSizes; @@ -3758,8 +3806,9 @@ adjustTileSizesForRootUnPackOp(mlir::FunctionOpInterface entryPointFn, for (IREE::CPU::LoweringConfigLevelInfo &info : tilingInfo) { SmallVector &tileSizes = info.sizes; for (auto idx : llvm::seq(0, tileSizes.size())) { - if (tileSizes[idx] == 0) + if (tileSizes[idx] == 0) { continue; + } int64_t alignedTileSize = llvm::alignTo(tileSizes[idx], alignedSizes[idx]); tileSizes[idx] = roundUpToPow2( @@ -3938,13 +3987,15 @@ setTranslationInfoAndRootConfig(mlir::FunctionOpInterface entryPointFn, ArrayRef computeOps) { // Make sure that lowering_config is not preset on any compute ops. for (auto computeOp : computeOps) { - if (getLoweringConfig(computeOp)) + if (getLoweringConfig(computeOp)) { return failure(); + } } FailureOr rootOp = getRootOperation(computeOps); - if (failed(rootOp)) + if (failed(rootOp)) { return failure(); + } Operation *rootOperation = rootOp.value(); // Handle the case with no known root operation. diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp index 48e404cea3ab..1961fe1e6125 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPU2DScalableTo1DScalable.cpp @@ -91,8 +91,9 @@ class LLVMCPU2DScalableTo1DScalablePass }; static bool opKnownToSupport2DScalableVectorizationWithArmSME(Operation *op) { - if (auto genericOp = dyn_cast(op)) + if (auto genericOp = dyn_cast(op)) { return isLinalgGeneric2DTranspose(genericOp); + } return isa(op); } @@ -206,16 +207,18 @@ dropScalabilityFromUnsupportedOperations(mlir::FunctionOpInterface funcOp, scf::SCFTilingOptions options; setSCFTileSizes(options, tilingOp, loopTileSizes, /*tileScalableFlags=*/{}); auto tilingResult = scf::tileUsingSCF(rewriter, tilingOp, options); - if (failed(tilingResult)) + if (failed(tilingResult)) { return failure(); + } // Update the lowering config of the new tiled operations. IREE::CPU::LoweringConfigAttr newLoweringConfig = getLoweringConfigWithNewVectorSizes(loweringConfigAttr, *vectorSizes, newScalableFlags); for (auto *newOp : tilingResult->tiledOps) { - if (isa(newOp)) + if (isa(newOp)) { setLoweringConfig(newOp, newLoweringConfig); + } } rewriter.replaceOp(tilingOp, tilingResult->replacements); @@ -225,8 +228,9 @@ dropScalabilityFromUnsupportedOperations(mlir::FunctionOpInterface funcOp, void LLVMCPU2DScalableTo1DScalablePass::runOnOperation() { if (failed(dropScalabilityFromUnsupportedOperations(getOperation(), - assumeArmSME))) + assumeArmSME))) { signalPassFailure(); + } } } // namespace diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp index e3eeb49d3ed0..815fd344b0e8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignConstantOrdinals.cpp @@ -23,8 +23,9 @@ struct LLVMCPUAssignConstantOrdinalsPass // Get a constant key -> ordinal mapping. auto keyOrdinals = variantOp.gatherConstantOrdinals(); - if (keyOrdinals.empty()) + if (keyOrdinals.empty()) { return; + } // Update placeholders to hold the concrete ordinal values. // Eventually MLIR or LLVM will inline them. @@ -33,8 +34,9 @@ struct LLVMCPUAssignConstantOrdinalsPass llvm::make_early_inc_range(moduleOp.getOps())) { auto keyAttr = globalOp->getAttr( IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); - if (!keyAttr) + if (!keyAttr) { continue; + } auto it = keyOrdinals.find(keyAttr); if (it == keyOrdinals.end()) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp index 4827f7c3ef08..b7e6726d1560 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUAssignImportOrdinals.cpp @@ -37,13 +37,15 @@ struct LLVMCPUAssignImportOrdinalsPass for (auto globalOp : llvm::make_early_inc_range(moduleOp.getOps())) { auto keyAttr = globalOp->getAttrOfType(importKeyAttr); - if (!keyAttr) + if (!keyAttr) { continue; + } uniqueKeys.insert(keyAttr); ordinalGlobals[keyAttr].push_back(globalOp); } - if (uniqueKeys.empty()) + if (uniqueKeys.empty()) { return; + } auto sortedKeys = uniqueKeys.takeVector(); llvm::stable_sort(sortedKeys, [](auto lhs, auto rhs) { return lhs.getValue() < rhs.getValue(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp index 95bd62f46f81..d35cc4c2df47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUCheckIRBeforeLLVMConversion.cpp @@ -40,8 +40,9 @@ struct LLVMCPUCheckIRBeforeLLVMConversionPass /// defined for HAL LLVMCPU target). static LogicalResult checkStackAllocationSize(mlir::FunctionOpInterface funcOp) { - if (funcOp.getFunctionBody().empty()) + if (funcOp.getFunctionBody().empty()) { return success(); + } // In rare cases where the attribute is not present in the module, a value of // 32KB will be taken. @@ -73,8 +74,9 @@ checkStackAllocationSize(mlir::FunctionOpInterface funcOp) { int allocaSize = 1; auto allocaType = cast(allocaOp.getType()); for (auto dimSize : allocaType.getShape()) { - if (ShapedType::isDynamic(dimSize)) + if (ShapedType::isDynamic(dimSize)) { continue; + } allocaSize *= dimSize; } for (auto operand : allocaOp.getDynamicSizes()) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp index 3e66355ad9a6..1d38122ce7ed 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp @@ -48,13 +48,15 @@ void LLVMCPUMmt4dVectorLoweringPass::runOnOperation() { std::optional numLoops; funcOp.walk([&](vector::ContractionOp op) { - if (numLoops) + if (numLoops) { return signalPassFailure(); + } numLoops = op.getIndexingMapsArray()[0].getNumDims(); }); // No vector.contract op to optimize. - if (!numLoops) + if (!numLoops) { return; + } { // Fold consumer add ops into the contraction op itself. diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp index 2a688c74523e..1edcb02b504e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp @@ -30,8 +30,9 @@ namespace { // stages. void collectLoopsToPeel(Operation *op, llvm::SmallSetVector &loopsToPeel) { - if (!iree_compiler::getLoweringConfig(op)) + if (!iree_compiler::getLoweringConfig(op)) { return; + } int maxNumLoopsToPeel = TypeSwitch(op) .Case([](auto linalgOp) { @@ -44,8 +45,9 @@ void collectLoopsToPeel(Operation *op, for (int i = 0; i < maxNumLoopsToPeel; ++i) { op = op->getParentOfType(); auto loop = cast_or_null(op); - if (!loop || iree_compiler::isTiledAndDistributedLoop(loop)) + if (!loop || iree_compiler::isTiledAndDistributedLoop(loop)) { break; + } LDBG() << "Loop to peel\n " << *op; loopsToPeel.insert(loop); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp index 4f5555e841fb..e6d8e35d45c0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp @@ -223,8 +223,9 @@ static LogicalResult verifyLoweringConfiguration(FunctionOpInterface funcOp, return WalkResult::advance(); } auto loweringConfig = getLoweringConfig(op); - if (!loweringConfig) + if (!loweringConfig) { return WalkResult::advance(); + } return verificationFn(op, loweringConfig); }); return failure(walkResult.wasInterrupted()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp index b02963a6dcd2..14bd061c545f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTile.cpp @@ -57,15 +57,17 @@ void LLVMCPUTilePass::runOnOperation() { SmallVector computeOps = getComputeOps(funcOp); for (auto computeOp : computeOps) { auto op = dyn_cast(computeOp); - if (!op || op.getLoopIteratorTypes().empty()) + if (!op || op.getLoopIteratorTypes().empty()) { continue; + } // For now do not tile `tensor.pad` operations. The `tensor.pad` // operations might be those introduced by the padding-based // codegeneration strategy. Those are not meant to be tiled again. // Need a better way for handling this, but this works for now. - if (isa(computeOp)) + if (isa(computeOp)) { continue; + } IREE::Codegen::LoweringConfigAttrInterface maybeLoweringConfig = getLoweringConfig(op); @@ -104,8 +106,9 @@ void LLVMCPUTilePass::runOnOperation() { std::move(tileScalableFlags)); FailureOr tiledResults = scf::tileUsingSCF(rewriter, op, options); - if (failed(tiledResults)) + if (failed(tiledResults)) { continue; + } rewriter.replaceOp(op, tiledResults->replacements); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp index 263c336d894a..83768362d7ec 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorTransposeLowering.cpp @@ -24,8 +24,9 @@ static bool has16x16Transpose(mlir::FunctionOpInterface funcOp) { bool res = false; funcOp.walk([&](vector::TransposeOp op) { auto srcGtOneDims = isTranspose2DSlice(op); - if (failed(srcGtOneDims)) + if (failed(srcGtOneDims)) { return WalkResult::advance(); + } VectorType srcType = op.getSourceVectorType(); int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp index 3c53e1ad239d..9ba70c0ce6a4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp @@ -104,36 +104,43 @@ bool hasI8mmFeature(DictionaryAttr targetConfig) { bool isLinalgGeneric2DTranspose(linalg::GenericOp genericOp) { // Check op has 2 dimensions. - if (genericOp.getNumLoops() != 2) + if (genericOp.getNumLoops() != 2) { return false; + } // Check op has single input and output. - if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) + if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) { return false; + } // Check all iterators are parallel. - if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { return false; + } // Check that the two indexing maps are a permutation of each other. SmallVector indexingMaps = genericOp.getIndexingMapsArray(); bool isTranspose = (indexingMaps[0].isPermutation() && indexingMaps[1].isIdentity()) || (indexingMaps[1].isPermutation() && indexingMaps[0].isIdentity()); - if (!isTranspose) + if (!isTranspose) { return false; + } // Make sure the region only contains a yield op. Block &body = genericOp.getRegion().front(); - if (!llvm::hasSingleElement(body)) + if (!llvm::hasSingleElement(body)) { return false; + } auto yieldOp = cast(body.getTerminator()); // The yield op should return the block argument corresponding to the input. auto yieldArg = dyn_cast(yieldOp.getValues()[0]); - if (!yieldArg || yieldArg.getArgNumber() != 0 || yieldArg.getOwner() != &body) + if (!yieldArg || yieldArg.getArgNumber() != 0 || + yieldArg.getOwner() != &body) { return false; + } return true; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index a5f09bba5a1b..974527d74fed 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -853,10 +853,11 @@ class MMTKernelGenerator { // the constraints string. Not confusing at all! inputs.append(lhs.begin(), lhs.end()); for (const auto &v : rhs) { - if (cast(v.getType()).getNumElements() == 1) + if (cast(v.getType()).getNumElements() == 1) { inputs.push_back(extract(rewriter, loc, v, 0)); - else + } else { inputs.push_back(v); + } } inputs.append(acc.begin(), acc.end()); // Create the inline asm op. @@ -1039,8 +1040,9 @@ struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); - if (!inLhs || !inRhs) + if (!inLhs || !inRhs) { return failure(); + } auto loc = contractionOp.getLoc(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index b0995cafd0a8..4fa209a2aa7c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -37,11 +37,13 @@ void ConvertToDynamicSharedMemory(ModuleOp moduleOp) { moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) { // Check that the global associated with this addressOfOp has shared memory // space. - if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3) + if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3) { addressOfOps.push_back(addressOfOp); + } }); - if (addressOfOps.size() == 0) + if (addressOfOps.size() == 0) { return; + } OpBuilder builder(moduleOp); builder.setInsertionPoint(&moduleOp.front()); auto type = @@ -118,8 +120,9 @@ struct ScalarizeMathOp : public OpRewritePattern { LogicalResult matchAndRewrite(MathOpTy mathOp, PatternRewriter &rewriter) const override { auto vecType = dyn_cast(mathOp.getType()); - if (!vecType) + if (!vecType) { return failure(); + } Location loc = mathOp.getLoc(); Value newVector = arith::ConstantOp::create(rewriter, loc, vecType, rewriter.getZeroAttr(vecType)); @@ -151,8 +154,9 @@ struct ConvertSharedMemAllocOp : public OpRewritePattern { LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter &rewriter) const override { - if (!hasSharedMemoryAddressSpace(allocOp.getType())) + if (!hasSharedMemoryAddressSpace(allocOp.getType())) { return failure(); + } ArrayRef shape = allocOp.getType().getShape(); if (ShapedType::isDynamicShape(shape)) { return failure(); @@ -164,15 +168,16 @@ struct ConvertSharedMemAllocOp : public OpRewritePattern { } else { // If no alignment specified align at least to the size of an element. Type elType = allocOp.getType().getElementType(); - if (auto shapeType = dyn_cast(elType)) + if (auto shapeType = dyn_cast(elType)) { alignement = shapeType.getNumElements() * shapeType.getElementTypeBitWidth() / 8; - else if (elType.isIndex()) { + } else if (elType.isIndex()) { auto mod = allocOp->getParentOfType(); LowerToLLVMOptions options(mod.getContext(), DataLayout(mod)); alignement = options.getIndexBitwidth() / 8; - } else + } else { alignement = elType.getIntOrFloatBitWidth() / 8; + } } // In CUDA workgroup memory is represented by a global variable. MemRefType allocType = allocOp.getType(); @@ -262,8 +267,9 @@ class ConvertFunc : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { FunctionType fnType = funcOp.getFunctionType(); (void)fnType; - if (!funcOp.isPublic()) + if (!funcOp.isPublic()) { return failure(); + } // illegal FuncOp must have 0 inputs. assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0); @@ -296,8 +302,9 @@ class ConvertFunc : public ConvertOpToLLVMPattern { FailureOr> maybeBindingsInfo = analyzeSubspans(subspans, numBindings, getTypeConverter()); - if (failed(maybeBindingsInfo)) + if (failed(maybeBindingsInfo)) { return failure(); + } auto bindingsInfo = std::move(*maybeBindingsInfo); SmallVector llvmInputTypes; @@ -309,8 +316,9 @@ class ConvertFunc : public ConvertOpToLLVMPattern { // All the push constants are i32 and go at the end of the argument list. llvmInputTypes.resize(numBindings + numConstants, rewriter.getI32Type()); - if (!llvmInputTypes.empty()) + if (!llvmInputTypes.empty()) { signatureConverter.addInputs(llvmInputTypes); + } // Construct newFunc with all attributes except return type & symbol name. SmallVector funcAttrs; @@ -384,8 +392,9 @@ struct ConvertIREEBindingSubspanOp final ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); - if (!llvmFuncOp) + if (!llvmFuncOp) { return failure(); + } assert(llvmFuncOp.getNumArguments() > 0); Location loc = op->getLoc(); @@ -489,8 +498,9 @@ struct ConvertIREEConstantOp final ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); - if (!llvmFuncOp) + if (!llvmFuncOp) { return failure(); + } assert(llvmFuncOp.getNumArguments() > 0); auto ireeConstantOp = cast(op); @@ -572,8 +582,9 @@ struct HALInterfaceWorkgroupOpsConverter final gpu::Dimension::z}; NewOpTy newOp = rewriter.replaceOpWithNewOp(op, op.getType(), dimAttr[index]); - if (IntegerAttr bound = op.getUpperBoundAttr()) + if (IntegerAttr bound = op.getUpperBoundAttr()) { newOp.setUpperBoundAttr(bound); + } return success(); } }; @@ -602,23 +613,26 @@ struct ConvertIREEUtilAssumeIntOp final ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); - if (!llvmFuncOp) + if (!llvmFuncOp) { return failure(); + } Location loc = op.getLoc(); auto updateConds = [&](std::optional &conds, Value cond) { - if (!conds) + if (!conds) { conds = cond; - else + } else { conds = LLVM::AndOp::create(rewriter, loc, *conds, cond); + } }; // Materialize the assumptions that aren't atteched directly to arguments // in order to account for the fact that i64 inputs get passed in as a pair // of i32 constants. for (auto [idx, mlirVal, llvmVal] : llvm::enumerate(op.getOperands(), adaptor.getOperands())) { - if (mlirVal.getDefiningOp()) + if (mlirVal.getDefiningOp()) { continue; + } std::optional conds; Type type = llvmVal.getType(); auto [min, max] = op.getUnionedUnsignedRange(idx); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index a9c25fd7d343..9688fc8597cf 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -224,8 +224,9 @@ static LogicalResult setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, linalg::LinalgOp op) { - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); @@ -303,15 +304,17 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, intrinsics.reserve(target.getWgp().getMma().size()); MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getSubgroupSize() != targetSubgroupSize) + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } storeMmaInfo(mma, intrinsics); // Skip adding any virtual intrinsics since they are not tested for // convolutions. } - if (intrinsics.empty()) + if (intrinsics.empty()) { return failure(); + } // TODO: Replace the below with algorithm described in // https://github.com/iree-org/iree/discussions/21506. @@ -429,9 +432,11 @@ debugPrintContractionInfo(StringRef label, unsigned numLoops, contractionDims.n, contractionDims.k}; std::string dimSymbols(numLoops, '*'); for (auto [idx, val] : llvm::enumerate(dimSymbols)) { - for (auto [letter, dim] : llvm::zip_equal(StringRef("bmnk"), dimVals)) - if (llvm::is_contained(dim, idx)) + for (auto [letter, dim] : llvm::zip_equal(StringRef("bmnk"), dimVals)) { + if (llvm::is_contained(dim, idx)) { val = letter; + } + } } DBGS() << "Contraction dims: " << llvm::interleaved_array(dimSymbols) << "\n"; DBGS() << label << ": " << llvm::interleaved_array(sizes) << "\n"; @@ -441,8 +446,9 @@ static LogicalResult setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, linalg::LinalgOp op) { - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); @@ -538,14 +544,16 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, intrinsics.reserve(target.getWgp().getMma().size()); MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getSubgroupSize() != targetSubgroupSize) + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } storeMmaInfo(mma, intrinsics); // Skip adding any virtual intrinsics since they are not tested for matmuls. } - if (intrinsics.empty()) + if (intrinsics.empty()) { return failure(); + } GPUMMAHeuristicSeeds seeds; @@ -704,8 +712,9 @@ setAttentionPipelineAttributes(IREE::GPU::TargetAttr target, static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig( IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, IREE::LinalgExt::AttentionOp op) { - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); @@ -786,8 +795,9 @@ static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig( intrinsics.reserve(target.getWgp().getMma().size()); MLIRContext *context = op.getContext(); for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getSubgroupSize() != targetSubgroupSize) + if (mma.getSubgroupSize() != targetSubgroupSize) { continue; + } storeMmaInfo(mma, intrinsics); // Store info on virtual intrinsics based on current mma if any for (IREE::GPU::VirtualMMAIntrinsic virtualIntrinsic : @@ -798,8 +808,9 @@ static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig( } } - if (intrinsics.empty()) + if (intrinsics.empty()) { return failure(); + } // We assume that P uses the element type of V for input // and both matmuls have f32 as output. It is possible to use other element @@ -1344,8 +1355,9 @@ setVectorDistributionConfig(IREE::GPU::TargetAttr target, Operation *computeOp) { // We haven't properly plumbed through MMA op layouts and conversions for CUDA // to target NVIDIA GPUs. So disable the vector distribution pass for it. - if (!isROCmBackend(target)) + if (!isROCmBackend(target)) { return failure(); + } if (!clGPUEnableVectorDistribution) { LDBG() << "Vector Distribution not enabled, skipping..."; @@ -1409,8 +1421,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, staticNonUnitParallelDimCount += bounds[nDim] != 1 && ShapedType::isStatic(bounds[nDim]); } - if (staticNonUnitParallelDimCount <= 1) + if (staticNonUnitParallelDimCount <= 1) { return failure(); + } // Don't consider operations that don't have a broadcast, those should go // through reductions. @@ -1470,8 +1483,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, } std::optional subgroupSize = std::nullopt; - if (!subgroupSizes.empty()) + if (!subgroupSizes.empty()) { subgroupSize = subgroupSizes.front(); + } // For the LLVMGPUTileAndFuse pipeline, we need to split tile sizes // for workgroup, thread, and reduction. @@ -1599,8 +1613,9 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target, int64_t tileK = config.tileSize[2]; // Since specialization doesn't work for K loop and peeling is not enabled yet // we pick a tileK size that is aligned on the K size. - if (ShapedType::isDynamic(sizeK)) + if (ShapedType::isDynamic(sizeK)) { tileK = 1; + } while (sizeK % tileK != 0) { tileK >>= 1; } @@ -1780,8 +1795,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, shape.back() % (workgroupSize[0] * vectorSize) != 0) { vectorSize /= 2; } - if (vectorSize == 1) // assume there is fastpath + slowpath + if (vectorSize == 1) { // assume there is fastpath + slowpath vectorSize = 4; + } int64_t problemSize = llvm::product_of(shape); if ((problemSize / (preferredSubgroupSize * vectorSize)) < 64) { vectorSize = 1; @@ -1795,8 +1811,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, int64_t id = 0; for (int64_t dim : llvm::reverse(shape)) { // Unit loops are already skipped. - if (dim == 1) + if (dim == 1) { continue; + } if (dim < flatWG) { skipInnerTiling++; workgroupSize[id] = dim; @@ -1806,8 +1823,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, } flatWG = flatWG / dim; id++; - if (flatWG <= 1 || id >= workgroupSize.size()) + if (flatWG <= 1 || id >= workgroupSize.size()) { break; + } } break; } @@ -1838,8 +1856,9 @@ static LogicalResult setRootDefaultConfig(IREE::GPU::TargetAttr target, workgroupTileSizes[depth - 1] = 0; skipInnerTiling--; id++; - if (id >= workgroupSize.size()) + if (id >= workgroupSize.size()) { break; + } continue; } workgroupTileSizes[depth - 1] = workgroupSize[id] * vectorSize; @@ -1880,12 +1899,14 @@ static bool isMatvecLike(linalg::LinalgOp linalgOp) { // TODO: Allow for matvec with fused dequantization. FailureOr dims = linalg::inferContractionDims(linalgOp); - if (failed(dims)) + if (failed(dims)) { return false; + } // TODO: Support batch matvec. - if (!dims->batch.empty()) + if (!dims->batch.empty()) { return false; + } if (dims->m.size() >= 2 || dims->n.size() >= 2 || !llvm::hasSingleElement(dims->k)) { @@ -2007,8 +2028,9 @@ static LogicalResult setArgmaxUkernelConfig( op.getReductionDims(reductionDims); // Currently Argmax UKernel only support 1 reduction dim. - if (reductionDims.size() != 1) + if (reductionDims.size() != 1) { return failure(); + } // Make sure reduction dimensions are static and innermost ones. SmallVector bounds = op.getStaticLoopRanges(); @@ -2082,14 +2104,16 @@ static bool distributeToOneDim(const int64_t inputDim, // Handle 4 elements per thread for the innermost dimension. We need // this for vectorized load. chosenTileSize = 4; - if (inputDim % (dim * chosenTileSize) != 0) + if (inputDim % (dim * chosenTileSize) != 0) { continue; + } } else { - for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) + for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) { if (inputDim % (dim * t) == 0) { chosenTileSize = t; break; } + } } if (chosenTileSize) { wgDimSize = dim; @@ -2192,8 +2216,9 @@ static LogicalResult setConvolutionConfig( // OC -> x if (!distributeToOneDim(oc, /*isInnerMostDim=*/true, residualThreads, residualTilingFactor, workgroupSize[0], - workgroupTileSizes[3])) + workgroupTileSizes[3])) { return failure(); + } // Deduce the configruation for the OW and OH dimension. Try to make them // even if possible given we typically have images with the same height @@ -2219,10 +2244,11 @@ static LogicalResult setConvolutionConfig( auto pipeline = CodeGenPipeline::LLVMGPUVectorize; TileSizesListType tileSizes; // Add reduction tile sizes. - if (isNCHW) + if (isNCHW) { workgroupTileSizes.append({4, 1, 1}); - else if (isNHWC) + } else if (isNHWC) { workgroupTileSizes.append({1, 1, 4}); + } tileSizes.push_back(workgroupTileSizes); // Tile along OH by size 1 to enable downsizing 2-D convolution to 1-D. @@ -2371,8 +2397,9 @@ static void propagateLoweringConfig(Operation *rootOperation, if (IREE::Codegen::LoweringConfigAttrInterface config = getLoweringConfig(rootOperation)) { for (auto op : computeOps) { - if (op == rootOperation) + if (op == rootOperation) { continue; + } setLoweringConfig(op, config); } } @@ -2383,8 +2410,9 @@ static void propagateLoweringConfig(Operation *rootOperation, //===----------------------------------------------------------------------===// LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) + if (!target) { return funcOp.emitError("missing GPU target in #hal.executable.target"); + } auto exportOp = getEntryPoint(funcOp); if (!getTranslationInfo(funcOp) && exportOp) { @@ -2507,8 +2535,9 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) { return success(); } - if (failed(setRootConfig(target, funcOp, rootOperation))) + if (failed(setRootConfig(target, funcOp, rootOperation))) { return funcOp.emitOpError("failed to set root config"); + } if (IREE::Codegen::TranslationInfoAttr translationInfo = getTranslationInfo(funcOp)) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp index 6428cf2bad99..a7841657b4f2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp @@ -23,8 +23,9 @@ struct LLVMGPUAssignConstantOrdinalsPass // Get a constant key -> ordinal mapping. auto keyOrdinals = variantOp.gatherConstantOrdinals(); - if (keyOrdinals.empty()) + if (keyOrdinals.empty()) { return; + } // Update placeholders to hold the concrete ordinal values. // Eventually MLIR or LLVM will inline them. @@ -33,8 +34,9 @@ struct LLVMGPUAssignConstantOrdinalsPass llvm::make_early_inc_range(moduleOp.getOps())) { auto keyAttr = globalOp->getAttr( IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); - if (!keyAttr) + if (!keyAttr) { continue; + } auto it = keyOrdinals.find(keyAttr); if (it == keyOrdinals.end()) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp index 3164c15f210c..77a691ce3cdb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp @@ -66,8 +66,9 @@ struct LLVMGPUCastAddressSpaceFunctionPass final SymbolTable::lookupSymbolIn(moduleOp, callee)); if (fnDecl) { SmallVector callArgumentTypes; - for (auto op : newOperands) + for (auto op : newOperands) { callArgumentTypes.push_back(op.getType()); + } FunctionType functionType = rewriter.getFunctionType( callArgumentTypes, fnDecl->getResultTypes()); fnDecl.setType(functionType); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp index 6fa62b9d5e33..f09091de6ffe 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp @@ -71,8 +71,9 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() { FunctionOpInterface funcOp = getOperation(); IREE::Codegen::TranslationInfoAttr translationInfo = getTranslationInfo(funcOp); - if (!translationInfo) + if (!translationInfo) { return; + } std::optional maybePipeline = getFunctionOpInterfacePassManager(funcOp); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp index 04be0ad56e92..804f4d88ca2f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp @@ -43,8 +43,9 @@ static LogicalResult verifyLoweringConfiguration( IREE::Codegen::TranslationInfoAttr translationInfo) { auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult { auto loweringConfig = getLoweringConfig(op); - if (!loweringConfig) + if (!loweringConfig) { return success(); + } if (translationInfo.getDispatchLoweringPassPipeline() == IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp index 968879feacb2..546ce726be7b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp @@ -56,13 +56,15 @@ static void populateVectorUnrollPatterns(RewritePatternSet &patterns, bool useMmaSyncShape) { auto unrollOrder = [](Operation *op) -> std::optional> { auto contract = dyn_cast(op); - if (!contract) + if (!contract) { return std::nullopt; + } return gpuMmaUnrollOrder(contract); }; auto getNativeShape = [useMmaSyncShape](Operation *op) { - if (useMmaSyncShape) + if (useMmaSyncShape) { return getMmaNativeVectorSize(op); + } return getWmmaNativeVectorSize(op); }; vector::populateVectorUnrollPatterns( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp index d546b1f426b7..852238b5be91 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp @@ -115,13 +115,15 @@ calculateDistributedTileSize(ArrayRef numElements, OpBuilder &builder, unsigned idIdx = 0; std::reverse(distributedDim.begin(), distributedDim.end()); for (unsigned depth : partitionedLoops) { - if (depth >= blockTileSize.size()) + if (depth >= blockTileSize.size()) { continue; + } tileSizesVal[depth] = arith::ConstantIndexOp::create( builder, operation->getLoc(), llvm::divideCeil(blockTileSize[depth], distributedDim[idIdx++])); - if (idIdx == kNumMaxParallelDims) + if (idIdx == kNumMaxParallelDims) { break; + } } return tileSizesVal; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp index 95bca2146d31..d69e8fe68a10 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp @@ -66,8 +66,9 @@ struct PromoteContractOperands final Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v, Type dstElementType) const { Type elementType = getElementTypeOrSelf(v.getType()); - if (elementType == dstElementType) + if (elementType == dstElementType) { return v; + } // vector.contract only allows extension on operands. assert(elementType.getIntOrFloatBitWidth() <= @@ -75,11 +76,13 @@ struct PromoteContractOperands final "vector.contract does not allow truncation of operands"); Type promotedType = dstElementType; - if (auto vecType = dyn_cast(v.getType())) + if (auto vecType = dyn_cast(v.getType())) { promotedType = vecType.clone(promotedType); + } - if (isa(dstElementType)) + if (isa(dstElementType)) { return arith::ExtFOp::create(rewriter, loc, promotedType, v); + } // For integer types, vector.contract only supports signless integer types // and promotion happens via sign extension. return arith::ExtSIOp::create(rewriter, loc, promotedType, v); @@ -409,8 +412,9 @@ struct ContractToChainFMA final : OpRewritePattern { static std::optional getDimPosition(AffineMap map, unsigned dim) { for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - if (map.getDimPosition(i) == dim) + if (map.getDimPosition(i) == dim) { return i; + } } return std::nullopt; } @@ -419,8 +423,9 @@ struct ContractToChainFMA final : OpRewritePattern { ArrayAttr iteratorTypes) { SmallVector dimsIdx; for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - if (vector::isReductionIterator(iteratorTypes[map.getDimPosition(i)])) + if (vector::isReductionIterator(iteratorTypes[map.getDimPosition(i)])) { dimsIdx.push_back(i); + } } return dimsIdx; } @@ -506,8 +511,10 @@ struct UnrollElementwiseOps final : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + if (!OpTrait::hasElementwiseMappableTraits(op) || + op->getNumResults() != 1) { return failure(); + } Location loc = op->getLoc(); VectorType dstVecTy = dyn_cast(op->getResult(0).getType()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 2f97bcabe869..98d047592d4f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -171,8 +171,9 @@ static LogicalResult gpuCopyFn(OpBuilder &builder, Location loc, Value from, if (hasSharedMemoryAddressSpace(cast(to.getType()))) { needsBarrier = true; } - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } Operation *copy = memref::CopyOp::create(builder, loc, from, to); if (needsBarrier) { setMarker(copy, getCopyToWorkgroupMemoryMarker()); @@ -188,14 +189,16 @@ static LogicalResult canReorderWorkgroups(FunctionOpInterface funcOp) { if (!target) { return failure(); } - if (target.getBackend() != "rocm") + if (target.getBackend() != "rocm") { return success(); + } // Workgroup reordering on ROCm currently requires all workgrup counts to be // static. SmallVector workgroupCounts = getStaticNumWorkgroups(funcOp); - if (llvm::any_of(workgroupCounts, ShapedType::isDynamic)) + if (llvm::any_of(workgroupCounts, ShapedType::isDynamic)) { return failure(); + } // This is further restricted to 2D+ grids as we reorder along the X and Y // workgroup IDs. @@ -690,8 +693,9 @@ static LogicalResult gpuVectorCopyFn(OpBuilder &builder, Location loc, if (hasSharedMemoryAddressSpace(cast(to.getType()))) { needsBarrier = true; } - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } VectorType vectorType = VectorType::get(fromType.getShape(), fromType.getElementType()); Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp index 75fcc0f758fa..4520aa619026 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp @@ -102,16 +102,19 @@ annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp, // attribute. FailureOr chipset = getChipsetVersion(builder.getContext(), targetAttr); - if (failed(chipset)) + if (failed(chipset)) { return variantOp.emitError() << "failed to parse amdgpu chipset"; + } - if (chipset->majorVersion != 9 || *chipset < amdgpu::Chipset(9, 4, 0)) + if (chipset->majorVersion != 9 || *chipset < amdgpu::Chipset(9, 4, 0)) { return success(); + } auto inRegAttrName = builder.getStringAttr(LLVM::LLVMDialect::getInRegAttrName()); - for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { funcOp.setArgAttr(i, inRegAttrName, unitAttr); + } return success(); } @@ -142,8 +145,9 @@ struct ROCDLAnnotateKernelForTranslationPass final // Un-exported functions are library functions or otherwise not kernels, so // don't need these annotations. - if (!exportOp) + if (!exportOp) { return; + } if (failed(annotateKernelForTranslation(funcOp, variantOp, exportOp))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp index 9bc90f177f54..9b69bf72063c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLBufferInstructionsOptimization.cpp @@ -82,13 +82,15 @@ void simplifyMaskOps(RewriterBase &rewriter, vector::CreateMaskOp maskOp) { for (Operation *user : maskOp.getResult().getUsers()) { auto readOp = dyn_cast(user); // Only TransferReadOps are supported. - if (!readOp) + if (!readOp) { continue; + } auto sourceType = dyn_cast(readOp.getBase().getType()); // only supported for fat raw buffers. - if (!sourceType || !hasAMDGPUFatRawBufferAddressSpace(sourceType)) + if (!sourceType || !hasAMDGPUFatRawBufferAddressSpace(sourceType)) { continue; + } SmallVector inBounds = readOp.getInBoundsValues(); // Only supported for reads that are fully in_bounds. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp index 7d9f8df3488e..584d433bddff 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLConfigureBufferInstructions.cpp @@ -46,8 +46,9 @@ static Value stripIntegerCasts(Value val) { /// loads, which is a conservative approximatino for workgroup-uniformity that /// can be made more extensive if needed. static bool isDefinitelyWorkgroupUniform(Value arg) { - if (!arg) + if (!arg) { return true; + } SetVector dependencies; BackwardSliceOptions opts; arg = stripIntegerCasts(arg); @@ -60,8 +61,9 @@ static bool isDefinitelyWorkgroupUniform(Value arg) { getBackwardSlice(arg, &dependencies, opts); assert(result.succeeded()); return llvm::all_of(dependencies, [&](Operation *op) { - if (matchPattern(op, m_Constant())) + if (matchPattern(op, m_Constant())) { return true; + } if (isa(op)) { return true; } @@ -116,13 +118,15 @@ struct ROCDLConfigureBufferInstructionsPass final : impl::ROCDLConfigureBufferInstructionsPassBase< ROCDLConfigureBufferInstructionsPass> { void runOnOperation() override { - if (!clROCDLlEnableBufferInstructions) + if (!clROCDLlEnableBufferInstructions) { return; + } mlir::FunctionOpInterface funcOp = getOperation(); // Is this really the best way to skip this pass on non-rocdl targets? IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target || !target.isAMD()) + if (!target || !target.isAMD()) { return; + } // Initialize the DataFlowSolver with IntegerRangeAnalysis. DataFlowSolver solver; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 166b61e87f12..e6ea60f915f2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -96,8 +96,9 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne( mlir::transform::gpu::mapNestedForallToThreadsImpl( rewriter, transformOp, target, getWorkgroupDims(), getSubgroupSize(), getSyncAfterDistribution()); - if (!diag.succeeded()) + if (!diag.succeeded()) { return diag; + } IREE::Codegen::TranslationInfoAttr updatedTranslationInfo = IREE::Codegen::TranslationInfoAttr::get( @@ -161,8 +162,9 @@ replaceAllUsesOfLaneWithin(RewriterBase &b, Value laneId = executeOp.getLaneid(); bool applied = false; for (Operation *user : llvm::make_early_inc_range(laneId.getUsers())) { - if (!executeOp->isProperAncestor(user)) + if (!executeOp->isProperAncestor(user)) { continue; + } b.startOpModification(user); user->replaceUsesOfWith(laneId, zero); b.finalizeOpModification(user); @@ -179,47 +181,61 @@ replaceAllUsesOfLaneWithin(RewriterBase &b, static FailureOr isThreadIdxxZeroPredicate(scf::IfOp ifOp) { if (!ifOp || ifOp.getNumResults() > 0 || ifOp.getThenRegion().getBlocks().size() != 1 || - !ifOp.getElseRegion().empty()) + !ifOp.getElseRegion().empty()) { return failure(); + } auto pred = ifOp.getCondition().getDefiningOp(); - if (!pred) + if (!pred) { return failure(); + } auto EQ = arith::CmpIPredicate::eq; auto SLT = arith::CmpIPredicate::slt; auto SLE = arith::CmpIPredicate::sle; auto ULT = arith::CmpIPredicate::ult; auto ULE = arith::CmpIPredicate::ule; if (auto threadIdOp = pred.getLhs().getDefiningOp()) { - if (threadIdOp.getDimension() != gpu::Dimension::x) + if (threadIdOp.getDimension() != gpu::Dimension::x) { return failure(); - if (pred.getPredicate() == EQ && isZeroInteger(pred.getRhs())) + } + if (pred.getPredicate() == EQ && isZeroInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == SLE && isZeroInteger(pred.getRhs())) + } + if (pred.getPredicate() == SLE && isZeroInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == ULE && isZeroInteger(pred.getRhs())) + } + if (pred.getPredicate() == ULE && isZeroInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == SLT && isOneInteger(pred.getRhs())) + } + if (pred.getPredicate() == SLT && isOneInteger(pred.getRhs())) { return threadIdOp; - if (pred.getPredicate() == ULT && isOneInteger(pred.getRhs())) + } + if (pred.getPredicate() == ULT && isOneInteger(pred.getRhs())) { return threadIdOp; + } } auto SGT = arith::CmpIPredicate::sgt; auto SGE = arith::CmpIPredicate::sge; auto UGT = arith::CmpIPredicate::ugt; auto UGE = arith::CmpIPredicate::uge; if (auto threadIdOp = pred.getRhs().getDefiningOp()) { - if (threadIdOp.getDimension() != gpu::Dimension::x) + if (threadIdOp.getDimension() != gpu::Dimension::x) { return failure(); - if (pred.getPredicate() == EQ && isZeroInteger(pred.getLhs())) + } + if (pred.getPredicate() == EQ && isZeroInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == SGE && isZeroInteger(pred.getLhs())) + } + if (pred.getPredicate() == SGE && isZeroInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == UGE && isZeroInteger(pred.getLhs())) + } + if (pred.getPredicate() == UGE && isZeroInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == SGT && isOneInteger(pred.getLhs())) + } + if (pred.getPredicate() == SGT && isOneInteger(pred.getLhs())) { return threadIdOp; - if (pred.getPredicate() == UGT && isOneInteger(pred.getLhs())) + } + if (pred.getPredicate() == UGT && isOneInteger(pred.getLhs())) { return threadIdOp; + } } return failure(); } @@ -235,8 +251,9 @@ rewriteScfIfAsWarpExecuteOnLane0(RewriterBase &rewriter, Location loc, // Bail if cond is not `if (threadIdx.x == 0)`. FailureOr maybeThreadIdxxOp = isThreadIdxxZeroPredicate(ifOp); - if (failed(maybeThreadIdxxOp)) + if (failed(maybeThreadIdxxOp)) { return failure(); + } // All the code below will be executed on a single warp given a // fixed (threadIdxy, threadIdxz). Note, we reuse @@ -384,8 +401,9 @@ static OpOperand *getWarpResult(gpu::WarpExecuteOnLane0Op warpOp, Value yieldValues = yieldOperand.get(); Operation *definedOp = yieldValues.getDefiningOp(); if (definedOp && fn(definedOp)) { - if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) + if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { return &yieldOperand; + } } } return {}; @@ -414,15 +432,17 @@ struct WarpOpLoad : public OpRewritePattern { LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); - if (!operand) + if (!operand) { return failure(); + } auto load = operand->get().getDefiningOp(); unsigned operandIndex = operand->getOperandNumber(); Value distributedVal = warpOp.getResult(operandIndex); auto indices = llvm::to_vector_of(load.getIndices()); - if (!indices.empty()) + if (!indices.empty()) { return failure(); + } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(warpOp); @@ -458,17 +478,20 @@ struct HoistSharedMemoryAlloc : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(memref::AllocOp alloc, PatternRewriter &rewriter) const override { - if (!iree_compiler::hasSharedMemoryAddressSpace(alloc.getType())) + if (!iree_compiler::hasSharedMemoryAddressSpace(alloc.getType())) { return failure(); + } auto warpParent = alloc->getParentOfType(); - if (!warpParent) + if (!warpParent) { return failure(); + } alloc->moveBefore(warpParent); // Conservatively move the dealloc after the warpOp. This may // extend the liverange of the allocation but is always correct. for (Operation *user : alloc->getUsers()) { - if (isa(user)) + if (isa(user)) { user->moveAfter(warpParent); + } } return success(); } @@ -488,8 +511,9 @@ static void populateMultiReductionLoweringPatterns(Operation *target, static AffineMap simpleDistributionFunction(Value val) { AffineMap map = AffineMap::get(val.getContext()); auto vecType = dyn_cast(val.getType()); - if (!vecType) + if (!vecType) { return map; + } // Create a map (d0, d1) -> (d1) to distribute along the inner // dimension. Once we support n-d distribution we can add more // complex cases. @@ -673,15 +697,17 @@ transform_dialect::VectorToMMAConversionOp::applyToOne( auto diag = DiagnosedSilenceableFailure::success(); if (getUseWmma()) { - if (failed(convertVectorToMMAOps(rewriter, target))) + if (failed(convertVectorToMMAOps(rewriter, target))) { return mlir::emitDefiniteFailure( target, "vector to wmma patterns failed to apply"); + } return listener.checkAndResetError(); } - if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) + if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) { return mlir::emitDefiniteFailure(target, "vector to mma patterns failed to apply"); + } DEBUG_WITH_TYPE(DEBUG_VECTOR_TO_MMA, { @@ -694,10 +720,11 @@ transform_dialect::VectorToMMAConversionOp::applyToOne( RewritePatternSet f32ToTF32patterns(funcOp.getContext()); nvgpu::populateMmaSyncF32ToTF32Patterns(f32ToTF32patterns, nvgpu::MmaSyncF32Lowering::TF32); - if (failed( - applyPatternsGreedily(funcOp, std::move(f32ToTF32patterns), config))) + if (failed(applyPatternsGreedily(funcOp, std::move(f32ToTF32patterns), + config))) { return mlir::emitDefiniteFailure( target, "vector to mma F32ToTF32 patterns failed to apply"); + } return listener.checkAndResetError(); } @@ -826,8 +853,9 @@ static bool isKnownNoEffectsOpWithoutInterface(Operation *op) { /// Returns `true` if the op is defines the parallel region that is subject to /// barrier synchronization. static bool isParallelRegionBoundary(Operation *op) { - if (op->hasAttr("__parallel_region_boundary_for_test")) + if (op->hasAttr("__parallel_region_boundary_for_test")) { return true; + } // We consider functions inside executable variants . return isa(op); @@ -871,12 +899,14 @@ collectEffects(Operation *op, bool ignoreBarriers = true) { // Skip over barriers to avoid infinite recursion (those barriers would ask // this barrier again). - if (ignoreBarriers && isa(op)) + if (ignoreBarriers && isa(op)) { return true; + } // Skip over ops that we know have no effects. - if (isKnownNoEffectsOpWithoutInterface(op)) + if (isKnownNoEffectsOpWithoutInterface(op)) { return true; + } // Collect effect instances the operation. Note that the implementation of // getEffects erases all effect instances that have the type other than the @@ -891,9 +921,11 @@ collectEffects(Operation *op, if (op->hasTrait()) { for (auto ®ion : op->getRegions()) { for (auto &block : region) { - for (auto &innerOp : block) - if (!collectEffects(&innerOp, effects, ignoreBarriers)) + for (auto &innerOp : block) { + if (!collectEffects(&innerOp, effects, ignoreBarriers)) { return false; + } + } } } return true; @@ -915,8 +947,9 @@ static bool getEffectsBefore(Operation *op, SmallVectorImpl &effects, bool stopAtBarrier) { - if (!op->getBlock()) + if (!op->getBlock()) { return true; + } // If there is a non-structured control flow, bail. Region *region = op->getBlock()->getParent(); @@ -930,23 +963,27 @@ getEffectsBefore(Operation *op, for (Operation *it = op->getPrevNode(); it != nullptr; it = it->getPrevNode()) { if (isa(it)) { - if (stopAtBarrier) + if (stopAtBarrier) { return true; - else + } else { continue; + } } - if (!collectEffects(it, effects)) + if (!collectEffects(it, effects)) { return false; + } } } // Stop if reached the parallel region boundary. - if (isParallelRegionBoundary(op->getParentOp())) + if (isParallelRegionBoundary(op->getParentOp())) { return true; + } // Otherwise, keep collecting above the parent operation. - if (!getEffectsBefore(op->getParentOp(), effects, stopAtBarrier)) + if (!getEffectsBefore(op->getParentOp(), effects, stopAtBarrier)) { return false; + } // If the op is loop-like, collect effects from the trailing operations until // we hit a barrier because they can executed before the current operation by @@ -971,16 +1008,18 @@ getEffectsBefore(Operation *op, // If the parent operation is not guaranteed to execute its (single-block) // region once, walk the block. bool conservative = false; - if (!hasSingleExecutionBody(op->getParentOp())) + if (!hasSingleExecutionBody(op->getParentOp())) { op->getParentOp()->walk([&](Operation *in) { - if (conservative) + if (conservative) { return WalkResult::interrupt(); + } if (!collectEffects(in, effects)) { conservative = true; return WalkResult::interrupt(); } return WalkResult::advance(); }); + } return !conservative; } @@ -995,8 +1034,9 @@ static bool getEffectsAfter(Operation *op, SmallVectorImpl &effects, bool stopAtBarrier) { - if (!op->getBlock()) + if (!op->getBlock()) { return true; + } // If there is a non-structured control flow, bail. Region *region = op->getBlock()->getParent(); @@ -1006,25 +1046,30 @@ getEffectsAfter(Operation *op, } // Collect all effects after the op. - if (op != &op->getBlock()->back()) + if (op != &op->getBlock()->back()) { for (Operation *it = op->getNextNode(); it != nullptr; it = it->getNextNode()) { if (isa(it)) { - if (stopAtBarrier) + if (stopAtBarrier) { return true; + } continue; } - if (!collectEffects(it, effects)) + if (!collectEffects(it, effects)) { return false; + } } + } // Stop if reached the parallel region boundary. - if (isParallelRegionBoundary(op->getParentOp())) + if (isParallelRegionBoundary(op->getParentOp())) { return true; + } // Otherwise, keep collecting below the parent operation. - if (!getEffectsAfter(op->getParentOp(), effects, stopAtBarrier)) + if (!getEffectsAfter(op->getParentOp(), effects, stopAtBarrier)) { return false; + } // If the op is loop-like, collect effects from the leading operations until // we hit a barrier because they can executed after the current operation by @@ -1041,8 +1086,9 @@ getEffectsAfter(Operation *op, // operation `op2` at iteration `i-1` and the side effects must be ordered // appropriately. if (isSequentialLoopLike(op->getParentOp())) { - if (isa(op->getBlock()->front())) + if (isa(op->getBlock()->front())) { return true; + } bool exact = collectEffects(&op->getBlock()->front(), effects); return getEffectsAfter(&op->getBlock()->front(), effects, @@ -1053,16 +1099,18 @@ getEffectsAfter(Operation *op, // If the parent operation is not guaranteed to execute its (single-block) // region once, walk the block. bool conservative = false; - if (!hasSingleExecutionBody(op->getParentOp())) + if (!hasSingleExecutionBody(op->getParentOp())) { op->getParentOp()->walk([&](Operation *in) { - if (conservative) + if (conservative) { return WalkResult::interrupt(); + } if (!collectEffects(in, effects)) { conservative = true; return WalkResult::interrupt(); } return WalkResult::advance(); }); + } return !conservative; } @@ -1071,8 +1119,9 @@ getEffectsAfter(Operation *op, static Value getBase(Value v) { while (true) { Operation *definingOp = v.getDefiningOp(); - if (!definingOp) + if (!definingOp) { break; + } bool shouldContinue = TypeSwitch(v.getDefiningOp()) @@ -1090,8 +1139,9 @@ static Value getBase(Value v) { return true; }) .Default([](Operation *) { return false; }); - if (!shouldContinue) + if (!shouldContinue) { break; + } } return v; } @@ -1163,8 +1213,9 @@ static bool maybeCaptured(Value v) { } std::optional knownCaptureStatus = getKnownCapturingStatus(user, v); - if (!knownCaptureStatus || *knownCaptureStatus) + if (!knownCaptureStatus || *knownCaptureStatus) { return true; + } } } @@ -1227,20 +1278,24 @@ static bool mayAlias(Value first, Value second) { // Non-equivalent distinct bases and globals cannot alias. At this point, we // have already filtered out based on values being equal and global name being // equal. - if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1])) + if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1])) { return false; + } bool isArg[] = {isFunctionArgument(first), isFunctionArgument(second)}; // Distinct bases (allocations) cannot have been passed as an argument. - if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0])) + if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0])) { return false; + } // Non-captured base distinct values cannot conflict with another base value. - if (isDistinct[0] && !maybeCaptured(first)) + if (isDistinct[0] && !maybeCaptured(first)) { return false; - if (isDistinct[1] && !maybeCaptured(second)) + } + if (isDistinct[1] && !maybeCaptured(second)) { return false; + } // Otherwise, conservatively assume aliasing. DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n"); @@ -1263,8 +1318,9 @@ static bool mayAlias(MemoryEffects::EffectInstance a, Value v2) { /// cannot alias. static bool mayAlias(MemoryEffects::EffectInstance a, MemoryEffects::EffectInstance b) { - if (a.getResource()->getResourceID() != b.getResource()->getResourceID()) + if (a.getResource()->getResourceID() != b.getResource()->getResourceID()) { return false; + } if (Value v2 = b.getValue()) { return mayAlias(a, v2); } else if (Value v = a.getValue()) { @@ -1287,8 +1343,9 @@ haveConflictingEffects(ArrayRef beforeEffects, for (const MemoryEffects::EffectInstance &before : beforeEffects) { for (const MemoryEffects::EffectInstance &after : afterEffects) { // If cannot alias, definitely no conflict. - if (!mayAlias(before, after)) + if (!mayAlias(before, after)) { continue; + } // Read/read is not a conflict. if (isa(before.getEffect()) && @@ -1313,8 +1370,9 @@ haveConflictingEffects(ArrayRef beforeEffects, // conflicts. // 2. either the program is ill-formed and we are in undefined behavior // territory. - if (isa(before.getEffect())) + if (isa(before.getEffect())) { continue; + } // Other kinds of effects create a conflict, e.g. read-after-write. LLVM_DEBUG( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp index 14f462c45f17..2c476276990b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp @@ -83,8 +83,9 @@ struct MaskResult { }; static MaskResult getMask(Operation *op) { auto transferRead = dyn_cast(op); - if (!transferRead || !transferRead.getMask()) + if (!transferRead || !transferRead.getMask()) { return MaskResult{}; + } vector::ExtractOp maybeExtractOp = transferRead.getMask().getDefiningOp(); auto maskOp = @@ -111,8 +112,9 @@ static MaskResult getMask(Operation *op) { static Value getMaskValue(RewriterBase &rewriter, Operation *op) { MaskResult maskResult = getMask(op); - if (!maskResult.maskOp) + if (!maskResult.maskOp) { return Value(); + } Value count = maskResult.maskOp->getOperands().back(); vector::ExtractOp maybeExtractOp = maskResult.maybeExtractOp; if (maybeExtractOp) { @@ -142,14 +144,18 @@ static Value getValueStored(Operation *writeOp) { } static Operation::operand_range getIndices(Operation *op) { - if (auto vectorReadOp = dyn_cast(op)) + if (auto vectorReadOp = dyn_cast(op)) { return vectorReadOp.getIndices(); - if (auto vectorStoreOp = dyn_cast(op)) + } + if (auto vectorStoreOp = dyn_cast(op)) { return vectorStoreOp.getIndices(); - if (auto transferReadOp = dyn_cast(op)) + } + if (auto transferReadOp = dyn_cast(op)) { return transferReadOp.getIndices(); - if (auto transferWriteOp = dyn_cast(op)) + } + if (auto transferWriteOp = dyn_cast(op)) { return transferWriteOp.getIndices(); + } llvm_unreachable("unsupported op type"); } @@ -196,8 +202,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, llvm::SmallSetVector copyToSharedMem; // Look for all the copy that can be converted to async copy ops. funcOp.walk([&](Operation *writeOp) { - if (!isContiguousStore(writeOp)) + if (!isContiguousStore(writeOp)) { return WalkResult::advance(); + } LDBG() << "--candidate writeOp: " << *writeOp; Value vectorVal = getValueStored(writeOp); if (cast(vectorVal.getType()).getRank() != 1) { @@ -242,8 +249,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, if (!resultsInSupportedAsyncCopy(cast(loadBase.getType()), getIndices(readOp), vecType) || !resultsInSupportedAsyncCopy(cast(storeBase.getType()), - getIndices(writeOp), vecType)) + getIndices(writeOp), vecType)) { return WalkResult::advance(); + } LDBG() << "--writeOp can be made async -> SUCCESS"; copyToSharedMem.insert(writeOp); @@ -263,8 +271,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, // Ignore ops without side effects auto memInterface = dyn_cast(nextNode); if (memInterface && memInterface.hasNoEffect() && - !nextNode->hasTrait()) + !nextNode->hasTrait()) { continue; + } // ignore read from a different address space. if (isa(nextNode)) { Operation *readOp = nextNode; @@ -315,8 +324,9 @@ void createAsyncGroups(RewriterBase &rewriter, mlir::FunctionOpInterface funcOp, nvgpu::DeviceAsyncWaitOp::create(rewriter, funcOp.getLoc(), groupToken, nullptr); // Clean up old stores. - for (Operation *writeOp : group) + for (Operation *writeOp : group) { rewriter.eraseOp(writeOp); + } } } @@ -360,8 +370,9 @@ void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc, needBarrier = true; } else { for (Operation &op : entryBlock->getOperations()) { - if (&op == alloc) + if (&op == alloc) { break; + } if (op.getNumRegions() != 0) { needBarrier = true; break; @@ -372,8 +383,9 @@ void addBarrier(mlir::FunctionOpInterface funcOp, Operation *alloc, } } } - if (!needBarrier) + if (!needBarrier) { return; + } OpBuilder builder(alloc); // TODO: make it a option if needed. if (hasAsyncCopies) { @@ -400,8 +412,9 @@ void packSharedMemoryAlloc(mlir::FunctionOpInterface funcOp) { SmallVector aliasGroups; analyseAllocsForPacking(funcOp, allocs, aliasGroups); // If there is 1 or less alias group there is nothing to do. - if (aliasGroups.size() <= 1) + if (aliasGroups.size() <= 1) { return; + } // Pack all the allocations into one i8 alloc. // We may need to add extra barriers to make sure we are done writting or diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp index 9f831452610f..fcae85d1525b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp @@ -277,14 +277,17 @@ static LogicalResult classifyOperationsIntoStages( LDBG() << "\n=== Final Stage Classification ==="; LDBG() << "--- Read Stage (" << result.readStage.size() << " ops) ---"; - for (Operation *op : result.readStage) + for (Operation *op : result.readStage) { LDBG() << *op; + } LDBG() << "--- Write Stage (" << result.writeStage.size() << " ops) ---"; - for (Operation *op : result.writeStage) + for (Operation *op : result.writeStage) { LDBG() << *op; + } LDBG() << "--- Compute Stage (" << result.computeStage.size() << " ops) ---"; - for (Operation *op : result.computeStage) + for (Operation *op : result.computeStage) { LDBG() << *op; + } return success(); } @@ -365,28 +368,35 @@ populateOpToStageMap(const StageClassification &stages, scf::ForOp forOp, unsigned numStages, llvm::DenseMap &opToStage) { auto assignOp = [&](Operation *op, unsigned stage) { - if (!op || isa(op)) + if (!op || isa(op)) { return; + } opToStage[op] = stage; }; if (numStages == 2) { // Two-stage pipelining: read+write in stage 0, compute in stage 1. - for (Operation *op : stages.readStage) + for (Operation *op : stages.readStage) { assignOp(op, /*stage=*/0); - for (Operation *op : stages.writeStage) + } + for (Operation *op : stages.writeStage) { assignOp(op, /*stage=*/0); - for (Operation *op : stages.computeStage) + } + for (Operation *op : stages.computeStage) { assignOp(op, /*stage=*/1); + } } else { // Three-stage pipelining: read in stage 0, write in stage 1, compute in // stage 2. - for (Operation *op : stages.readStage) + for (Operation *op : stages.readStage) { assignOp(op, /*stage=*/0); - for (Operation *op : stages.writeStage) + } + for (Operation *op : stages.writeStage) { assignOp(op, /*stage=*/1); - for (Operation *op : stages.computeStage) + } + for (Operation *op : stages.computeStage) { assignOp(op, /*stage=*/2); + } } } @@ -513,8 +523,9 @@ invokePipelineForLoop(scf::ForOp forOp, const scf::PipeliningOption &options) { // Helper to check for shared memory. static bool hasSharedMemory(Value val) { auto memrefType = dyn_cast(val.getType()); - if (!memrefType) + if (!memrefType) { return false; + } auto addrSpace = dyn_cast_if_present(memrefType.getMemorySpace()); return addrSpace && addrSpace.getValue() == gpu::AddressSpace::Workgroup; @@ -587,10 +598,12 @@ static SharedBarrierState insertBarriersInRange(RewriterBase &rewriter, state.needBarrierBeforeWrite = false; } - if (hasSharedRead) + if (hasSharedRead) { state.needBarrierBeforeWrite = true; - if (hasSharedWrite) + } + if (hasSharedWrite) { state.needBarrierBeforeRead = true; + } } return state; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp index 743718c3d89b..fcd725f7d171 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp @@ -34,8 +34,9 @@ static LogicalResult setAMDMatmulConfig(linalg::LinalgOp op, if (succeeded(setCooperativeMatrixConfig( target, op, AMDNumSubgroupsPerWorkgroup, AMDNumMNTilesPerSubgroup, AMDCoopMatrixSoftwarePipelineDepth, - AMDCoopMatrixSoftwarePipelineStoreStage))) + AMDCoopMatrixSoftwarePipelineStoreStage))) { return success(); + } int subgroupSize = target.getPreferredSubgroupSize(); const std::array workgroupXY = {subgroupSize / 2, 8}; @@ -69,16 +70,18 @@ LogicalResult setAMDCodeGenConfig(IREE::GPU::TargetAttr target, int subgroupSize = target.getPreferredSubgroupSize(); if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setAMDMatmulConfig(linalgOp, target); + } } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; bool hasPaddedInput = convOp.image().getDefiningOp(); const int bestTilingFactor = (hasPaddedInput ? 16 : 32) * multipler; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp index 5921c4d7612d..b99b0d8b7f22 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp @@ -40,24 +40,28 @@ LogicalResult setAdrenoCodeGenConfig(IREE::GPU::TargetAttr target, Operation *rootOp) { int subgroupSize = target.getPreferredSubgroupSize(); - if (!isa(rootOp)) + if (!isa(rootOp)) { return failure(); + } auto linalgOp = cast(rootOp); - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setAdrenoMatmulConfig(linalgOp, target); + } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); - if (failed(convDimsOrFailure)) + if (failed(convDimsOrFailure)) { return failure(); + } const int bestTilingFactor = (convDimsOrFailure->depth.empty() ? 32 : 16) * multipler; return setConvOpConfig(cast(rootOp), subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp index 1977157cca8c..091e00aad644 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp @@ -40,16 +40,18 @@ LogicalResult setAppleCodeGenConfig(IREE::GPU::TargetAttr target, int subgroupSize = target.getPreferredSubgroupSize(); if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setAppleMatmulConfig(linalgOp, target); + } } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; const int bestTilingFactor = 16 * multipler; return setConvOpConfig(cast(rootOp), subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index 4b012027be90..e296200c1f53 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -101,8 +101,9 @@ createResourceVariable(Location loc, const SubspanResourceInfo &resource, llvm::formatv("__resource_var_{}_{}_", resource.set, resource.binding); variable = spirv::GlobalVariableOp::create( builder, loc, globalVariableType, name, resource.set, resource.binding); - if (resource.aliased) + if (resource.aliased) { variable->setAttr("aliased", builder.getUnitAttr()); + } } else { std::string name = llvm::formatv("__resource_var_indirect_{}_", resource.set); @@ -543,8 +544,9 @@ class ConvertToSPIRVPass final LogicalResult initializeOptions( StringRef options, function_ref errorHandler) override { - if (failed(Pass::initializeOptions(options, errorHandler))) + if (failed(Pass::initializeOptions(options, errorHandler))) { return failure(); + } indexBits = indexBitsOption; return success(); } @@ -561,17 +563,20 @@ void ConvertToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } bool useIndirectBindings = usesIndirectBindingsAttr(moduleOp); for (auto funcOp : moduleOp.getOps()) { auto exportOp = getEntryPoint(funcOp); - if (!exportOp) + if (!exportOp) { continue; - if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) + } + if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) { continue; + } std::optional workgroupSize = exportOp->getWorkgroupSize(); if (!workgroupSize) { exportOp->emitOpError( @@ -757,8 +762,9 @@ void ConvertToSPIRVPass::runOnOperation() { SmallVector functions; for (auto fn : moduleOp.getOps()) { - if (!fn.isPublic()) + if (!fn.isPublic()) { continue; + } functions.push_back(fn); } @@ -770,8 +776,9 @@ void ConvertToSPIRVPass::runOnOperation() { } auto addressingModel = spirv::AddressingModel::Logical; - if (useIndirectBindings) + if (useIndirectBindings) { addressingModel = spirv::AddressingModel::PhysicalStorageBuffer64; + } // Collect all SPIR-V ops into a spirv.module. OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); @@ -781,10 +788,12 @@ void ConvertToSPIRVPass::runOnOperation() { Dialect *spvDialect = spvModule->getDialect(); for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) { // Skip the newly created spirv.module itself. - if (&op == spvModule) + if (&op == spvModule) { continue; - if (op.getDialect() == spvDialect) + } + if (op.getDialect() == spvDialect) { op.moveBefore(body, body->end()); + } } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index a4289924ff6f..d0ea809f68e0 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -54,8 +54,9 @@ using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline; // Check if the given linalg op is fused with another op that may result // in too much shared memory usage. static bool fusedOpMayUseExtraSharedMemory(linalg::LinalgOp matmul) { - if (matmul->getNumResults() != 1) + if (matmul->getNumResults() != 1) { return true; + } auto entryPoint = matmul->getParentOfType(); @@ -105,14 +106,16 @@ static bool tileConvOneDim(const int64_t inputDim, const bool isInnerMostDim, // Handle `vectorSize` elements per thread for the innermost dimension. // We need this for the best utilization of memory. chosenTileSize = vectorSize; - if (inputDim % (dim * chosenTileSize) != 0) + if (inputDim % (dim * chosenTileSize) != 0) { continue; + } } else { - for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) + for (int64_t t = residualTilingFactor; t >= 1; t >>= 1) { if (inputDim % (dim * t) == 0) { chosenTileSize = t; break; } + } } if (chosenTileSize) { wgDimSize = dim; @@ -168,12 +171,14 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, // Restrict to pure 4-D input/output shapes for now. This excludes convolution // ops with 1- or 3-D window sizes. It also excludes 2-D-window convolution // ops like `linalg.depthwise_conv_2d_nhwc_hwcm`. - if (inputShape.size() != 4 || outputShape.size() != 4) + if (inputShape.size() != 4 || outputShape.size() != 4) { return failure(); + } auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); - if (failed(convDimsOrFailure)) + if (failed(convDimsOrFailure)) { return failure(); + } const mlir::linalg::ConvolutionDimensions &convDims = *convDimsOrFailure; LLVM_DEBUG(llvm::dbgs() << "conv: " << linalgOp << "\n" << "conv batch dim: " @@ -231,8 +236,9 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, // We use `vectorSize` as the tile size along IC dimension. If smaller than // 4, it will be unrolled into size 1. - if (ic && !(*ic % vectorSize == 0 || *ic < 4)) + if (ic && !(*ic % vectorSize == 0 || *ic < 4)) { return failure(); + } // The core idea is to distribute the convolution dimensions to the workgroup // Z/Y/X dimensions, with each thread in a workgroup handling multiple vector @@ -263,8 +269,9 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp, // OC -> x if (!tileConvOneDim(oc, /*isInnerMostDim=*/true, vectorSize, residualThreads, residualTilingFactor, workgroupSize[0], - workgroupTileSizes[3])) + workgroupTileSizes[3])) { return failure(); + } // Deduce the configruation for the OW and OH dimension. Try to make them // even if possible given we typically have images with the same height @@ -362,18 +369,21 @@ std::tuple getMatmulBMNKIndex(linalg::LinalgOp op, } else if (inLHS) { // For cases where we have two parallel dimensions only accessed by // the LHS, treat the outer one of them as the batch dimension. - if (mIndex >= 0 && bIndex < 0) + if (mIndex >= 0 && bIndex < 0) { bIndex = mIndex; + } mIndex = i; } else if (inRHS) { // For cases where we have two parallel dimensions only accessed by // the RHS, treat the outer one of them as the batch dimension. - if (nIndex >= 0 && bIndex < 0) + if (nIndex >= 0 && bIndex < 0) { bIndex = nIndex; + } nIndex = i; } - if (lastParallelDim) + if (lastParallelDim) { *lastParallelDim = i; + } } LLVM_DEBUG({ @@ -459,15 +469,17 @@ int64_t getTileBytes(int64_t mTileSize, int64_t nTileSize, int64_t kTileSize, int64_t elementBits, bool promoteC) { int64_t paddingBits = detail::bankConflictReductionPaddingBits / elementBits; int64_t count = (mTileSize + nTileSize) * (kTileSize + paddingBits); - if (promoteC) + if (promoteC) { count += mTileSize * (nTileSize + paddingBits); + } return (elementBits / 8) * count; } int64_t getMultiBufferMemoryUsage(int64_t singleBufferBytes, unsigned depth, unsigned storeStage) { - if (depth == 0) + if (depth == 0) { return singleBufferBytes; + } return singleBufferBytes * (storeStage == 1 ? depth : depth + 1); }; @@ -479,8 +491,9 @@ static bool adjustToVectorLoad(ArrayRef dimMNKSize, int64_t &mTileSize, const int64_t subgroupSize, int64_t vectorSize) { const int64_t totalThreads = wgSize[0] * wgSize[1] * wgSize[2]; LLVM_DEBUG(llvm::dbgs() << "initial total thread = " << totalThreads << "\n"); - if (totalThreads <= subgroupSize) + if (totalThreads <= subgroupSize) { return false; + } const bool canVectorLoadLHS = canPerformVectorAccessUsingAllThreads( {mTileSize, kTileSize}, totalThreads, vectorSize); @@ -490,8 +503,9 @@ static bool adjustToVectorLoad(ArrayRef dimMNKSize, int64_t &mTileSize, LLVM_DEBUG(llvm::dbgs() << "RHS vector load: " << canVectorLoadRHS << "\n"); // If we can perform vector load of neither, just don't use shared memory. - if (!canVectorLoadLHS && !canVectorLoadRHS) + if (!canVectorLoadLHS && !canVectorLoadRHS) { return false; + } // If we can only perform vector load of one operands, adjust the tiling // scheme to see if we can make both work. Increase K to load more data for @@ -499,15 +513,18 @@ static bool adjustToVectorLoad(ArrayRef dimMNKSize, int64_t &mTileSize, if (canVectorLoadLHS && !canVectorLoadRHS) { for (const int scale : {2, 4}) { const int64_t newKTileSize = kTileSize * scale; - if (dimMNKSize[2] % newKTileSize != 0) + if (dimMNKSize[2] % newKTileSize != 0) { continue; + } const int64_t newMTileSize = mTileSize / scale; const int64_t newWgMDim = wgSize[1] / scale; - if (newMTileSize == 0 || newWgMDim == 0) + if (newMTileSize == 0 || newWgMDim == 0) { continue; + } const int64_t newCount = wgSize[0] * newWgMDim * wgSize[2]; - if (newCount <= subgroupSize) + if (newCount <= subgroupSize) { continue; + } if (!canPerformVectorAccessUsingAllThreads({newMTileSize, newKTileSize}, newCount, vectorSize) || !canPerformVectorAccessUsingAllThreads({newKTileSize, nTileSize}, @@ -542,8 +559,9 @@ static bool adjustToPromote(ArrayRef dimMNKSize, int64_t &mTileSize, LLVM_DEBUG(llvm::dbgs() << "subgroup size = " << subgroupSize << "\n"); const int vectorSize = kMaxVectorNumBits / elementBits; if (!adjustToVectorLoad(dimMNKSize, mTileSize, nTileSize, kTileSize, wgSize, - subgroupSize, vectorSize)) + subgroupSize, vectorSize)) { return false; + } // Don't do multibuffering if the inner reduction loop is folded out. if (dimMNKSize[2] == kTileSize) { @@ -563,8 +581,9 @@ static bool adjustToPromote(ArrayRef dimMNKSize, int64_t &mTileSize, // possible. do { if (getMultiBufferMemoryUsage(usedBytes, pipelineDepth, storeStage) <= - maxBytes) + maxBytes) { return true; + } } while (pipelineDepth-- > 1); // If we can't fit in workgroup memory, don't multibuffer. @@ -573,8 +592,9 @@ static bool adjustToPromote(ArrayRef dimMNKSize, int64_t &mTileSize, if (storeStage == 0) { storeStage = 1; if (getMultiBufferMemoryUsage(usedBytes, pipelineDepth, storeStage) <= - maxBytes) + maxBytes) { return true; + } } // Using too much workgroup memory. Try to reduce the tile size for X/Y once @@ -609,23 +629,27 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, auto rhsType = cast(rhs->get().getType()); auto elementBits = static_cast(IREE::Util::getTypeBitWidth(lhsType.getElementType())); - if (!llvm::is_contained({8, 16, 32}, elementBits)) + if (!llvm::is_contained({8, 16, 32}, elementBits)) { return failure(); + } ArrayRef lhsShape = lhsType.getShape(); ArrayRef rhsShape = rhsType.getShape(); - if (llvm::any_of(lhsShape, ShapedType::isDynamic)) + if (llvm::any_of(lhsShape, ShapedType::isDynamic)) { return failure(); - if (llvm::any_of(rhsShape, ShapedType::isDynamic)) + } + if (llvm::any_of(rhsShape, ShapedType::isDynamic)) { return failure(); + } assert(llvm::is_contained({2u, 3u}, op.getNumParallelLoops())); int lastParallelDim = -1; const auto [bIndex, mIndex, nIndex, kIndex] = getMatmulBMNKIndex(op, &lastParallelDim); - if (mIndex < 0 || nIndex < 0 || kIndex < 0) + if (mIndex < 0 || nIndex < 0 || kIndex < 0) { return failure(); + } const bool isBM = bIndex >= 0; SmallVector loopRanges = op.getStaticLoopRanges(); @@ -669,8 +693,9 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, SmallVector workgroupTileSizes(numLoops, 0); SmallVector reductionTileSizes(numLoops, 0); - if (isBM) + if (isBM) { workgroupTileSizes[bIndex] = 1; + } if (!tileMatmulNToWorkgroupX(dimN, bestThreadN, residualThreads, bestX, residualTilingFactor, workgroupSize[0], @@ -722,8 +747,9 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, // Tile all additional reduction dimensions with size 1 to materialize loops. for (auto [i, it] : llvm::enumerate(op.getIteratorTypesArray())) { - if (linalg::isReductionIterator(it) && reductionTileSizes[i] == 0) + if (linalg::isReductionIterator(it) && reductionTileSizes[i] == 0) { reductionTileSizes[i] = 1; + } } TileSizesListType tileSizes; @@ -733,8 +759,9 @@ LogicalResult setMatmulOpConfig(IREE::GPU::TargetAttr target, // Merge reductionTileSizes into workgroupTileSizes--this is needed by the // pipeline passes shared between SPIR-V and LLVMGPU. for (auto [i, it] : llvm::enumerate(op.getIteratorTypesArray())) { - if (linalg::isReductionIterator(it)) + if (linalg::isReductionIterator(it)) { workgroupTileSizes[i] = reductionTileSizes[i]; + } } tileSizes.push_back(workgroupTileSizes); @@ -787,8 +814,9 @@ static LogicalResult setTilingAndMatmulOpConfig(linalg::LinalgOp op, //===----------------------------------------------------------------------===// bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { - if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) + if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) { return false; + } // Look at fused elementwise ops to make sure they are allowed by the // cooperative matrix spec. @@ -802,8 +830,9 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { arith::UIToFPOp, // Special cases of these ops are directly allowed to sue // cooperative matrix types. Other cases can use a loop. - arith::MulFOp>(op)) + arith::MulFOp>(op)) { return false; + } } // Look at operands to make sure we don't have inlined constants. Cooperative @@ -811,8 +840,9 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { // classes. for (Value input : genericOp.getInputs()) { if (isa(input.getType())) { - if (matchPattern(input, m_Constant())) + if (matchPattern(input, m_Constant())) { return false; + } continue; } @@ -822,8 +852,9 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { input = subviewOp.getViewSource(); } if (auto toMemrefOp = input.getDefiningOp()) { - if (matchPattern(toMemrefOp.getTensor(), m_Constant())) + if (matchPattern(toMemrefOp.getTensor(), m_Constant())) { return false; + } } } @@ -833,11 +864,13 @@ bool isCooperativeMatrixFusable(linalg::GenericOp genericOp) { bool needToPrmoteCForCooperativeMatrix(linalg::LinalgOp matmulOp) { assert(matmulOp.hasPureTensorSemantics()); Value result = matmulOp.getOperation()->getResult(0); - if (!result.hasOneUse()) + if (!result.hasOneUse()) { return true; // Be conservative. + } Operation *user = *result.getUsers().begin(); - if (isa(user)) + if (isa(user)) { return false; + } if (auto genericOp = dyn_cast(user)) { return !isCooperativeMatrixFusable(genericOp); } @@ -854,11 +887,13 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, unsigned softwarePipelineStoreStage) { LLVM_DEBUG(llvm::dbgs() << "trying to matmul cooperative matrix config...\n"); // This configuration is only for cooperative matrix. - if (target.getWgp().getMma().empty()) + if (target.getWgp().getMma().empty()) { return failure(); + } - if (op.hasDynamicShape()) + if (op.hasDynamicShape()) { return failure(); + } Value lhs = op.getDpsInputOperand(0)->get(); Value rhs = op.getDpsInputOperand(1)->get(); @@ -867,8 +902,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, int lastParallelDim = -1; const auto [bIndex, mIndex, nIndex, kIndex] = getMatmulBMNKIndex(op, &lastParallelDim); - if (mIndex < 0 || nIndex < 0 || kIndex < 0) + if (mIndex < 0 || nIndex < 0 || kIndex < 0) { return failure(); + } const bool isBM = bIndex >= 0; SmallVector loopRanges = op.getStaticLoopRanges(); @@ -929,8 +965,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, FailureOr schedule = deduceMMASchedule( problem, intrinsics, seeds, sharedMemoryLimitInBytes, subgroupSize, /*cuCount=*/std::nullopt, op.getLoc(), transposedLhs, transposedRhs); - if (failed(schedule)) + if (failed(schedule)) { return failure(); + } assert(schedule->hasSingleDimensions() && "expected single M/N/K dimension"); auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; @@ -940,21 +977,24 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, schedule->mSubgroupCounts[0], 1}; SmallVector vectorSizes(kIndex + 1, 0); - if (isBM) + if (isBM) { vectorSizes[bIndex] = 1; + } vectorSizes[mIndex] = schedule->mSizes[0]; vectorSizes[nIndex] = schedule->nSizes[0]; vectorSizes[kIndex] = schedule->kSizes[0]; SmallVector subgroupTileSizes(lastParallelDim + 1, 0); - if (isBM) + if (isBM) { subgroupTileSizes[bIndex] = 1; + } subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex]; subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex]; SmallVector workgroupTileSizes(lastParallelDim + 1, 0); - if (isBM) + if (isBM) { workgroupTileSizes[bIndex] = 1; + } workgroupTileSizes[mIndex] = schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex]; workgroupTileSizes[nIndex] = @@ -1156,8 +1196,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, // This pipeline eventually generates non-uniform group shuffle ops, which // requires special capability. - if (!target.supportsSubgroupShuffle()) + if (!target.supportsSubgroupShuffle()) { return failure(); + } SmallVector parallelDims; SmallVector reductionDims; @@ -1168,8 +1209,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, int64_t numParallelDims = op.getNumParallelLoops(); // We should have reduction dimensions. - if (reductionDims.empty()) + if (reductionDims.empty()) { return failure(); + } // Make sure reduction dimensions are static and innermost ones. int64_t numDynamicReductionDims = 0; @@ -1188,8 +1230,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, return failure(); } - if (op.getRegionOutputArgs().size() != 1) + if (op.getRegionOutputArgs().size() != 1) { return failure(); + } // Only support projected permutation for now. This could be extended to // projected permutated with broadcast. @@ -1205,8 +1248,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, SmallVector combinerOps; if (matchReduction(op.getRegionOutputArgs(), i, combinerOps) && combinerOps.size() == 1) { - if (foundSingleReductionOutput) + if (foundSingleReductionOutput) { return failure(); + } foundSingleReductionOutput = true; continue; } @@ -1214,8 +1258,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, return failure(); } } - if (!foundSingleReductionOutput) + if (!foundSingleReductionOutput) { return failure(); + } int subgroupSize = target.getPreferredSubgroupSize(); @@ -1253,24 +1298,29 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, } int64_t reductionSize = 1; - for (int64_t dim : reductionDims) + for (int64_t dim : reductionDims) { reductionSize *= bounds[dim]; - if (reductionSize % subgroupSize != 0) + } + if (reductionSize % subgroupSize != 0) { return failure(); + } const Type elementType = cast(op.getDpsInits()[0].getType()).getElementType(); - if (!elementType.isIntOrFloat()) + if (!elementType.isIntOrFloat()) { return failure(); + } unsigned bitWidth = IREE::Util::getTypeBitWidth(elementType); // Reduction distribution only supports 8/16/32 bit types now. - if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) + if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) { return failure(); + } // Let each thread handle `vectorSize` elements. unsigned vectorSize = kMaxVectorNumBits / bitWidth; - while ((reductionSize / vectorSize) % subgroupSize != 0) + while ((reductionSize / vectorSize) % subgroupSize != 0) { vectorSize /= 2; + } // Deduce the workgroup size we should use for reduction. Currently a // workgroup processes all elements in reduction dimensions. Need to make sure @@ -1295,8 +1345,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, int64_t parallelSize = 1; for (int64_t dim : parallelDims) { - if (ShapedType::isStatic(bounds[dim])) + if (ShapedType::isStatic(bounds[dim])) { parallelSize *= bounds[dim]; + } } // Total parallel size that can fill the GPU with enough workgorups. // TODO: query from the target device; roughly 2x hardware compute unit. @@ -1316,8 +1367,9 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, // First, do warp reductions along multiple subgroups. // Second, reduce results from multiple subgroups using single warp reduce. // The final warp reduce requires subgroup count <= subgroup size to work. - if ((groupSize / subgroupSize) > subgroupSize) + if ((groupSize / subgroupSize) > subgroupSize) { return failure(); + } if (hasIncompatibleConsumer(op, groupSize)) { LDBG() << "Reduction has incompatible consumer, limiting workgroup size " @@ -1332,13 +1384,15 @@ static LogicalResult setReductionConfig(IREE::GPU::TargetAttr target, for (int i = reductionDims.size() - 1; i >= 0; --i) { int64_t dim = reductionDims[i]; int64_t bound = bounds[dim]; - if (i == reductionDims.size() - 1) + if (i == reductionDims.size() - 1) { bound /= vectorSize; + } APInt size = GreatestCommonDivisor(APInt(64, uint64_t(remaingGroupSize)), APInt(64, uint64_t(bound))); reductionTileSizes[dim] = size.getSExtValue(); - if (i == reductionDims.size() - 1) + if (i == reductionDims.size() - 1) { reductionTileSizes[dim] *= vectorSize; + } remaingGroupSize /= size.getSExtValue(); } @@ -1461,10 +1515,12 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, // dimensions to 1 for extra dimensions. if (isa(linalgOp.getOperation())) { for (int64_t i = 0, e = workgroupTileSizes.size(); i < e; i++) { - if (workgroupTileSizes[i] != 0) + if (workgroupTileSizes[i] != 0) { break; - if (loopBounds[i] != 1) + } + if (loopBounds[i] != 1) { workgroupTileSizes[i] = 1; + } } } // Scan from the innermost shape dimension and try to deduce the @@ -1473,8 +1529,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, for (auto shapeDim : llvm::reverse(partitionedLoops)) { int64_t loopBound = loopBounds[shapeDim]; // Skip dynamic dimensions. - if (ShapedType::isDynamic(loopBound)) + if (ShapedType::isDynamic(loopBound)) { continue; + } // Try to find some power of two that can devide the current shape dim // size. This vector keeps the candidate tile sizes. @@ -1495,12 +1552,14 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, for (int64_t candidate : candidates) { int64_t scaledTileSize = candidate * scaleToByte; if (loopBound % scaledTileSize != 0) { - if (!lossFactor) + if (!lossFactor) { continue; + } // Skip this candidate if it causes many threads to be idle. int64_t idleThreads = candidate - (loopBound % scaledTileSize); - if (idleThreads > candidate / *lossFactor) + if (idleThreads > candidate / *lossFactor) { continue; + } } // If the workload is too small and we cannot distribute to more than 2 // workgroups, try a smaller tile size to increase parallelism. @@ -1526,8 +1585,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, assert(numThreads % (candidate / vectorSize) == 0); numThreads /= candidate / vectorSize; } else { - if (wgDim == 0) + if (wgDim == 0) { vectorizable = false; + } threadTileSizes[shapeDim] = scaleToByte; workgroupSize[wgDim] = candidate; assert(numThreads % candidate == 0); @@ -1538,8 +1598,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, } // Stop if we have distributed all threads. - if (numThreads == 1) + if (numThreads == 1) { break; + } wgDim++; } return numThreads; @@ -1555,8 +1616,9 @@ static LogicalResult setDefaultOpConfig(IREE::GPU::TargetAttr target, int64_t lossFactor = 32; for (; lossFactor >= 1; lossFactor >>= 1) { - if (distributeToThreads(numThreads, lossFactor) == 1) + if (distributeToThreads(numThreads, lossFactor) == 1) { break; + } } } @@ -1600,19 +1662,26 @@ static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target, Operation *rootOp) { // First try to find a proper CodeGen configuration to tile and vectorize for // the current target architecture. - if (target.isAMD() && succeeded(detail::setAMDCodeGenConfig(target, rootOp))) + if (target.isAMD() && + succeeded(detail::setAMDCodeGenConfig(target, rootOp))) { return success(); + } if (target.isApple() && - succeeded(detail::setAppleCodeGenConfig(target, rootOp))) + succeeded(detail::setAppleCodeGenConfig(target, rootOp))) { return success(); - if (target.isARM() && succeeded(detail::setMaliCodeGenConfig(target, rootOp))) + } + if (target.isARM() && + succeeded(detail::setMaliCodeGenConfig(target, rootOp))) { return success(); + } if (target.isNVIDIA() && - succeeded(detail::setNVIDIACodeGenConfig(target, rootOp))) + succeeded(detail::setNVIDIACodeGenConfig(target, rootOp))) { return success(); + } if (target.isQualcomm() && - succeeded(detail::setAdrenoCodeGenConfig(target, rootOp))) + succeeded(detail::setAdrenoCodeGenConfig(target, rootOp))) { return success(); + } // Otherwise fallback to use a default configuration that tiles and // distributes/vectorizes. @@ -1635,8 +1704,9 @@ static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target, const int subgroupSize = 32; auto result = detail::setConvOpConfig(cast(*op), subgroupSize, bestTilingFactor); - if (succeeded(result)) + if (succeeded(result)) { return success(); + } } // If unsuccessful, try to tile and distribute/vectorize. return setDefaultOpConfig(target, op); @@ -1692,22 +1762,26 @@ static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target, ArrayRef roots(computeOps); while (roots.size() > 1) { auto linalgOp = dyn_cast(roots.front()); - if (!linalgOp) + if (!linalgOp) { break; - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) + } + if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { break; + } roots = roots.drop_front(); } for (Operation *computeOp : roots) { - if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp))) + if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp))) { return success(); + } } Operation *computeOp = roots.back(); // If there are still no root op, check for any linalg.generic op. - if (succeeded(setDefaultOpConfig(target, computeOp))) + if (succeeded(setDefaultOpConfig(target, computeOp))) { return success(); + } // Check if the op configuration was set. return computeOp->emitOpError( @@ -1717,11 +1791,13 @@ static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target, LogicalResult initSPIRVLaunchConfig(FunctionOpInterface funcOp) { IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) + if (!target) { return funcOp.emitError("missing GPU target in #hal.executable.target"); + } - if (getTranslationInfo(funcOp)) + if (getTranslationInfo(funcOp)) { return success(); + } if (auto exportOp = getEntryPoint(funcOp)) { // If no translation info set, first check whether we already have workgroup diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp index 37b8ba322161..25643415e539 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp @@ -45,16 +45,18 @@ LogicalResult setMaliCodeGenConfig(IREE::GPU::TargetAttr target, const int subgroupSize = target.getPreferredSubgroupSize(); if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setMaliMatmulConfig(linalgOp, target); + } } if (auto convOp = dyn_cast(rootOp)) { // Use the result type in case of larger bitwidth for accumulators. auto type = cast(convOp->getResult(0).getType()); const int bitwidth = type.getElementTypeBitWidth(); - if (bitwidth > 32) + if (bitwidth > 32) { return failure(); + } const int multipler = 32 / bitwidth; bool hasPaddedInput = convOp.image().getDefiningOp(); const int bestTilingFactor = (hasPaddedInput ? 8 : 16) * multipler; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp index 5f4505a02d62..aefeef7ec608 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/NVIDIAConfig.cpp @@ -30,8 +30,9 @@ static LogicalResult setNVIDIAMatmulConfig(linalg::LinalgOp op, // First try to see if we can use tensor cores. if (succeeded(setCooperativeMatrixConfig(target, op, NVIDIANumSubgroupsPerWorkgroup, - NVIDIANumMNTilesPerSubgroup))) + NVIDIANumMNTilesPerSubgroup))) { return success(); + } const int subgroupSize = target.getPreferredSubgroupSize(); const std::array workgroupXY = {subgroupSize, 8}; @@ -79,8 +80,9 @@ static LogicalResult setNVIDIAMatmulConfig(linalg::LinalgOp op, LogicalResult setNVIDIACodeGenConfig(IREE::GPU::TargetAttr target, Operation *rootOp) { if (auto linalgOp = dyn_cast(rootOp)) { - if (isMatmulOrBatchMatmul(linalgOp)) + if (isMatmulOrBatchMatmul(linalgOp)) { return setNVIDIAMatmulConfig(linalgOp, target); + } } return failure(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index 07737c4b9d4e..df7987e986e4 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -94,8 +94,9 @@ static LogicalResult gpuCopyFn(OpBuilder &builder, Location loc, Value from, bool needsBarrier = hasSharedMemoryAddressSpace(fromType) || hasSharedMemoryAddressSpace(toType); - if (needsBarrier) + if (needsBarrier) { gpu::BarrierOp::create(builder, loc); + } Operation *copy = memref::CopyOp::create(builder, loc, from, to); if (needsBarrier) { setMarker(copy, getCopyToWorkgroupMemoryMarker()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp index dbe01cc10a12..a42d102e6d36 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp @@ -25,16 +25,18 @@ class SPIRVAnnotateWinogradLoopsPass final mlir::FunctionOpInterface funcOp = getOperation(); SmallVector forOps; funcOp.walk([&](scf::ForOp forOp) { - if (!isTiledAndDistributedLoop(forOp)) + if (!isTiledAndDistributedLoop(forOp)) { forOps.push_back(forOp); + } }); MLIRContext *context = &getContext(); OpBuilder builder(context); const char *attrName = getGPUDistributeAttrName(); for (auto [index, forOp] : llvm::enumerate(forOps)) { - if (index > kNumGPUDims) + if (index > kNumGPUDims) { break; + } forOp->setAttr(attrName, builder.getIndexAttr(index)); } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp index c9ef3ffe0fd6..c64f7c3158f3 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp @@ -67,26 +67,31 @@ struct BreakDownCastExtractExtend final : OpRewritePattern { PatternRewriter &rewriter) const override { auto extractOp = extOp.getIn().getDefiningOp(); - if (!extractOp) + if (!extractOp) { return failure(); + } auto bitCastOp = extractOp.getSource().getDefiningOp(); - if (!bitCastOp) + if (!bitCastOp) { return failure(); + } VectorType extractSrcType = extractOp.getSourceVectorType(); VectorType extractDstType = extractOp.getType(); // We expect high-D vectors are broken down into 1-D ones so here we only // handle 1-D vectors. - if (extractSrcType.getRank() != 1 || extractDstType.getRank() != 1) + if (extractSrcType.getRank() != 1 || extractDstType.getRank() != 1) { return failure(); + } // We only have power-of-two bitwidth cases for now. if (!llvm::isPowerOf2_64(extractSrcType.getNumElements()) || - !llvm::isPowerOf2_64(extractDstType.getNumElements())) + !llvm::isPowerOf2_64(extractDstType.getNumElements())) { return failure(); + } // We only handle not directly supported vector sizes. - if (extractSrcType.getNumElements() <= 4) + if (extractSrcType.getNumElements() <= 4) { return failure(); + } int64_t srcElemBitwidth = bitCastOp.getSourceVectorType().getElementTypeBitWidth(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp index 3ef39a32875c..c7209f2075a8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVConvertGPUTarget.cpp @@ -65,8 +65,9 @@ std::optional processCapabilities(ArrayRef features, SetVector &caps) { for (StringRef feature : features) { if (feature.consume_front("cap:")) { - if (std::optional cap = spirv::symbolizeCapability(feature)) + if (std::optional cap = spirv::symbolizeCapability(feature)) { caps.insert(*cap); + } } } return std::nullopt; @@ -78,8 +79,9 @@ std::optional processExtensions(ArrayRef features, SetVector &exts) { for (StringRef feature : features) { if (feature.consume_front("ext:")) { - if (std::optional ext = spirv::symbolizeExtension(feature)) + if (std::optional ext = spirv::symbolizeExtension(feature)) { exts.insert(*ext); + } } } return std::nullopt; @@ -99,16 +101,21 @@ ClientAPI deduceClientAPI(StringRef backend) { } Vendor deduceVendor(IREE::GPU::TargetAttr target) { - if (target.isAMD()) + if (target.isAMD()) { return Vendor::AMD; - if (target.isApple()) + } + if (target.isApple()) { return Vendor::Apple; - if (target.isARM()) + } + if (target.isARM()) { return Vendor::ARM; - if (target.isNVIDIA()) + } + if (target.isNVIDIA()) { return Vendor::NVIDIA; - if (target.isQualcomm()) + } + if (target.isQualcomm()) { return Vendor::Qualcomm; + } return Vendor::Unknown; } @@ -118,19 +125,24 @@ Vendor deduceVendor(IREE::GPU::TargetAttr target) { void addComputeFeatures(ComputeBitwidths compute, SetVector &caps, SetVector &exts) { - if (bitEnumContainsAny(compute, ComputeBitwidths::FP64)) + if (bitEnumContainsAny(compute, ComputeBitwidths::FP64)) { caps.insert(Capability::Float64); + } // FP32 does not need special capabilities or extensions. - if (bitEnumContainsAny(compute, ComputeBitwidths::FP16)) + if (bitEnumContainsAny(compute, ComputeBitwidths::FP16)) { caps.insert(Capability::Float16); + } - if (bitEnumContainsAny(compute, ComputeBitwidths::Int64)) + if (bitEnumContainsAny(compute, ComputeBitwidths::Int64)) { caps.insert(Capability::Int64); + } // Int32 does not need special capabilities or extensions. - if (bitEnumContainsAny(compute, ComputeBitwidths::Int16)) + if (bitEnumContainsAny(compute, ComputeBitwidths::Int16)) { caps.insert(Capability::Int16); - if (bitEnumContainsAny(compute, ComputeBitwidths::Int8)) + } + if (bitEnumContainsAny(compute, ComputeBitwidths::Int8)) { caps.insert(Capability::Int8); + } } void addStorageFeatures(StorageBitwidths storage, SetVector &caps, @@ -280,8 +292,9 @@ struct SPIRVConvertGPUTargetPass final FailureOr spirvTarget = convertGPUTarget(context, variant); - if (failed(spirvTarget)) + if (failed(spirvTarget)) { return signalPassFailure(); + } moduleOp->setAttr(spirv::getTargetEnvAttrName(), *spirvTarget); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp index 8291c1cf7384..e7f06146b9a8 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp @@ -55,10 +55,11 @@ struct ConvertHalInterfaceBindingSubspan final matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResultTy = getTypeConverter()->convertType(op.getType()); - if (!newResultTy) + if (!newResultTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to legalize memref type: {}", op.getType())); + } auto newOp = rewriter.replaceOpWithNewOp( @@ -111,8 +112,9 @@ struct ConvertUtilAssumeIntOp final unsigned replacementLoc = 0; for (auto result : newOp.getResults()) { - while (replacements[replacementLoc] != nullptr) + while (replacements[replacementLoc] != nullptr) { replacementLoc++; + } Value replacement = result; Type newType = getTypeConverter()->convertType( op.getResult(replacementLoc).getType()); @@ -138,11 +140,13 @@ struct ConvertUtilAssumeIntOp final // Tries to flatten `type` to a 1-D vector type. Returns `nullptr` on failure. static VectorType flattenVectorType(Type type) { auto vecTy = dyn_cast(type); - if (!vecTy) + if (!vecTy) { return nullptr; + } - if (vecTy.isScalable() || vecTy.getRank() <= 1) + if (vecTy.isScalable() || vecTy.getRank() <= 1) { return nullptr; + } int64_t totalElements = vecTy.getNumElements(); return VectorType::get(llvm::ArrayRef(totalElements), vecTy.getElementType()); @@ -167,13 +171,15 @@ struct FlattenElementwisePattern final : RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!OpTrait::hasElementwiseMappableTraits(op)) + if (!OpTrait::hasElementwiseMappableTraits(op)) { return failure(); + } auto newResultTypes = llvm::to_vector_of( llvm::map_range(op->getResultTypes(), flattenVectorType)); - if (llvm::any_of(newResultTypes, [](Type type) { return !type; })) + if (llvm::any_of(newResultTypes, [](Type type) { return !type; })) { return failure(); + } Location loc = op->getLoc(); @@ -181,8 +187,9 @@ struct FlattenElementwisePattern final : RewritePattern { auto operands = llvm::to_vector_of(op->getOperands()); for (Value &operand : operands) { VectorType newOperandTy = flattenVectorType(operand.getType()); - if (!newOperandTy) + if (!newOperandTy) { return failure(); + } operand = rewriter.createOrFold(loc, newOperandTy, operand); @@ -233,8 +240,9 @@ struct SPIRVEmulateI64Pass final void runOnOperation() override { mlir::FunctionOpInterface op = getOperation(); - if (supportsI64(op)) + if (supportsI64(op)) { return; + } arith::WideIntEmulationConverter typeConverter(32); memref::populateMemRefWideIntEmulationConversions(typeConverter); @@ -263,8 +271,9 @@ struct SPIRVEmulateI64Pass final memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); populateIreeI64EmulationPatterns(typeConverter, patterns); - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { signalPassFailure(); + } } // Clean up any new 2-D vectors. We need to do it here because later passed @@ -279,8 +288,9 @@ struct SPIRVEmulateI64Pass final vector::InsertStridedSliceOp::getCanonicalizationPatterns(patterns, ctx); vector::ShapeCastOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { return signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp index df98ee11741d..1eb346982475 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp @@ -36,12 +36,14 @@ class EraseStorageBufferStaticShapePass final bool is1DStaticShapedStorageBuffer( IREE::HAL::InterfaceBindingSubspanOp subspanOp) { auto type = dyn_cast(subspanOp.getType()); - if (!type) + if (!type) { return false; + } auto attr = dyn_cast_if_present(type.getMemorySpace()); - if (!attr) + if (!attr) { return false; + } return type.hasStaticShape() && type.getRank() == 1 && attr.getValue() == IREE::HAL::DescriptorType::StorageBuffer; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp index a6aa5ac39397..adec7266b5f2 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp @@ -51,29 +51,34 @@ void debugPrint(Operation *op, const char *message) { int getComputeVectorSize(int64_t size) { for (int i : {4, 3, 2}) { - if (size % i == 0) + if (size % i == 0) { return i; + } } return 1; } int getMemoryVectorSize(Value source, Type scalarType, int64_t size) { int bitwidth = scalarType.getIntOrFloatBitWidth(); - while (auto sliceOp = source.getDefiningOp()) + while (auto sliceOp = source.getDefiningOp()) { source = sliceOp.getSource(); + } if (!matchPattern(source, m_Constant())) { // If we are not reading from a constant array that is embedded in the // kernel, try to use a large vector size matching the bitwidth to read in // 128-bit chunks. This helps with memory access performance. Such vector // sizes are not native in SPIR-V though; this relies on following passes to // bitcast them to 32-bit 4-element vectors to be valid. - if (bitwidth <= 8 && size % 16 == 0) + if (bitwidth <= 8 && size % 16 == 0) { return 16; - if (bitwidth <= 16 && size % 8 == 0) + } + if (bitwidth <= 16 && size % 8 == 0) { return 8; + } } - if (bitwidth <= 32 && size % 4 == 0) + if (bitwidth <= 32 && size % 4 == 0) { return 4; + } return size % 2 == 0 ? 2 : 1; } @@ -108,8 +113,9 @@ Operation *stripElementBitPatternPreservingParents(Value op) { }) .Default([](Operation *) { return nullptr; }); - if (!source) + if (!source) { break; + } op = source; } @@ -119,8 +125,9 @@ Operation *stripElementBitPatternPreservingParents(Value op) { /// Returns true when |op| has the i32 element type that is likely to be result /// of a zero/sign extension from i8. bool mayExtI8ToI32(Value op) { - if (!getElementTypeOrSelf(op.getType()).isInteger(32)) + if (!getElementTypeOrSelf(op.getType()).isInteger(32)) { return false; + } // Look through vector operations created by vector unrolling patterns, // hoping to find a zero/sign extension op. Note that we do not need to find @@ -146,15 +153,18 @@ bool mayExtI8ToI32(Value op) { /// Succeeds when |contract| is a i32 matmul whose LHS and RHS operands may be /// result of zero/sign extension of i8 inputs. LogicalResult detectI8ToI32Matmul(vector::ContractionOp contract) { - if (contract.getKind() != vector::CombiningKind::ADD) + if (contract.getKind() != vector::CombiningKind::ADD) { return failure(); + } - if (!mayExtI8ToI32(contract.getLhs()) || !mayExtI8ToI32(contract.getRhs())) + if (!mayExtI8ToI32(contract.getLhs()) || !mayExtI8ToI32(contract.getRhs())) { return failure(); + } ArrayRef iteratorTypes = contract.getIteratorTypes().getValue(); - if (iteratorTypes.size() != 3) + if (iteratorTypes.size() != 3) { return failure(); + } return success(vector::isParallelIterator(iteratorTypes[0]) && vector::isParallelIterator(iteratorTypes[1]) && @@ -265,12 +275,14 @@ bool supportsIntegerDotProductOps(mlir::FunctionOpInterface fn) { // First check if the function op itself has a target env attribute. This may // be preferred in tests. auto targetEnvAttr = getGPUTargetAttr(fn); - if (!targetEnvAttr) + if (!targetEnvAttr) { return false; + } if (!IREE::GPU::bitEnumContainsAll(targetEnvAttr.getWgp().getDot().getValue(), - IREE::GPU::DotProductOps::DP4xI8ToI32)) + IREE::GPU::DotProductOps::DP4xI8ToI32)) { return false; + } return true; } @@ -332,8 +344,9 @@ class SPIRVInitialLoweringPass final // batch dimension. Try to drop that to map to matmul dimensions better. SmallVector contractOps; funcOp.walk([&](vector::ContractionOp op) { - if (op.getIteratorTypes().size() > 3) + if (op.getIteratorTypes().size() > 3) { contractOps.push_back(op); + } }); for (vector::ContractionOp op : contractOps) { OpBuilder builder(op); @@ -373,8 +386,9 @@ class SPIRVInitialLoweringPass final funcOp.walk([&](vector::MultiDimReductionOp reductionOp) { if (llvm::any_of(reductionOp->getOperands(), [](Value operand) { return operand.getDefiningOp(); - })) + })) { reductionOps.push_back(reductionOp); + } return WalkResult::advance(); }); RewritePatternSet patterns(context); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp index 4c3f84e2211a..1d70a29637f0 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp @@ -70,8 +70,9 @@ struct SPIRVLinkExecutablesPass final // Collect all source executable ops. auto sourceExecutableOps = gatherExecutablesForSPIRVCodegen(moduleOp); - if (sourceExecutableOps.size() <= 1) + if (sourceExecutableOps.size() <= 1) { return; + } // Note that at runtime, for a particular executable, only one variant of it // will be loaded. So, all variants of an executable are expected to provide @@ -154,8 +155,9 @@ struct SPIRVLinkExecutablesPass final } }); - if (failed(linkOneExecutableBucket(moduleOp, moduleName, key, bucket))) + if (failed(linkOneExecutableBucket(moduleOp, moduleName, key, bucket))) { return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp index 220bb9df8071..6c1af71fc727 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMapMemRefStorageClass.cpp @@ -75,16 +75,18 @@ mapHALDescriptorTypeForOpenCL(Attribute attr) { bool allowsShaderCapability(ArrayRef features) { for (StringRef feature : features) { - if (feature.consume_front("cap:") && feature == "Shader") + if (feature.consume_front("cap:") && feature == "Shader") { return true; + } } return false; } bool allowsKernelCapability(ArrayRef features) { for (StringRef feature : features) { - if (feature.consume_front("cap:") && feature == "Kernel") + if (feature.consume_front("cap:") && feature == "Kernel") { return true; + } } return false; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp index d01eab942010..40750858d891 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp @@ -157,8 +157,9 @@ LogicalResult mapToDeviceQuery(IREE::HAL::ExecutableExportOp entryPoint, entryPoint->getAttrOfType("iree.spirv.coopmatrix.type"); auto coopmatShape = entryPoint->getAttrOfType( "iree.spirv.coopmatrix.shape"); - if (!coopmatType || !coopmatShape) + if (!coopmatType || !coopmatShape) { return failure(); + } Type inputType = cast(coopmatType.getValue().front()).getValue(); Type outputType = cast(coopmatType.getValue().back()).getValue(); @@ -277,8 +278,9 @@ struct SPIRVMaterializeExecutableConditionsPass final SPIRVMaterializeExecutableConditionsPass> { void runOnOperation() override { IREE::HAL::ExecutableVariantOp variantOp = getOperation(); - if (!usesSPIRVCodeGen(variantOp)) + if (!usesSPIRVCodeGen(variantOp)) { return; + } IREE::HAL::ExecutableTargetAttr executableTarget = variantOp.getTarget(); DictionaryAttr configuration = executableTarget.getConfiguration(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp index 822e4f8c1ba0..9028aa7ab93c 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp @@ -50,8 +50,9 @@ verifyLoweringConfiguration(FunctionOpInterface funcOp, auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult { auto loweringConfig = getLoweringConfig(op); - if (!loweringConfig) + if (!loweringConfig) { return WalkResult::advance(); + } return verificationFn(op, loweringConfig, translationInfo, workgroupSize); }); return failure(walkResult.wasInterrupted()); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp index 1b68acef1783..8665bb9619d2 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp @@ -127,15 +127,18 @@ class SPIRVTileAndDistributePass final void SPIRVTileAndDistributePass::runOnOperation() { MLIRContext *context = &getContext(); mlir::FunctionOpInterface funcOp = getOperation(); - if (!isEntryPoint(funcOp)) + if (!isEntryPoint(funcOp)) { return; + } auto threadTileComputeFn = getSPIRVTileSizeComputeFn(funcOp, 1); - if (failed(threadTileComputeFn)) + if (failed(threadTileComputeFn)) { return signalPassFailure(); + } auto reductionTileComputeFn = getSPIRVScfTileSizeComputeFn(funcOp, 2); - if (failed(reductionTileComputeFn)) + if (failed(reductionTileComputeFn)) { return signalPassFailure(); + } { // Tile and distribute to invocations. if (failed(tileToInvocation(funcOp, *threadTileComputeFn))) { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp index 9824035c8186..6a41bc1cc7bd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp @@ -132,16 +132,19 @@ void SPIRVTileAndPromotePass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); auto threadTileComputeFn = getSPIRVTileSizeComputeFn(funcOp, 1); - if (failed(threadTileComputeFn)) + if (failed(threadTileComputeFn)) { return signalPassFailure(); + } auto reductionTileComputeFn = getSPIRVScfTileSizeComputeFn(funcOp, 2); - if (failed(reductionTileComputeFn)) + if (failed(reductionTileComputeFn)) { return signalPassFailure(); + } // Promote C matrix and propagate the potential fill producer into the // allocation. This needs to be done before reduction tiling. - if (failed(doPromoteCMatrix(funcOp))) + if (failed(doPromoteCMatrix(funcOp))) { return signalPassFailure(); + } StringLiteral markerAttrName = LinalgTransforms::kLinalgTransformMarker; auto workgroupMarker = StringAttr::get(context, getWorkgroupMemoryMarker()); @@ -219,10 +222,12 @@ void SPIRVTileAndPromotePass::runOnOperation() { // that there are no subview ops), clear markers to enable following steps. funcOp.walk([&](linalg::LinalgOp linalgOp) { auto marker = linalgOp->getAttrOfType(markerAttrName); - if (!marker) + if (!marker) { return WalkResult::advance(); - if (marker.getValue() == promoteBothMarker) + } + if (marker.getValue() == promoteBothMarker) { linalgOp->removeAttr(markerAttrName); + } return WalkResult::advance(); }); } @@ -271,14 +276,16 @@ void SPIRVTileAndPromotePass::runOnOperation() { LogicalResult SPIRVTileAndPromotePass::doPromoteCMatrix( mlir::FunctionOpInterface funcOp) const { MLIRContext *context = funcOp.getContext(); - if (!promoteCMatrix) + if (!promoteCMatrix) { return success(); + } SmallVector computeOps = getComputeOps(funcOp); SmallVector linalgOps; for (Operation *op : computeOps) { - if (isa(op)) + if (isa(op)) { continue; // Don't care + } if (auto linalgOp = dyn_cast(op)) { linalgOps.push_back(linalgOp); } else { @@ -291,8 +298,9 @@ LogicalResult SPIRVTileAndPromotePass::doPromoteCMatrix( } // If there are no fused elementwise ops, we can avoid promoting C matrix. - if (linalgOps.size() <= 1) + if (linalgOps.size() <= 1) { return success(); + } auto matmulOp = cast(linalgOps.front()); auto genericOp = cast(*linalgOps.back()); @@ -311,8 +319,9 @@ LogicalResult SPIRVTileAndPromotePass::doPromoteCMatrix( // If the fused elementwise ops are allowed to use cooperative types, we can // also avoid promoting C matrix. - if (isCooperativeMatrixFusable(genericOp)) + if (isCooperativeMatrixFusable(genericOp)) { return success(); + } // Finally do promote C matrix. RewritePatternSet patterns(context); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp index 3d754ef10cfb..6e3b1a0bdf83 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp @@ -90,8 +90,9 @@ void setSPIRVCooperativeMatrixInfo(mlir::FunctionOpInterface funcOp, ArrayRef getSPIRVCooperativeMatrixShape(mlir::FunctionOpInterface funcOp) { auto attr = funcOp->getAttrOfType(coopMatShapeAttrName); - if (!attr) + if (!attr) { return {}; + } return attr.asArrayRef(); } @@ -110,10 +111,12 @@ static SmallVector deduceSubgroupCounts(linalg::LinalgOp op) { SmallVector subgroupCounts; for (int i = 0, e = workgroupTileSizes.size(); i < e; ++i) { - if (subgroupTileSizes[i] == 0) + if (subgroupTileSizes[i] == 0) { continue; - if (linalg::isReductionIterator(op.getIteratorTypesArray()[i])) + } + if (linalg::isReductionIterator(op.getIteratorTypesArray()[i])) { continue; + } assert(workgroupTileSizes[i] % subgroupTileSizes[i] == 0); subgroupCounts.push_back(workgroupTileSizes[i] / subgroupTileSizes[i]); } @@ -174,17 +177,20 @@ std::optional> getExtOpVectorShape(ExtOpTy op, ArrayRef nativeShape) { auto insert = op.getOperand().template getDefiningOp(); - if (!insert) + if (!insert) { return std::nullopt; + } VectorType sliceType = insert.getSourceVectorType(); for (Operation *users : op->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (!llvm::equal(sliceType.getShape(), vecType.getShape())) + if (!llvm::equal(sliceType.getShape(), vecType.getShape())) { return std::nullopt; + } } return llvm::to_vector(sliceType.getShape()); @@ -201,8 +207,9 @@ getCooperativeOpVectorShape(Operation *op, ArrayRef nativeShape) { // Unroll elementwise ops according to native cooperative matrix size. if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) { - if (auto vecType = dyn_cast(op->getResultTypes()[0])) + if (auto vecType = dyn_cast(op->getResultTypes()[0])) { return llvm::to_vector(nativeShape.drop_back()); // Drop K dim size + } } // Unrolling vector.contract generates vector.{insert|extract}_strided_slice @@ -231,27 +238,32 @@ getCooperativeOpVectorShape(Operation *op, ArrayRef nativeShape) { auto sourceOp = op; if (op->hasOneUse()) { auto user = *op->user_begin(); - if (isa(user) || isa(user)) + if (isa(user) || isa(user)) { sourceOp = user; + } } VectorType sliceType; for (Operation *users : sourceOp->getUsers()) { auto extract = dyn_cast(users); - if (!extract) + if (!extract) { return std::nullopt; + } auto vecType = cast(extract.getResult().getType()); - if (sliceType && sliceType != vecType) + if (sliceType && sliceType != vecType) { return std::nullopt; + } sliceType = vecType; } return llvm::to_vector(sliceType.getShape()); } - if (auto extOp = dyn_cast(op)) + if (auto extOp = dyn_cast(op)) { return getExtOpVectorShape(extOp, nativeShape); - if (auto extOp = dyn_cast(op)) + } + if (auto extOp = dyn_cast(op)) { return getExtOpVectorShape(extOp, nativeShape); + } return std::nullopt; } @@ -309,8 +321,9 @@ class CombineContractTranspose final newSources.push_back(transposeOp.getVector()); foundTranspose = true; } - if (!foundTranspose) + if (!foundTranspose) { return failure(); + } Value res = vector::ContractionOp::create( rewriter, loc, newSources[0], newSources[1], newSources[2], diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp index f9e163d75303..926ec2b04f53 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp @@ -62,8 +62,9 @@ static bool getUsesIfAllTransferOp(Value value, } continue; } - if (isa(userOp)) + if (isa(userOp)) { continue; + } if (!isa(userOp)) { @@ -109,15 +110,18 @@ calculateMemRefVectorNumBits(SmallVectorImpl &uses) { continue; } auto transferOp = dyn_cast(op); - if (!transferOp) + if (!transferOp) { return 0; + } // Masked transfers must be scalarized. - if (transferOp.getMask()) + if (transferOp.getMask()) { return 0; + } std::optional transferSize = getBitWidth(transferOp.getVectorType()); - if (!transferSize) + if (!transferSize) { return 0; + } minBits = std::min(minBits, *transferSize); } @@ -131,8 +135,9 @@ calculateMemRefVectorNumBits(SmallVectorImpl &uses) { memrefVal = storeOp.getDstMemref(); stride = storeOp.getLeadDimension().getSExtValue(); } - if (!memrefVal) + if (!memrefVal) { continue; + } // GPU subgroup MMA ops do not care about the memref element type. But we // still need to make sure we can load/store with good strides. @@ -141,12 +146,14 @@ calculateMemRefVectorNumBits(SmallVectorImpl &uses) { auto memrefType = cast(memrefVal.getType()); std::optional elementBits = getBitWidth(memrefType.getElementType()); - if (!elementBits) + if (!elementBits) { return 0; + } int64_t strideBits = stride * *elementBits; // Make sure the stride is aligned with the planned vector bitwidth. - if (strideBits % minBits != 0) + if (strideBits % minBits != 0) { return 0; + } } return minBits; @@ -197,8 +204,9 @@ static unsigned isMemRefVectorizable(Value value, if (getUsesIfAllTransferOp(value, uses)) { unsigned vectorBits = calculateMemRefVectorNumBits(uses); LLVM_DEBUG(llvm::dbgs() << "vectorBits=" << vectorBits << "\n"); - if (!vectorBits) + if (!vectorBits) { return 0; + } // TODO: Fix sub-byte type support in vector.bitcast lowering. if (vectorBits % 32 != 0) { @@ -377,8 +385,9 @@ class ProcessTransferRead final FailureOr> indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return rewriter.notifyMatchFailure(read, "failed to adjust indices"); + } // If the transfer_read can be replaced by a load after vectorization use // LoadOp and cast back to the original type. @@ -480,8 +489,9 @@ class ProcessTransferWrite final FailureOr> indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return rewriter.notifyMatchFailure(write, "failed to adjust indices"); + } // If the transfer_write can be replaced by a store after vectorization cast // the original value and use StoreOp. @@ -572,8 +582,9 @@ MemRefConversionPattern::getVectorizedMemRefType( Type vectorType = VectorType::get(vectorNumElements, scalarType); auto newShape = llvm::to_vector<2>(type.getShape()); unsigned ratio = vectorNumBits / type.getElementTypeBitWidth(); - if (newShape.back() % ratio != 0) + if (newShape.back() % ratio != 0) { return {}; + } newShape.back() = newShape.back() / ratio; MemRefLayoutAttrInterface layout = {}; @@ -605,8 +616,9 @@ FailureOr> MemRefConversionPattern::adjustIndices( getBitWidth(vectorMemrefType.getElementType()); std::optional scalarMemrefElemSize = getBitWidth(scalarMemrefType.getElementType()); - if (!vectorMemrefElemSize || !scalarMemrefElemSize) + if (!vectorMemrefElemSize || !scalarMemrefElemSize) { return failure(); + } MLIRContext *context = rewriter.getContext(); AffineExpr sym0, sym1; @@ -629,8 +641,9 @@ class ProcessAlloc final : public MemRefConversionPattern { matchAndRewrite(memref::AllocOp alloc, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto memrefType = getVectorizedMemRefType(rewriter, alloc.getResult()); - if (!memrefType) + if (!memrefType) { return failure(); + } rewriter.replaceOpWithNewOp(alloc, *memrefType, alloc.getDynamicSizes()); return success(); @@ -647,8 +660,9 @@ class ProcessInterfaceBindingSubspan final OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto memrefType = dyn_cast(subspanOp.getType()); - if (!memrefType) + if (!memrefType) { return failure(); + } // This should be guaranteed by the analysis step. But just double check. assert(memrefType.getRank() > 0 && @@ -696,8 +710,9 @@ struct ProcessSubgroupMMALoad final Location loc = loadOp.getLoc(); auto indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return failure(); + } // Compute how many bits the mma op stride corresponds to for the scalar // memref, and rescale it to vector memref. @@ -730,8 +745,9 @@ struct ProcessSubgroupMMAStore final Location loc = storeOp.getLoc(); auto indices = adjustIndices(scalarMemrefType, vectorMemrefType, adaptor.getIndices(), rewriter, loc); - if (failed(indices)) + if (failed(indices)) { return failure(); + } // Compute how many bits the mma op stride corresponds to for the scalar // memref, and rescale it to vector memref. @@ -804,8 +820,9 @@ struct ScalarizeVectorTransferRead final PatternRewriter &rewriter) const override { VectorType vectorType = readOp.getType(); auto map = readOp.getPermutationMap(); - if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) + if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) { return failure(); + } Location loc = readOp.getLoc(); Value maybeMask = readOp.getMask(); @@ -883,8 +900,9 @@ struct ScalarizeVectorLoad final : public OpRewritePattern { LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { VectorType vectorType = loadOp.getType(); - if (vectorType.getRank() > 1) + if (vectorType.getRank() > 1) { return failure(); + } Location loc = loadOp.getLoc(); if (vectorType.getRank() == 0) { @@ -929,8 +947,9 @@ struct ScalarizeVectorTransferWrite final PatternRewriter &rewriter) const override { VectorType vectorType = writeOp.getVectorType(); auto map = writeOp.getPermutationMap(); - if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) + if (vectorType.getRank() > 1 || !map.isProjectedPermutation()) { return failure(); + } Location loc = writeOp.getLoc(); Value maybeMask = writeOp.getMask(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp index 2685c41b5e5e..9dfb15d68388 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp @@ -36,8 +36,9 @@ const char *getSPIRVDistributeAttrName() { return "iree.spirv.distribute_dim"; } DictionaryAttr getTargetConfigAttr(Operation *op) { auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); - if (!targetAttr) + if (!targetAttr) { return nullptr; + } return targetAttr.getConfiguration(); } @@ -62,8 +63,9 @@ getSPIRVTileSize(mlir::FunctionOpInterface funcOp, int tilingLevel) { FailureOr getSPIRVTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { auto tileSizes = getSPIRVTileSize(funcOp, tilingLevel); - if (failed(tileSizes)) + if (failed(tileSizes)) { return failure(); + } linalg::TileSizeComputationFunction computeFn = [tileSizes](OpBuilder &builder, Operation *op) { auto range = llvm::map_range(*tileSizes, [&](int64_t size) -> Value { @@ -79,8 +81,9 @@ getSPIRVScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { FailureOr> tileSizes = getSPIRVTileSize(funcOp, tilingLevel); - if (failed(tileSizes)) + if (failed(tileSizes)) { return failure(); + } scf::SCFTileSizeComputationFunction computeFn = [tileSizes](OpBuilder &builder, Operation *op) -> SmallVector { diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp index fe00bd1e6be4..2a1bafa26e01 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp @@ -34,8 +34,9 @@ LogicalResult verifySPIRVMatmulPromoteVectorizePassPipeline( << stringifyEnum(CodeGenPipeline::SPIRVMatmulPromoteVectorize); } - if (!isa(op)) + if (!isa(op)) { return success(); + } LLVM_DEBUG(llvm::dbgs() << "verifying op: " << *op << "\n" << "chosen workgroup size: " @@ -55,8 +56,9 @@ LogicalResult verifySPIRVMatmulPromoteVectorizePassPipeline( auto funcOp = op->getParentOfType(); std::optional subgroupSize = getGPUSubgroupSize(funcOp); - if (!subgroupSize) + if (!subgroupSize) { return funcOp->emitError("failed to query subgroup size"); + } const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup(); const auto maxWorkGroupSize = target.getWgp().getMaxWorkgroupSizes().asArrayRef(); @@ -164,8 +166,9 @@ LogicalResult verifySPIRVCooperativeMatrixVectorizePassPipeline( auto funcOp = op->getParentOfType(); std::optional subgroupSize = getGPUSubgroupSize(funcOp); - if (!subgroupSize) + if (!subgroupSize) { return funcOp->emitError("failed to query subgroup size"); + } const int maxThreads = target.getWgp().getMaxThreadCountPerWorkgroup(); const auto maxWorkGroupSize = target.getWgp().getMaxWorkgroupSizes().asArrayRef(); diff --git a/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp index a786176ea66a..5a8da4ac6778 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/KernelDispatch.cpp @@ -28,8 +28,9 @@ getDefaultDistributionTileSizes(TilingInterface op) { llvm::DenseSet partitionedLoopsSet(partitionedLoops.begin(), partitionedLoops.end()); for (auto dim : llvm::seq(0, distTileSizes.size())) { - if (!partitionedLoopsSet.count(dim)) + if (!partitionedLoopsSet.count(dim)) { distTileSizes[dim] = 0; + } } return distTileSizes; diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp index 89d64882c227..f36bb9ea059f 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXAssignConstantOrdinals.cpp @@ -23,13 +23,15 @@ struct VMVXAssignConstantOrdinalsPass // Ignore non-VMVX variants. // TODO(benvanik): a way to nest this in the pipeline via dynamic passes. - if (variantOp.getTarget().getBackend().getValue() != "vmvx") + if (variantOp.getTarget().getBackend().getValue() != "vmvx") { return; + } // Get a constant key -> ordinal mapping. auto keyOrdinals = variantOp.gatherConstantOrdinals(); - if (keyOrdinals.empty()) + if (keyOrdinals.empty()) { return; + } // Update placeholders to hold the concrete ordinal values. // Eventually the VM global folding passes will inline them. @@ -39,8 +41,9 @@ struct VMVXAssignConstantOrdinalsPass moduleOp.getOps())) { auto keyAttr = globalOp->getAttr( IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); - if (!keyAttr) + if (!keyAttr) { continue; + } auto it = keyOrdinals.find(keyAttr); if (it == keyOrdinals.end()) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp index ae8e5a3ebe9a..695af393f023 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp @@ -54,8 +54,9 @@ void VMVXLowerExecutableTargetPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); auto translationInfo = getTranslationInfo(funcOp); - if (!translationInfo) + if (!translationInfo) { return; + } std::optional maybePipeline = getFunctionOpInterfacePassManager(funcOp); diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp index fd40ecda75df..650e870d73bd 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp +++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp @@ -179,8 +179,9 @@ class StridedBufferAnalysis { StridedBufferDescriptor &getDesc(OpBuilder &builder) { assert(isValid() && "invalid StridedBufferAnalysis"); - if (desc) + if (desc) { return *desc; + } OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointAfterValue(buffer); @@ -257,10 +258,12 @@ struct BinaryEmitter { } LogicalResult initialize(Location loc, PatternRewriter &rewriter) { - if (!isProjectedPermutation()) + if (!isProjectedPermutation()) { return rewriter.notifyMatchFailure(loc, "not projected permutation"); - if (maxRank() > 2) + } + if (maxRank() > 2) { return rewriter.notifyMatchFailure(loc, "rank > 2"); + } if (!operands.first.bufferAnal.isValid() || !operands.second.bufferAnal.isValid() || !result.bufferAnal.isValid()) { return rewriter.notifyMatchFailure(loc, @@ -370,10 +373,12 @@ struct UnaryEmitter { unsigned maxRank() { return std::max(operand.getRank(), result.getRank()); } LogicalResult initialize(Location loc, PatternRewriter &rewriter) { - if (!isProjectedPermutation()) + if (!isProjectedPermutation()) { return rewriter.notifyMatchFailure(loc, "not projected permutation"); - if (maxRank() > 2) + } + if (maxRank() > 2) { return rewriter.notifyMatchFailure(loc, "rank > 2"); + } if (!operand.bufferAnal.isValid() || !result.bufferAnal.isValid()) { return rewriter.notifyMatchFailure(loc, "could not compute buffer descriptor"); @@ -463,10 +468,12 @@ struct CopyEmitter { } LogicalResult initialize(Location loc, PatternRewriter &rewriter) { - if (!isProjectedPermutation()) + if (!isProjectedPermutation()) { return rewriter.notifyMatchFailure(loc, "not projected permutation"); - if (maxRank() > 2) + } + if (maxRank() > 2) { return rewriter.notifyMatchFailure(loc, "rank > 2"); + } // Initialize buffer descriptors. for (auto © : copies) { @@ -529,11 +536,13 @@ struct LinalgBinaryGenericConversion PatternRewriter &rewriter) const override { auto &children = op.getBlock()->getOperations(); // Only match two children (op + yield). - if (children.size() != 2) + if (children.size() != 2) { return failure(); + } // Only match parallel loops. - if (op.getNumParallelLoops() != op.getNumLoops()) + if (op.getNumParallelLoops() != op.getNumLoops()) { return failure(); + } // Match: // %0 = someop %arg2, %arg3 @@ -548,8 +557,9 @@ struct LinalgBinaryGenericConversion dyn_cast(binaryOp->getOperands()[0]); BlockArgument operandScalar1 = dyn_cast(binaryOp->getOperands()[1]); - if (!operandScalar0 || !operandScalar1) + if (!operandScalar0 || !operandScalar1) { return failure(); + } // Construct the emitter and start lowering. // Note that the operands may map to an out if the aliasing is safe, @@ -597,8 +607,9 @@ struct LinalgBinaryGenericConversion // Select the op to lower to and configure the emitter. // Emit from the iree_ukernel_x32b_opcode_t table. Type resultType = binaryOp->getResult(0).getType(); - if (!resultType.isIntOrFloat()) + if (!resultType.isIntOrFloat()) { return failure(); + } std::optional emitter = TypeSwitch>(binaryOp) .Case([&](arith::AddFOp op) -> std::optional { @@ -691,8 +702,9 @@ struct LinalgBinaryGenericConversion if (!emitter) { return rewriter.notifyMatchFailure(op, "unrecognized binary op"); } - if (failed(emitter->initialize(op.getLoc(), rewriter))) + if (failed(emitter->initialize(op.getLoc(), rewriter))) { return failure(); + } emitter->emit(op.getLoc(), rewriter); rewriter.eraseOp(op); @@ -709,11 +721,13 @@ struct LinalgUnaryGenericConversion PatternRewriter &rewriter) const override { auto &children = op.getBlock()->getOperations(); // Only match two children (op + yield). - if (children.size() != 2) + if (children.size() != 2) { return failure(); + } // Only match parallel loops. - if (op.getNumParallelLoops() != op.getNumLoops()) + if (op.getNumParallelLoops() != op.getNumLoops()) { return failure(); + } // Match: // %0 = someop %arg2 @@ -726,8 +740,9 @@ struct LinalgUnaryGenericConversion } BlockArgument operandScalar0 = dyn_cast(unaryOp->getOperands()[0]); - if (!operandScalar0) + if (!operandScalar0) { return failure(); + } // Construct the emitter and start lowering. // Note that the operands may map to an out if the aliasing is safe, @@ -755,8 +770,9 @@ struct LinalgUnaryGenericConversion // Select the op to lower to and configure the emitter. // Emit from the iree_ukernel_x32b_opcode_t table. Type resultType = unaryOp->getResult(0).getType(); - if (!resultType.isIntOrFloat()) + if (!resultType.isIntOrFloat()) { return failure(); + } std::optional emitter = TypeSwitch>(unaryOp) .Case([&](math::AbsFOp op) -> std::optional { @@ -814,8 +830,9 @@ struct LinalgUnaryGenericConversion if (!emitter) { return rewriter.notifyMatchFailure(op, "unrecognized unary op"); } - if (failed(emitter->initialize(op.getLoc(), rewriter))) + if (failed(emitter->initialize(op.getLoc(), rewriter))) { return failure(); + } emitter->emit(op.getLoc(), rewriter); rewriter.eraseOp(op); @@ -832,11 +849,13 @@ struct LinalgTrivialGenericConversion PatternRewriter &rewriter) const override { auto &children = op.getBlock()->getOperations(); // Only match one child (yield). - if (children.size() != 1) + if (children.size() != 1) { return failure(); + } // Only match parallel loops. - if (op.getNumParallelLoops() != op.getNumLoops()) + if (op.getNumParallelLoops() != op.getNumLoops()) { return failure(); + } // Presumed to be a yield terminator: configure the emitter. CopyEmitter emitter; @@ -857,8 +876,9 @@ struct LinalgTrivialGenericConversion } } - if (failed(emitter.initialize(op.getLoc(), rewriter))) + if (failed(emitter.initialize(op.getLoc(), rewriter))) { return failure(); + } emitter.emit(op.getLoc(), rewriter); rewriter.eraseOp(op); return success(); From bf6b22bef170d302826d5a4fb2ac08ef15cd2282 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 16:01:21 -0500 Subject: [PATCH 48/71] Add braces in Flow, HAL, and Stream dialects. NFC. 4/n (#23146) --- .../Flow/Conversion/TensorToFlow/Patterns.cpp | 3 +- .../Flow/Conversion/TensorToFlow/Utils.cpp | 6 +- .../Dialect/Flow/IR/FlowOpFolders.cpp | 45 ++-- .../iree/compiler/Dialect/Flow/IR/FlowOps.cpp | 105 +++++--- .../compiler/Dialect/Flow/IR/FlowTypes.cpp | 3 +- .../TransformExtensions/FlowExtensions.cpp | 69 ++++-- .../Flow/Transforms/AnnotateDispatches.cpp | 42 ++-- .../Dialect/Flow/Transforms/Canonicalize.cpp | 9 +- .../Flow/Transforms/CaptureDynamicDims.cpp | 45 ++-- .../Flow/Transforms/CleanupTensorShapes.cpp | 3 +- .../Transforms/ConvertRegionToWorkgroups.cpp | 15 +- .../Flow/Transforms/ConvertShardToFlow.cpp | 11 +- .../Transforms/DeduplicateExecutables.cpp | 15 +- .../Flow/Transforms/DumpDispatchGraph.cpp | 63 +++-- .../Flow/Transforms/ExportBenchmarkFuncs.cpp | 3 +- .../Flow/Transforms/FormDispatchRegions.cpp | 21 +- .../Flow/Transforms/InjectTensorTracing.cpp | 8 +- .../Transforms/InsertDispatchDebugTargets.cpp | 36 ++- .../Flow/Transforms/OutlineConstants.cpp | 18 +- .../Transforms/OutlineDispatchExterns.cpp | 6 +- .../Transforms/OutlineDispatchRegions.cpp | 6 +- .../Dialect/Flow/Transforms/RegionOpUtils.cpp | 48 ++-- .../Flow/Transforms/TopLevelSCFToCFG.cpp | 5 +- .../Dialect/HAL/Analysis/BindingLayout.cpp | 6 +- .../HALToVM/ConvertBufferViewOps.cpp | 6 +- .../HALToVM/ConvertCommandBufferOps.cpp | 6 +- .../Conversion/HALToVM/ConvertDeviceOps.cpp | 12 +- .../HALToVM/ConvertExecutableOps.cpp | 3 +- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 15 +- .../HAL/Conversion/StreamToHAL/Utils.cpp | 15 +- .../HAL/Conversion/UtilToHAL/Patterns.cpp | 3 +- .../iree/compiler/Dialect/HAL/IR/HALAttrs.cpp | 21 +- .../compiler/Dialect/HAL/IR/HALOpFolders.cpp | 72 ++++-- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 96 +++++--- .../iree/compiler/Dialect/HAL/IR/HALTypes.cpp | 9 +- .../Dialect/HAL/Target/TargetOptions.cpp | 15 +- .../Dialect/HAL/Target/TargetRegistry.cpp | 3 +- .../HAL/Transforms/AnnotateTargetDevices.cpp | 6 +- .../Transforms/CaptureExecutableSources.cpp | 9 +- .../HAL/Transforms/ConfigureExecutables.cpp | 3 +- .../Transforms/DumpExecutableBenchmarks.cpp | 9 +- .../HAL/Transforms/ElideRedundantCommands.cpp | 3 +- .../MaterializeDispatchInstrumentation.cpp | 15 +- .../HAL/Transforms/MaterializeInterfaces.cpp | 3 +- .../Transforms/MaterializeResourceCaches.cpp | 3 +- .../HAL/Transforms/MemoizeDeviceQueries.cpp | 3 +- .../HAL/Transforms/PreprocessExecutables.cpp | 12 +- .../HAL/Transforms/PruneExecutables.cpp | 12 +- .../HAL/Transforms/SerializeExecutables.cpp | 3 +- .../HAL/Transforms/SubstituteExecutables.cpp | 12 +- .../HAL/Transforms/TranslateExecutables.cpp | 6 +- .../Dialect/HAL/Utils/LLVMLinkerUtils.cpp | 3 +- .../Dialect/Stream/Analysis/Partitioning.cpp | 9 +- .../Partitioning/ReferencePartitioning.cpp | 6 +- .../Dialect/Stream/Analysis/ResourceUsage.cpp | 48 ++-- .../Conversion/FlowToStream/Patterns.cpp | 21 +- .../Conversion/HALToStream/Patterns.cpp | 3 +- .../Stream/Conversion/PatternUtils.cpp | 3 +- .../Conversion/StandardToStream/Patterns.cpp | 6 +- .../Conversion/UtilToStream/Patterns.cpp | 30 ++- .../Dialect/Stream/IR/StreamDialect.cpp | 3 +- .../Dialect/Stream/IR/StreamOpFolders.cpp | 228 ++++++++++++------ .../compiler/Dialect/Stream/IR/StreamOps.cpp | 191 ++++++++++----- .../Dialect/Stream/IR/StreamTypes.cpp | 82 ++++--- .../Stream/Transforms/AnnotateAffinities.cpp | 3 +- .../Transforms/AnnotateDispatchArguments.cpp | 24 +- .../AnnotateDispatchAssumptions.cpp | 21 +- .../Stream/Transforms/ConvertToStream.cpp | 3 +- .../Stream/Transforms/DumpStatistics.cpp | 18 +- .../Stream/Transforms/ElideAsyncCopies.cpp | 27 ++- .../Stream/Transforms/ElideTimepoints.cpp | 54 +++-- .../Stream/Transforms/EncodeTensors.cpp | 18 +- .../Stream/Transforms/FoldUniformOperands.cpp | 21 +- .../Transforms/FuseDispatchBindings.cpp | 9 +- .../Stream/Transforms/LayoutSlices.cpp | 3 +- .../Stream/Transforms/MaterializeBuiltins.cpp | 3 +- .../Transforms/MaterializeCopyOnWrite.cpp | 26 +- .../Stream/Transforms/PackConstants.cpp | 9 +- .../Transforms/PackDispatchOperands.cpp | 9 +- .../Stream/Transforms/PropagateTimepoints.cpp | 61 +++-- .../Dialect/Stream/Transforms/RefineUsage.cpp | 27 ++- .../Stream/Transforms/ScheduleAllocation.cpp | 60 +++-- .../Stream/Transforms/ScheduleConcurrency.cpp | 36 ++- .../Stream/Transforms/ScheduleExecution.cpp | 30 ++- .../Transforms/SpecializeDispatches.cpp | 12 +- .../Transforms/UnifyEncodingForGlobals.cpp | 6 +- .../Dialect/Stream/Transforms/Utils.h | 3 +- .../Stream/Transforms/VerifyAffinities.cpp | 3 +- .../Transforms/VerifyAsyncAccessRanges.cpp | 6 +- .../Stream/Transforms/VerifyLowerings.cpp | 9 +- 90 files changed, 1397 insertions(+), 703 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp index 4a36049c9151..c4600d8dc2fd 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp @@ -402,8 +402,9 @@ struct ConvertTensorReshapePattern : public OpRewritePattern { SmallVector outputDynamicShapes; for (auto [resultShape, outputShp] : llvm::zip_equal( reshapeOp.getResultType().getShape(), outputShape[0])) { - if (ShapedType::isStatic(resultShape)) + if (ShapedType::isStatic(resultShape)) { continue; + } outputDynamicShapes.push_back(getValueOrCreateConstantIndexOp( rewriter, reshapeOp.getLoc(), outputShp)); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp index e09f44a6ead5..aef20cd47dc7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.cpp @@ -104,12 +104,14 @@ bool isOffsetSizeAndStrideMappableToFlow(ArrayRef offsets, int64_t staticSize = getVal(size, ShapedType::kDynamic); int64_t staticStride = getVal(stride, ShapedType::kDynamic); - if (staticStride != 1) + if (staticStride != 1) { return false; + } if (fullSlices == false) { - if (staticSize != 1) + if (staticSize != 1) { return false; + } } else { // TODO: Use ValueBoundsAnalysis to check whether two dynamic values // are equal. diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp index c11185012d6f..e67988ddd0d3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp @@ -49,8 +49,9 @@ struct ElideUnusedOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - if (!op.use_empty()) + if (!op.use_empty()) { return failure(); + } rewriter.eraseOp(op); return success(); } @@ -59,13 +60,15 @@ struct ElideUnusedOp : public OpRewritePattern { // Returns true if |value| is definitely empty at runtime. static bool isTensorZeroElements(Value value) { auto type = dyn_cast(value.getType()); - if (!type) + if (!type) { return false; + } // Any static dimension being zero is definitely empty. for (int64_t i = 0; i < type.getRank(); ++i) { int64_t dim = type.getDimSize(i); - if (dim == 0) + if (dim == 0) { return true; + } } return false; // may still be dynamically empty } @@ -90,8 +93,9 @@ struct ReplaceOpIfTensorOperandZeroElements : public OpRewritePattern { LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { auto operand = op->getOperand(OperandIdx); - if (!isTensorOperandZeroElements(operand)) + if (!isTensorOperandZeroElements(operand)) { return failure(); + } auto result = op->getResult(ResultIdx); auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); rewriter.replaceOpWithNewOp(op, result.getType(), @@ -106,8 +110,9 @@ struct ReplaceOpIfTensorResultZeroElements : public OpRewritePattern { LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { auto result = op->getResult(ResultIdx); - if (!isTensorResultZeroElements(result)) + if (!isTensorResultZeroElements(result)) { return failure(); + } auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); rewriter.replaceOpWithNewOp(op, result.getType(), dynamicDims); @@ -122,8 +127,9 @@ struct ReplaceOpIfTensorOperandEmpty : public OpRewritePattern { PatternRewriter &rewriter) const override { auto operand = op->getOperand(OperandIdx); auto emptyOp = dyn_cast_if_present(operand.getDefiningOp()); - if (!emptyOp) + if (!emptyOp) { return failure(); + } auto result = op->getResult(ResultIdx); auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); rewriter.replaceOpWithNewOp(op, result.getType(), @@ -139,8 +145,9 @@ static SmallVector refreshDimsOnTypeChange(Operation *op, Type oldType, Type newType, ValueRange oldDims, PatternRewriter &rewriter) { - if (oldType == newType) + if (oldType == newType) { return llvm::to_vector(oldDims); + } // Build an expanded list of all the dims - constants will be nullptr. // This lets us map back the new types without worrying about whether some @@ -212,8 +219,9 @@ struct ReplaceDispatchResultIfZeroElements // will drop it. bool didReplaceAny = false; for (auto result : op.getResults()) { - if (result.use_empty()) + if (result.use_empty()) { continue; + } if (isTensorResultZeroElements(result)) { auto dynamicDims = op.getResultDynamicDims(result.getResultNumber()); auto emptyOp = IREE::Flow::TensorEmptyOp::create( @@ -392,8 +400,9 @@ struct DeduplicateDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -598,8 +607,9 @@ struct ResolveShapedDim : public OpRewritePattern { if (dynamicDims.has_value()) { unsigned dimOffset = 0; for (unsigned i = 0; i < idx; ++i) { - if (shapedType.isDynamicDim(i)) + if (shapedType.isDynamicDim(i)) { ++dimOffset; + } } rewriter.replaceOp(op, dynamicDims.value()[dimOffset]); return success(); @@ -679,8 +689,9 @@ struct FoldSplatLoadIntoPrimitive : public OpRewritePattern { PatternRewriter &rewriter) const override { auto sourceOp = dyn_cast_if_present(loadOp.getSource().getDefiningOp()); - if (!sourceOp) + if (!sourceOp) { return failure(); + } rewriter.replaceOp(loadOp, sourceOp.getValue()); return success(); } @@ -699,8 +710,9 @@ void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, OpFoldResult TensorStoreOp::fold(FoldAdaptor operands) { auto value = operands.getValue(); - if (!value) + if (!value) { return {}; + } if (auto target = dyn_cast_if_present(operands.getTarget())) { // Store into the constant target tensor. auto targetType = cast(target.getType()); @@ -751,8 +763,9 @@ struct FoldSplatReshapeIntoSplat : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = dyn_cast_if_present( reshapeOp.getSource().getDefiningOp()); - if (!splatOp) + if (!splatOp) { return failure(); + } rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(), reshapeOp.getResultDims()); @@ -1067,8 +1080,9 @@ struct FoldTensorUpdateOpWithCasts : public OpRewritePattern { PatternRewriter &rewriter) const override { auto targetCastOp = updateOp.getTarget().getDefiningOp(); auto updateCastOp = updateOp.getUpdate().getDefiningOp(); - if (!targetCastOp && !updateCastOp) + if (!targetCastOp && !updateCastOp) { return failure(); + } Value target = (targetCastOp ? cast(targetCastOp.getSource()) : cast(updateOp.getTarget())); Value update = (updateCastOp ? cast(updateCastOp.getSource()) @@ -1094,8 +1108,9 @@ struct ReplaceOpIfTensorUpdateOperandZeroElements LogicalResult matchAndRewrite(TensorUpdateOp op, PatternRewriter &rewriter) const override { auto operand = op.getUpdate(); - if (!isTensorOperandZeroElements(operand)) + if (!isTensorOperandZeroElements(operand)) { return failure(); + } rewriter.replaceOp(op, op.getTarget()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index e5d40af3533c..4a63c88f382f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -172,13 +172,15 @@ static ParseResult parseShapedOperandList( valueTypes.emplace_back(); if (failed(parser.parseOperand(values.back())) || failed(parser.parseColon()) || - failed(parser.parseType(valueTypes.back()))) + failed(parser.parseType(valueTypes.back()))) { return failure(); + } if (int64_t dynamicDimCount = cast(valueTypes.back()).getNumDynamicDims()) { if (failed(parser.parseOperandList(valueDims, dynamicDimCount, - AsmParser::Delimiter::Braces))) + AsmParser::Delimiter::Braces))) { return failure(); + } } } while (succeeded(parser.parseOptionalComma())); return success(); @@ -248,13 +250,15 @@ static ParseResult parseWorkgroupCountRegionWithoutKeyword(OpAsmParser &parser, static void printWorkgroupCountRegionWithoutKeyword(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; auto args = body.getArguments(); for (unsigned i = 0; i < args.size(); ++i) { - if (i > 0) + if (i > 0) { p << ", "; + } p.printRegionArgument(args[i]); } p << ")"; @@ -277,8 +281,9 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "workgroups"; printWorkgroupCountRegionWithoutKeyword(p, op, body); } @@ -293,8 +298,9 @@ static ParseResult parseDispatchWorkgroupsCountRegion(OpAsmParser &parser, static void printDispatchWorkgroupsCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << " count"; printWorkgroupCountRegionWithoutKeyword(p, op, body); } @@ -309,16 +315,19 @@ static ParseResult parseDispatchEntryPoints(OpAsmParser &parser, if (succeeded(parser.parseOptionalLBrace())) { do { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRBrace())) + if (failed(parser.parseRBrace())) { return failure(); + } } else { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs); @@ -388,11 +397,12 @@ LogicalResult DispatchRegionOp::verify() { << returnOp.getNumOperands() << ")"; } for (const auto [resultType, returnType] : - llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes())) + llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes())) { if (resultType != returnType) { return returnOp->emitOpError() << "operand types do not match with parent results"; } + } } // Make sure that all returned values are ranked tensors. @@ -423,36 +433,45 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser, (void)workloadOperandsLoc; if (succeeded(parser.parseOptionalLSquare())) { workloadOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(workloadOperands)) + if (parser.parseOperandList(workloadOperands)) { return failure(); - if (parser.parseRSquare()) + } + if (parser.parseRSquare()) { return failure(); + } } if (succeeded(parser.parseOptionalArrow())) { ParseResult typeListResult = parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { - if (parser.parseType(resultTypes.emplace_back())) + if (parser.parseType(resultTypes.emplace_back())) { return failure(); + } auto shapedType = dyn_cast(resultTypes.back()); - if (!shapedType) + if (!shapedType) { return success(); - if (shapedType.hasStaticShape()) + } + if (shapedType.hasStaticShape()) { return success(); + } SmallVector dynamicDims; if (parser.parseOperandList(dynamicDims, shapedType.getNumDynamicDims(), - OpAsmParser::Delimiter::Braces)) + OpAsmParser::Delimiter::Braces)) { return failure(); + } allOperands.append(dynamicDims.begin(), dynamicDims.end()); return success(); }); - if (typeListResult) + if (typeListResult) { return failure(); + } } - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { return failure(); - if (parser.parseRegion(*bodyRegion)) + } + if (parser.parseRegion(*bodyRegion)) { return failure(); + } if (parseDispatchWorkgroupsCountRegion(parser, *workloadCountRegion)) { return failure(); @@ -466,8 +485,9 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser, static_cast(workloadOperands.size())})); if (parser.resolveOperands(allOperands, parser.getBuilder().getIndexType(), - result.operands)) + result.operands)) { return failure(); + } if (parser.resolveOperands(workloadOperands, parser.getBuilder().getIndexType(), workloadOperandsLoc, result.operands)) { @@ -498,8 +518,9 @@ void DispatchRegionOp::print(OpAsmPrinter &p) { resultDimCounter += shapedType.getNumDynamicDims(); } } - if (it.index() < getNumResults() - 1) + if (it.index() < getNumResults() - 1) { p << ", "; + } } p << ")"; p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); @@ -519,9 +540,11 @@ void DispatchRegionOp::print(OpAsmPrinter &p) { ValueRange DispatchRegionOp::getResultDynamicDims(unsigned idx) { unsigned counter = 0; - for (unsigned i = 0; i < idx; ++i) - if (auto shapedType = dyn_cast(getResultTypes()[i])) + for (unsigned i = 0; i < idx; ++i) { + if (auto shapedType = dyn_cast(getResultTypes()[i])) { counter += shapedType.getNumDynamicDims(); + } + } auto shapedType = dyn_cast(getResultTypes()[idx]); return getResultDims().slice(counter, shapedType ? shapedType.getNumDynamicDims() : 0); @@ -590,8 +613,9 @@ bool dropUnusedAndRedundantDispatchRegionResults( "expected that all dynamic dims were processed"); // Nothing to do if all results are used. - if (droppedResultValues.empty()) + if (droppedResultValues.empty()) { return false; + } // Create new region and move over the body. auto newRegionOp = @@ -850,12 +874,14 @@ LogicalResult DispatchWorkgroupsOp::verify() { return success(); }; for (auto type : getOperandTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } for (auto type : getResultTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } // Workgroup count region is optional. @@ -879,22 +905,26 @@ BlockArgument DispatchWorkgroupsOp::getOutputBlockArgument(unsigned idx) { // Some outputs are tied to inputs and share their block arguments. int64_t tiedOperand = cast((*tiedOperands)[idx]).getValue().getSExtValue(); - if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) + if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) { // This output is tied to an input. return getInputBlockArgument(tiedOperand); + } unsigned nextOutArgIdx = getArguments().size(); - for (unsigned i = 0; i < idx; ++i) + for (unsigned i = 0; i < idx; ++i) { if (cast((*tiedOperands)[i]).getValue().getSExtValue() == - IREE::Util::TiedOpInterface::kUntiedIndex) + IREE::Util::TiedOpInterface::kUntiedIndex) { nextOutArgIdx++; + } + } return getWorkgroupBody().getArguments()[nextOutArgIdx]; } SmallVector DispatchWorkgroupsOp::getOutputBlockArguments() { SmallVector result; - for (unsigned i = 0; i < getNumResults(); ++i) + for (unsigned i = 0; i < getNumResults(); ++i) { result.push_back(getOutputBlockArgument(i)); + } return result; } @@ -954,10 +984,12 @@ refineTensorAccess(Value value, IREE::TensorExt::DispatchTensorType type) { hasWrites = true; }); } - if (hasReads && !hasWrites) + if (hasReads && !hasWrites) { tensorAccess = IREE::TensorExt::TensorAccess::ReadOnly; - if (!hasReads && hasWrites) + } + if (!hasReads && hasWrites) { tensorAccess = IREE::TensorExt::TensorAccess::WriteOnly; + } } return tensorAccess; } @@ -1071,16 +1103,18 @@ DispatchWorkgroupsOp::cloneReplacementExcludingOperandsAndResults( auto erasedArguments = llvm::to_vector(excludedOperandIndices); for (unsigned i = baseResultIndex, e = newBody.getNumArguments(); i != e; ++i) { - if (!is_contained(excludedResultIndices, i - baseResultIndex)) + if (!is_contained(excludedResultIndices, i - baseResultIndex)) { continue; + } auto arg = newBody.front().getArgument(i); eraseArgUseTree(arg, rewriter); erasedArguments.push_back(i); } auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : erasedArguments) + for (auto i : erasedArguments) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; @@ -1093,8 +1127,9 @@ DispatchWorkgroupsOp::getTiedOperandsIndexAndLength() { SmallVector DispatchWorkgroupsOp::getTiedOperandsAsIntegerList() { ArrayAttr attr = getTiedOperandsAttr(); - if (!attr) + if (!attr) { return {}; + } return llvm::map_to_vector(attr, [](Attribute intAttr) { return cast(intAttr).getInt(); }); diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp index ec6abe700972..63971aa4cd2f 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp @@ -53,8 +53,9 @@ Type FlowDialect::parseType(DialectAsmParser &parser) const { Type type; OptionalParseResult parseResult = generatedTypeParser(parser, &mnemonic, type); - if (parseResult.has_value()) + if (parseResult.has_value()) { return type; + } parser.emitError(parser.getCurrentLocation()) << "unknown Flow type: " << mnemonic; return {}; diff --git a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp index 1639cf7172a9..69ba78396530 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.cpp @@ -34,8 +34,9 @@ void registerTransformDialectFlowExtension(DialectRegistry ®istry) { static SmallVector getIndicesOfDynamicDims(ShapedType t) { int64_t numDynamicDims = t.getNumDynamicDims(); SmallVector res(numDynamicDims); - for (int64_t dim = 0; dim != numDynamicDims; ++dim) + for (int64_t dim = 0; dim != numDynamicDims; ++dim) { res[dim] = t.getDynamicDimIndex(dim); + } return res; } @@ -61,8 +62,9 @@ static LogicalResult populateWorkgroupCountComputingRegion( // TODO: Iteratively pull operations that are only consuming IndexType. for (Value v : forallOp.getUpperBound(rewriter)) { auto op = dyn_cast_if_present(v.getDefiningOp()); - if (!op) + if (!op) { return failure(); + } results.push_back( cast(rewriter.clone(*op)).getResult()); } @@ -124,17 +126,21 @@ static void rewriteExtractSlices(RewriterBase &rewriter, scf::ForallOp forallOp, IRMapping tensorToFlowBvm) { dispatchOp->walk([&](tensor::ExtractSliceOp extractSliceOp) { Value source = extractSliceOp.getSource(); - if (auto sourceBbArg = dyn_cast(source)) - if (sourceBbArg.getOwner()->getParentOp() == forallOp.getOperation()) + if (auto sourceBbArg = dyn_cast(source)) { + if (sourceBbArg.getOwner()->getParentOp() == forallOp.getOperation()) { source = forallOp.getTiedOpOperand(sourceBbArg)->get(); + } + } auto it = llvm::find(tensorOperands, source); - if (it == tensorOperands.end()) + if (it == tensorOperands.end()) { return; + } int64_t index = std::distance(tensorOperands.begin(), it); Value sourceFlow = tensorToFlowBvm.lookupOrNull(source); - if (!sourceFlow) + if (!sourceFlow) { return; + } Location loc = extractSliceOp.getLoc(); OpBuilder::InsertionGuard g(rewriter); @@ -162,22 +168,26 @@ static void cloneOpsIntoForallOp(RewriterBase &rewriter, // Add all ops who's results are used inside the ForallOp to the // worklist. llvm::SetVector worklist; - for (Value v : valuesDefinedAbove) - if (Operation *op = v.getDefiningOp()) + for (Value v : valuesDefinedAbove) { + if (Operation *op = v.getDefiningOp()) { worklist.insert(op); + } + } llvm::SmallVector opsToClone; llvm::DenseSet visited; // Process all ops in the worklist. while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); - if (visited.contains(op)) + if (visited.contains(op)) { continue; + } visited.insert(op); // Do not clone ops that are not clonable. - if (!IREE::Flow::isClonableIntoDispatchOp(op)) + if (!IREE::Flow::isClonableIntoDispatchOp(op)) { continue; + } // Do not clone ParallelInsertSliceOp destinations. bool isDestination = any_of( @@ -186,16 +196,18 @@ static void cloneOpsIntoForallOp(RewriterBase &rewriter, .getDest() .getDefiningOp() == op; }); - if (isDestination) + if (isDestination) { continue; + } opsToClone.push_back(op); // Add all operands to the worklist. for (Value operand : op->getOperands()) { Operation *operandOp = operand.getDefiningOp(); - if (!operandOp) + if (!operandOp) { continue; + } worklist.insert(operandOp); } } @@ -206,9 +218,11 @@ static void cloneOpsIntoForallOp(RewriterBase &rewriter, for (Operation *op : llvm::reverse(opsToClone)) { Operation *cloned = rewriter.clone(*op); SmallVector uses; - for (OpOperand &use : op->getUses()) - if (forallOp->isProperAncestor(use.getOwner())) + for (OpOperand &use : op->getUses()) { + if (forallOp->isProperAncestor(use.getOwner())) { uses.push_back(&use); + } + } for (OpOperand *use : uses) { unsigned resultNum = cast(use->get()).getResultNumber(); rewriter.modifyOpInPlace( @@ -264,13 +278,15 @@ rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp, BlockArgument destBbArg = cast(parallelInsertOp.getDest()); Value dest = forallOp.getTiedOpOperand(destBbArg)->get(); bool inserted = resultTensorOperands.insert(dest); - if (!inserted) + if (!inserted) { continue; + } auto dynamicDims = getIndicesOfDynamicDims(cast(dest.getType())); - for (int64_t dim : dynamicDims) + for (int64_t dim : dynamicDims) { resultTensorsDynamicDims.insert( tensor::DimOp::create(rewriter, loc, dest, dim)); + } } assert(resultTensorOperands.size() == forallOp.getNumResults() && "Expected as many resultTensorOperands as results of forallOp"); @@ -289,21 +305,25 @@ rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp, nonTensorOperands.push_back(v); continue; } - if (resultTensorOperands.contains(v)) + if (resultTensorOperands.contains(v)) { continue; + } tensorOperands.push_back(v); - for (int64_t dim : getIndicesOfDynamicDims(tensorType)) + for (int64_t dim : getIndicesOfDynamicDims(tensorType)) { tensorDynamicDims.push_back(tensor::DimOp::create(rewriter, loc, v, dim)); + } } // Also add shared outputs. (These are usually already added as result // tensor operands.) for (Value v : forallOp.getOutputs()) { auto tensorType = cast(v.getType()); - if (resultTensorOperands.contains(v)) + if (resultTensorOperands.contains(v)) { continue; + } tensorOperands.push_back(v); - for (int64_t dim : getIndicesOfDynamicDims(tensorType)) + for (int64_t dim : getIndicesOfDynamicDims(tensorType)) { tensorDynamicDims.push_back(tensor::DimOp::create(rewriter, loc, v, dim)); + } } // Step 3. Create ordered vectors of operands to pass to the builder and @@ -340,10 +360,11 @@ rewriteForeachThreadToFlowDispatchWorkgroups(scf::ForallOp forallOp, // Step 4. Outline the compute workload region and set up the workload // operands. if (failed(populateWorkgroupCountComputingRegion(rewriter, forallOp, - dispatchOp))) + dispatchOp))) { return forallOp->emitOpError( "failed to populate workload region for dispatchOp: ") << dispatchOp; + } // Step 5. Fixup dispatchOp bbArgs and terminator. // TODO: Ideally the builder would have created the proper bbArgs and the @@ -465,8 +486,9 @@ IREE::transform_dialect::ForeachThreadToFlowDispatchWorkgroupsOp::applyToOne( IRRewriter patternRewriter(target->getContext()); FailureOr result = rewriteForeachThreadToFlowDispatchWorkgroups(target, patternRewriter); - if (failed(result)) + if (failed(result)) { return emitDefaultDefiniteFailure(target); + } results.push_back(*result); return DiagnosedSilenceableFailure::success(); } @@ -484,8 +506,9 @@ IREE::transform_dialect::RegionToWorkgroupsOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &) { FailureOr result = rewriteFlowDispatchRegionToFlowDispatchWorkgroups(target, rewriter); - if (failed(result)) + if (failed(result)) { return emitDefaultDefiniteFailure(target); + } results.push_back(*result); return DiagnosedSilenceableFailure::success(); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp index fdb0b3e0a4fd..05a0b7727ad6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/AnnotateDispatches.cpp @@ -80,8 +80,9 @@ static TensorType getMainTensorForLinalgExtOp(Operation *op) { auto resultTypes = llvm::to_vector(op->getResultTypes()); for (Type t : llvm::concat(operandTypes, resultTypes)) { auto tensorType = dyn_cast(t); - if (!tensorType) + if (!tensorType) { continue; + } if (!main) { main = tensorType; } else if (costOfDomain(tensorType.getShape()) > @@ -182,19 +183,22 @@ static std::string getLinalgDataTypes(linalg::LinalgOp op) { static std::string getOpNameWithoutDialectName(Operation *op) { auto opName = op->getName().getStringRef().drop_until([](char c) { return c == '.'; }); - if (opName.starts_with(".")) + if (opName.starts_with(".")) { opName = opName.drop_front(); + } return opName.str(); } static bool isMatvecLike(linalg::LinalgOp linalgOp) { - if (!linalg::isaContractionOpInterface(linalgOp)) + if (!linalg::isaContractionOpInterface(linalgOp)) { return false; + } FailureOr dims = linalg::inferContractionDims(linalgOp); - if (failed(dims)) + if (failed(dims)) { return false; + } // One of the input should have all the parallel dimensions with size one. SmallVector bounds = linalgOp.getStaticLoopRanges(); @@ -207,8 +211,9 @@ static bool isMatvecLike(linalg::LinalgOp linalgOp) { unsigned pos = cast(result).getPosition(); // For a parallel dim, the bounds can be non-one if it's batch dim. if (iterators[pos] == utils::IteratorType::parallel && bounds[pos] != 1 && - !llvm::is_contained(dims->batch, pos)) + !llvm::is_contained(dims->batch, pos)) { return false; + } } return true; }; @@ -316,8 +321,9 @@ static std::string summarizeLinalgOp(linalg::LinalgOp op) { if (prefix.empty()) { // By default, use the op name as prefix. auto opName = op->getName().getStringRef(); - if (!opName.consume_front("linalg.")) + if (!opName.consume_front("linalg.")) { return ""; + } prefix = opName.str(); } @@ -331,8 +337,9 @@ static std::string summarizeLinalgExtOp(Operation *op) { auto opName = op->getName().getStringRef(); // Currently, this utility is also invoked by Linalg::SoftmaxOp. if (!(opName.consume_front("iree_linalg_ext.") || - opName.consume_front("linalg."))) + opName.consume_front("linalg."))) { return ""; + } std::string suffix = ""; if (TensorType mainTensor = getMainTensorForLinalgExtOp(op)) { llvm::raw_string_ostream sstream(suffix); @@ -382,8 +389,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { TypeSwitch(op) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgSoftmaxOpCost(op); - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -391,8 +399,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { }) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgOpCost(op); - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -403,8 +412,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { // SetEncoding/UnsetEncoding/PackOp/UnPackOp is the bestOp only if // there are no other operations. int64_t estimatedCost = kMinEstimatedCost + 1; - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -412,8 +422,9 @@ static std::string summarizeDispatchRegion(Region ®ion) { }) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgExtOpCost(op); - if (estimatedCost < bestEstimatedCost) + if (estimatedCost < bestEstimatedCost) { return; + } bestEstimatedCost = estimatedCost; bestOp = op; LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() @@ -507,8 +518,9 @@ struct AnnotateDispatchesPass for (auto executableOp : getOperation().getBody()->getOps()) { auto innerModuleOp = executableOp.getInnerModule(); - if (!innerModuleOp) + if (!innerModuleOp) { continue; + } for (auto exportOp : executableOp.getBlock().getOps()) { auto oldSymbolRefAttr = SymbolRefAttr::get( @@ -517,11 +529,13 @@ struct AnnotateDispatchesPass auto funcOp = innerModuleOp.lookupSymbol( exportOp.getFunctionRef()); - if (!funcOp) + if (!funcOp) { continue; // extern module, maybe + } std::string summary = summarizeDispatchRegion(funcOp.getFunctionBody()); - if (summary.empty()) + if (summary.empty()) { continue; // unable to tell + } std::string newName = funcOp.getName().str() + "_" + summary; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp index 109f00efd2a7..9450fd50cb9e 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Canonicalize.cpp @@ -94,8 +94,9 @@ class AffineApplyLowering : public OpRewritePattern { auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), llvm::to_vector<8>(op.getOperands())); - if (!maybeExpandedMap) + if (!maybeExpandedMap) { return failure(); + } rewriter.replaceOp(op, *maybeExpandedMap); return success(); } @@ -113,10 +114,12 @@ struct CanonicalizePass : public impl::CanonicalizePassBase { mlir::GreedySimplifyRegionLevel::Normal); RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) + for (auto *dialect : context->getLoadedDialects()) { dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) + } + for (RegisteredOperationName op : context->getRegisteredOperations()) { op.getCanonicalizationPatterns(owningPatterns, context); + } // Pull in some borderline/downstream canonicalizations for the Flow // compilation phase. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp index 57b1d1ac093b..b53fe9e783a1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CaptureDynamicDims.cpp @@ -47,8 +47,9 @@ static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) { outerToInnerMap[operand] = arg; } for (auto result : dispatchOp.getResults()) { - if (dispatchOp.getTiedResultOperand(result)) + if (dispatchOp.getTiedResultOperand(result)) { continue; // ignored tied + } auto arg = entryBlock->getArgument(argIdx++); outerToInnerMap[result] = arg; } @@ -59,16 +60,19 @@ static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) { auto captureTensorDims = [&](Value externalValue, Value internalValue) { auto tensorType = dyn_cast(internalValue.getType()); - if (!tensorType) + if (!tensorType) { return; - if (tensorType.hasStaticShape()) + } + if (tensorType.hasStaticShape()) { return; + } // Find the dimensions in the parent. auto maybeDynamicDims = IREE::Util::findDynamicDims( externalValue, dispatchOp->getBlock(), Block::iterator(dispatchOp)); - if (!maybeDynamicDims.has_value()) + if (!maybeDynamicDims.has_value()) { return; + } // Convert to a vector -- we cannot use the ValueRange directly because // it might point into the operand list of this op, which we might mutate // in-place. @@ -116,8 +120,9 @@ static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) { captureTensorDims(operand, outerToInnerMap[operand]); } for (auto result : dispatchOp.getResults()) { - if (dispatchOp.getTiedResultOperand(result)) + if (dispatchOp.getTiedResultOperand(result)) { continue; // ignore tied + } captureTensorDims(result, outerToInnerMap[result]); } } @@ -141,19 +146,22 @@ static void captureDims(scf::ForOp forOp) { llvm::zip_equal(forOp.getInitArgs(), forOp.getYieldedValues(), forOp.getRegionIterArgs(), forOp.getResults())) { auto tensorType = dyn_cast(init.getType()); - if (!tensorType || tensorType.hasStaticShape()) + if (!tensorType || tensorType.hasStaticShape()) { continue; + } // Make the transform idempotent by not caring about tensors only used // within 'flow.tensor.tie_shape' operations. - if (llvm::all_of(bbArg.getUsers(), llvm::IsaPred)) + if (llvm::all_of(bbArg.getUsers(), llvm::IsaPred)) { continue; + } dynamicTensorIterables.push_back({init, iter, bbArg, result}); } - if (dynamicTensorIterables.empty()) + if (dynamicTensorIterables.empty()) { return; + } // Create the new dimension loop variables. Since the dynamic tensors may be // of different types with varying number of dynamic dimensions, 'dimBounds' @@ -169,26 +177,31 @@ static void captureDims(scf::ForOp forOp) { dimBounds.push_back(newIterables.size()); std::optional initDynamicDims = IREE::Util::findDynamicDims( init, forOp->getBlock(), Block::iterator(forOp)); - if (!initDynamicDims) + if (!initDynamicDims) { continue; + } std::optional iterDynamicDims = IREE::Util::findDynamicDims( iter, forOp.getBody(), Block::iterator(forOp.getBody()->getTerminator())); - if (!iterDynamicDims) + if (!iterDynamicDims) { continue; + } - if (iterDynamicDims->size() != initDynamicDims->size()) + if (iterDynamicDims->size() != initDynamicDims->size()) { continue; + } for (auto [initDim, iterDim] : - llvm::zip_equal(*initDynamicDims, *iterDynamicDims)) + llvm::zip_equal(*initDynamicDims, *iterDynamicDims)) { newIterables.push_back({initDim, iterDim}); + } } dimBounds.push_back(newIterables.size()); - if (newIterables.empty()) + if (newIterables.empty()) { return; + } // A new 'scf.for' has to be created to replace the old one as new results // are being added. @@ -223,8 +236,9 @@ static void captureDims(scf::ForOp forOp) { auto dims = ArrayRef(newIterables) .slice(dimBounds[index], dimBounds[index + 1] - dimBounds[index]); - if (dims.empty()) + if (dims.empty()) { continue; + } Value tied = Flow::TensorTieShapeOp::create( builder, forOp.getLoc(), tensor.bbArg, @@ -242,8 +256,9 @@ static void captureDims(scf::ForOp forOp) { auto dims = ArrayRef(newIterables) .slice(dimBounds[index], dimBounds[index + 1] - dimBounds[index]); - if (dims.empty()) + if (dims.empty()) { continue; + } Value &replacement = results[tensor.result.getResultNumber()]; replacement = Flow::TensorTieShapeOp::create( diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp index db3957c96f2f..6b4d127ff8c7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CleanupTensorShapes.cpp @@ -34,8 +34,9 @@ struct CleanupTensorShapesPass foundBadOps = true; } }); - if (foundBadOps) + if (foundBadOps) { return signalPassFailure(); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp index 3e3a46a678d2..133ed3bea247 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp @@ -36,8 +36,9 @@ static void appendDynamicDims(OpBuilder &b, Location loc, } for (auto dim : llvm::enumerate(tensorType.getShape())) { - if (ShapedType::isStatic(dim.value())) + if (ShapedType::isStatic(dim.value())) { continue; + } argumentDims.push_back( b.createOrFold(loc, tensor, dim.index())); } @@ -50,8 +51,9 @@ findFirstTiedValueOutsideOfRegionOp(IREE::Flow::DispatchRegionOp regionOp, Value value) { // Check if `v` is defined outside of `regionOp`. auto isOutside = [&](Value v) { - if (isa(v)) + if (isa(v)) { return !regionOp->isAncestor(v.getDefiningOp()); + } assert(isa(v) && "expected bbArg"); // DispatchRegionOp does not have block arguments. return true; @@ -107,8 +109,9 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups( SmallVector argumentDims; for (Value tensor : argumentsSet) { auto tensorType = dyn_cast(tensor.getType()); - if (!tensorType) + if (!tensorType) { continue; + } appendDynamicDims(rewriter, loc, argumentDims, tensor); } @@ -129,13 +132,15 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups( llvm::enumerate(origTerminators.front()->getOperands())) { auto tiedArgument = findFirstTiedValueOutsideOfRegionOp(regionOp, it.value()); - if (!tiedArgument.has_value()) + if (!tiedArgument.has_value()) { continue; + } assert(argumentsSet.contains(*tiedArgument) && "expected that tiedArgument is already an argument"); // Do not tie an argument to multiple results. - if (tiedArgumentsSet.contains(*tiedArgument)) + if (tiedArgumentsSet.contains(*tiedArgument)) { continue; + } tiedArgumentsSet.insert(*tiedArgument); tiedArguments[it.index()] = std::distance( argumentsSet.begin(), llvm::find(argumentsSet, *tiedArgument)); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp index 5e4b314b4ff0..19e294a1dc80 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertShardToFlow.cpp @@ -139,11 +139,12 @@ static bool isDefaultChannel(shard::GridOp grid, static Value getDefaultChannel(Location loc, shard::GridOp grid, bool useNamedDefaultChannels, OpBuilder &builder) { - if (useNamedDefaultChannels) + if (useNamedDefaultChannels) { return IREE::Flow::ChannelDefaultOp::create(builder, loc, grid.getSymName()); - else + } else { return IREE::Flow::ChannelDefaultOp::create(builder, loc); + } } static Value buildCachedChannelLoading(Location loc, shard::GridOp grid, @@ -254,8 +255,9 @@ static void createChannels(ModuleOp moduleOp, llvm::sort(gridAndAxesSetSorted, [](auto &a, auto &b) { int nameCompareRes = std::get<0>(a).getSymName().compare(std::get<0>(b).getSymName()); - if (nameCompareRes == 0) + if (nameCompareRes == 0) { return std::get<1>(a) < std::get<1>(b); + } return nameCompareRes < 0; }); for (auto &[shard, shardAxes] : llvm::make_range( @@ -292,8 +294,9 @@ static void removeShardOps(GridAndAxesSet &gridAndAxesSet) { DenseSet gridOpsSet(std::begin(gridRange), std::end(gridRange)); for (shard::GridOp op : gridOpsSet) { - if (op) + if (op) { op.erase(); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp index 31769d71e67c..e76352da21cf 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp @@ -19,8 +19,9 @@ namespace { // Utilities to make SymbolRefAttr easier to construct. static SymbolRefAttr nestSymbolRef(SymbolRefAttr baseRefAttr, FlatSymbolRefAttr leafRefAttr) { - if (!baseRefAttr) + if (!baseRefAttr) { return leafRefAttr; + } SmallVector nestedRefAttrs; llvm::append_range(nestedRefAttrs, baseRefAttr.getNestedReferences()); nestedRefAttrs.push_back(leafRefAttr); @@ -43,8 +44,9 @@ static void gatherReplacements( for (auto [oldNestedSymbolOp, newNestedSymbolOp] : llvm::zip_equal(nestedOldRegion.getOps(), nestedNewRegion.getOps())) { - if (!oldNestedSymbolOp.isPublic()) + if (!oldNestedSymbolOp.isPublic()) { continue; // ignore private symbols + } auto oldNestedSymbolRefAttr = nestSymbolRef(oldSymbolRefAttr, oldNestedSymbolOp); auto newNestedSymbolRefAttr = @@ -140,8 +142,9 @@ static int deduplicateObjects(Operation *scopeOp, // We could rely on SymbolDCE for this but that makes looking at IR dumps // harder as after this pass runs and until SymbolDCE runs there are lots of // dead objects in the output. - for (auto *op : deadOps) + for (auto *op : deadOps) { op->erase(); + } return deadOps.size(); } @@ -156,11 +159,13 @@ class DeduplicateExecutablesPass mlir::ModuleOp moduleOp = getOperation(); SmallVector allObjects; for (auto &op : moduleOp.getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { allObjects.push_back(&op); + } } - if (allObjects.empty()) + if (allObjects.empty()) { return; + } (void)deduplicateObjects(moduleOp, allObjects); // totalObjects = allObjects.size(); // objectsDeduplicated = deduplicateObjects(moduleOp, allObjects); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp index e2082748d7dd..1f05d356b691 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp @@ -47,8 +47,9 @@ static const StringRef kShapeNone = "plain"; static const StringRef kShapeEllipse = "ellipse"; static StringRef getShape(Operation *op) { - if (isa(op)) + if (isa(op)) { return kShapeBox; + } return kShapeEllipse; } @@ -57,8 +58,9 @@ static StringRef getShape(Operation *op) { static int64_t getLargeAttributeSizeLimit() { // Use the default from the printer flags if possible. if (std::optional limit = - OpPrintingFlags().getLargeElementsAttrLimit()) + OpPrintingFlags().getLargeElementsAttrLimit()) { return *limit; + } return 16; } @@ -142,8 +144,9 @@ class GraphPrinter { void emitFunctions(ModuleOp module) { auto funcOps = module.getOps(); - if (funcOps.empty()) + if (funcOps.empty()) { return; + } emitGraph([&]() { for (auto funcOp : funcOps) { @@ -167,8 +170,9 @@ class GraphPrinter { /// Emit all edges. This function should be called after all nodes have been /// emitted. void emitAllEdgeStmts() { - for (const std::string &edge : edges) + for (const std::string &edge : edges) { os << edge << ";\n"; + } edges.clear(); } @@ -243,13 +247,16 @@ class GraphPrinter { // Do not label edges that start/end at a cluster boundary. Such edges are // clipped at the boundary, but labels are not. This can lead to labels // floating around without any edge next to them. - if (!n1.clusterId && !n2.clusterId) + if (!n1.clusterId && !n2.clusterId) { attrs["label"] = quoteString(escapeString(std::move(label))); + } // Use `ltail` and `lhead` to draw edges between clusters. - if (n1.clusterId) + if (n1.clusterId) { attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); - if (n2.clusterId) + } + if (n2.clusterId) { attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); + } edges.push_back(strFromOs([&](raw_ostream &os) { os << llvm::format("v%i -> v%i ", n1.id, n2.id); @@ -344,12 +351,14 @@ class GraphPrinter { } void annotateOperation(raw_ostream &os, Operation *op, AsmState &state) { - if (isa(op)) + if (isa(op)) { return; + } if (op->hasTrait() && - isa(op->getParentOp())) + isa(op->getParentOp())) { return; + } if (auto load = dyn_cast(op)) { printDispatchTensorLoad(os, load, state); @@ -385,18 +394,21 @@ class GraphPrinter { auto entryPoint = *dispatchOp.getEntryPointRefs().begin(); auto executableOp = cast(SymbolTable::lookupNearestSymbolFrom( dispatchOp, entryPoint.getRootReference())); - if (!executableOp) + if (!executableOp) { return; + } auto calleeNameAttr = entryPoint.getLeafReference(); auto innerModule = executableOp.getInnerModule(); - if (!innerModule) + if (!innerModule) { return; + } auto funcOps = innerModule.getOps(); auto funcIt = llvm::find_if( funcOps, [&](auto op) { return op.getNameAttr() == calleeNameAttr; }); - if (funcIt == funcOps.end()) + if (funcIt == funcOps.end()) { return; + } auto callee = *funcIt; @@ -506,25 +518,29 @@ class GraphPrinter { /// operation inside the cluster. void processBlock(Block &block) { emitClusterStmt([&]() { - for (BlockArgument &blockArg : block.getArguments()) + for (BlockArgument &blockArg : block.getArguments()) { valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); + } // Emit a node for each operation. std::optional prevNode; for (Operation &op : block) { Node nextNode = processOperation(&op); - if (printControlFlowEdges && prevNode) + if (printControlFlowEdges && prevNode) { emitEdgeStmt(*prevNode, nextNode, /*label=*/"", kLineStyleControlFlow); + } prevNode = nextNode; } }); } bool isScalarConstantOp(Operation *op) { - if (auto constOp = dyn_cast(op)) - if (constOp.getResult().getType().isIntOrIndexOrFloat()) + if (auto constOp = dyn_cast(op)) { + if (constOp.getResult().getType().isIntOrIndexOrFloat()) { return true; + } + } return false; } @@ -555,8 +571,9 @@ class GraphPrinter { // Emit cluster for op with regions. node = emitClusterStmt( [&]() { - for (Region ®ion : op->getRegions()) + for (Region ®ion : op->getRegions()) { processRegion(region); + } }, getLabel(op)); } else { @@ -578,22 +595,25 @@ class GraphPrinter { } } - for (Value result : op->getResults()) + for (Value result : op->getResults()) { valueToNode[result] = node; + } return node; } /// Process a region. void processRegion(Region ®ion) { - for (Block &block : region.getBlocks()) + for (Block &block : region.getBlocks()) { processBlock(block); + } } /// Truncate long strings. std::string truncateString(std::string str) { - if (str.length() <= maxLabelLen) + if (str.length() <= maxLabelLen) { return str; + } return str.substr(0, maxLabelLen) + "..."; } @@ -629,8 +649,9 @@ class DumpDispatchGraphPass void runOnOperation() override { auto modOp = dyn_cast(getOperation()); - if (!modOp) + if (!modOp) { return; + } // Open the output file we'll be streaming to. // Since we are processing the entire module at once we overwrite the file. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp index b9b8aafc7c56..9c2644adf868 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp @@ -225,8 +225,9 @@ createEntryPointBenchmarkFunc(mlir::ModuleOp moduleOp, for (auto arg : entryFuncOp.getArguments()) { auto dummyVar = createDummyInput(funcName, arg, symbolTable, moduleBuilder, explorer); - if (!dummyVar) + if (!dummyVar) { return failure(); + } dummyInputVariableOps.push_back(dummyVar); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp index 93c48013053d..5f830e2feb07 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp @@ -32,22 +32,25 @@ TensorDimTrackingRewriter::TensorDimTrackingRewriter(Operation *op) } SmallVector TensorDimTrackingRewriter::getTensorDimOps() { SmallVector result; - for (Operation *op : dimOps) + for (Operation *op : dimOps) { result.push_back(cast(op)); + } return result; } void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) { IRRewriter::Listener::notifyOperationErased(op); - if (isa(op)) + if (isa(op)) { dimOps.erase(op); + } } void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op, InsertPoint previous) { IRRewriter::Listener::notifyOperationInserted(op, previous); auto dimOp = dyn_cast(op); - if (dimOp && isa(dimOp.getSource())) + if (dimOp && isa(dimOp.getSource())) { dimOps.insert(op); + } } } // namespace mlir @@ -59,8 +62,9 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, for (tensor::DimOp dimOp : dimOps) { // Only DimOps with static indices are supported. std::optional idx = dimOp.getConstantIndex(); - if (!idx.has_value()) + if (!idx.has_value()) { continue; + } if (isa(dimOp.getSource())) { continue; @@ -68,8 +72,9 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, // Only DimOps with ranked tensors are supported. auto tensorType = dyn_cast(dimOp.getSource().getType()); - if (!tensorType) + if (!tensorType) { continue; + } OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(dimOp); @@ -85,9 +90,11 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter, if (succeeded(IREE::Flow::getOptimizedDynamicResultDims( rewriter, dimOp.getSource(), dynamicDims))) { unsigned ctr = 0; - for (int64_t i = 0; i < *dimOp.getConstantIndex(); ++i) - if (tensorType.isDynamicDim(i)) + for (int64_t i = 0; i < *dimOp.getConstantIndex(); ++i) { + if (tensorType.isDynamicDim(i)) { ++ctr; + } + } rewriter.replaceOp(dimOp, dynamicDims[ctr]); } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp index 675566d83072..91459acae3e8 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InjectTensorTracing.cpp @@ -34,8 +34,9 @@ static std::string inferTraceKey(Operation *op) { static SmallVector filterTensorValues(ValueRange &&range) { SmallVector result; for (auto value : range) { - if (isa(value.getType())) + if (isa(value.getType())) { result.push_back(value); + } } return result; } @@ -76,10 +77,11 @@ struct InjectTensorTracingPass funcOp.walk([&](Operation *op) { if (auto attr = op->getAttr(attrName)) { std::string traceKey; - if (auto stringAttr = dyn_cast(attr)) + if (auto stringAttr = dyn_cast(attr)) { traceKey = stringAttr.getValue().str(); - else + } else { traceKey = inferTraceKey(op); + } injectTracingOnOp(op, traceKey); op->removeAttr(attrName); } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp index a34129fdf49e..701fe0e257aa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InsertDispatchDebugTargets.cpp @@ -29,8 +29,9 @@ namespace mlir::iree_compiler::IREE::Flow { static SmallVector filterNonTensorValues(ValueRange &&range) { SmallVector result; for (auto value : range) { - if (isa(value.getType())) + if (isa(value.getType())) { result.push_back(value); + } } return result; } @@ -39,18 +40,21 @@ static SmallVector filterNonTensorValues(ValueRange &&range) { // a negative ordinal indicating no match. static std::tuple getOrdinalFromDebugTarget(std::string marker) { - if (marker.empty() || marker[0] != '@') + if (marker.empty() || marker[0] != '@') { return std::make_tuple("", -1); + } SmallVector parts; auto cropped = marker.substr(1); llvm::SplitString(llvm::StringRef(cropped), parts, ":"); - if (parts.size() != 2) + if (parts.size() != 2) { return std::make_tuple("", -1); + } int ordinal; - if (parts[1].getAsInteger(10, ordinal)) + if (parts[1].getAsInteger(10, ordinal)) { return std::make_tuple("", -1); + } return std::make_tuple(parts[0].str(), ordinal); } @@ -78,18 +82,21 @@ static void traceOpWithName(IREE::Flow::DispatchOp dispatchOp, static LogicalResult replaceReturnWithOpResults(mlir::ModuleOp moduleOp, IREE::Util::FuncOp funcOp, Operation *op) { - if (!funcOp->isProperAncestor(op)) + if (!funcOp->isProperAncestor(op)) { return failure(); + } // TODO: Handle nested function calls. - if (!SymbolTable::symbolKnownUseEmpty(funcOp, moduleOp)) + if (!SymbolTable::symbolKnownUseEmpty(funcOp, moduleOp)) { return failure(); + } // TODO: Handle (nested) control flow. auto funcBlock = op->getBlock(); if (funcBlock->getParentOp() != funcOp || - &funcOp.getBody().front() != funcBlock) + &funcOp.getBody().front() != funcBlock) { return failure(); + } // Collect the op results and create export ops for any tensor results. OpBuilder builder(funcOp); @@ -119,8 +126,9 @@ static LogicalResult replaceReturnWithOpResults(mlir::ModuleOp moduleOp, rewriter.replaceOpWithNewOp(oldTerminator, exports); SmallVector argTypes; - for (const auto &arg : llvm::enumerate(funcOp.getArguments())) + for (const auto &arg : llvm::enumerate(funcOp.getArguments())) { argTypes.push_back(arg.value().getType()); + } funcOp.setType(FunctionType::get(context, /*inputs=*/argTypes, /*results=*/newTypes)); @@ -151,12 +159,14 @@ struct InsertDebugTargetAtOrdinalPass // Only look for dispatches in util func ops. auto funcOp = dyn_cast(operation); - if (!funcOp) + if (!funcOp) { continue; + } std::string fName = funcOp.getName().str(); - if (fName != breakFname && fName != traceFname) + if (fName != breakFname && fName != traceFname) { continue; + } int localBreakOrdinal = -1; if (fName == breakFname) { @@ -188,8 +198,9 @@ struct InsertDebugTargetAtOrdinalPass if (localBreakOrdinal >= 0 && localBreakOrdinal < dispatchOps.size()) { auto breakTarget = dispatchOps[localBreakOrdinal]; if (failed(replaceReturnWithOpResults(getOperation(), funcOp, - breakTarget))) + breakTarget))) { return signalPassFailure(); + } } } @@ -252,8 +263,9 @@ struct InsertDebugTargetAtSymbolPass Operation *operation = funcOp; auto mlirFuncOp = dyn_cast(operation); if (!mlirFuncOp || failed(replaceReturnWithOpResults( - getOperation(), mlirFuncOp, breakTarget))) + getOperation(), mlirFuncOp, breakTarget))) { return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp index 5a726890523a..ca984231f5f6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineConstants.cpp @@ -49,8 +49,9 @@ static SmallVector findConstantsInModule(mlir::ModuleOp moduleOp) { SmallVector results; for (auto callableOp : moduleOp.getOps()) { auto *region = callableOp.getCallableRegion(); - if (!region) + if (!region) { continue; + } region->walk([&](Operation *op) { if (auto constantOp = dyn_cast(op)) { if (isOutlinableValue(constantOp.getValue())) { @@ -80,8 +81,9 @@ static Operation *getParentInOp(Operation *childOp, Operation *ancestorOp) { assert(childOp != ancestorOp && "child can't be its own ancestor"); do { auto *parentOp = childOp->getParentOp(); - if (parentOp == ancestorOp) + if (parentOp == ancestorOp) { return childOp; + } childOp = parentOp; } while (childOp); assert(false && "child must be nested under ancestor"); @@ -94,16 +96,18 @@ static std::string getConstantName(ConstantDef &def) { if (auto parameterAttr = dyn_cast(def.value)) { os << "__parameter_"; - if (parameterAttr.getScope() && !parameterAttr.getScope().empty()) + if (parameterAttr.getScope() && !parameterAttr.getScope().empty()) { os << parameterAttr.getScope().getValue() << "_"; + } os << parameterAttr.getKey().getValue() << "_"; } else { os << "__constant_"; } def.type.print(os); str = sanitizeSymbolName(str); - if (str.substr(str.size() - 1) == "_") + if (str.substr(str.size() - 1) == "_") { str = str.substr(0, str.size() - 1); // strip trailing _ + } return str; } @@ -115,8 +119,9 @@ struct OutlineConstantsPass : public IREE::Flow::impl::OutlineConstantsPassBase { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } SymbolTable moduleSymbols(moduleOp); @@ -127,8 +132,9 @@ struct OutlineConstantsPass // contains the constant. OpBuilder moduleBuilder(&moduleOp.getBody()->front()); auto parentFuncOp = getParentInOp(def.op, moduleOp); - if (parentFuncOp) + if (parentFuncOp) { moduleBuilder.setInsertionPoint(parentFuncOp); + } // New immutable global takes the constant attribute in its specified // encoding. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp index 7344bc78b1c7..8ad5c0f1a0df 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchExterns.cpp @@ -152,10 +152,12 @@ struct OutlineDispatchExternsPass }) .Default(WalkResult::advance()); }; - if (funcOp.walk(outlineOps).wasInterrupted()) + if (funcOp.walk(outlineOps).wasInterrupted()) { return signalPassFailure(); - for (auto *deadOp : deadOps) + } + for (auto *deadOp : deadOps) { deadOp->erase(); + } } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp index bd52c6fb29df..60f06cd6cbc5 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp @@ -194,10 +194,12 @@ struct OutlineDispatchRegionsPass }) .Default(WalkResult::advance()); }; - if (funcOp.walk(outlineOps).wasInterrupted()) + if (funcOp.walk(outlineOps).wasInterrupted()) { return signalPassFailure(); - for (auto *deadOp : deadOps) + } + for (auto *deadOp : deadOps) { deadOp->erase(); + } } } }; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 445560f0cf27..9e461995781a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -110,8 +110,9 @@ SmallVector getLoopRanges(Operation *op, Location loc, /// Return `true` if an operation is within a `flow.dispatch.region` or /// `flow.dispatch.workgroups` op. bool isNonNullAndOutsideDispatch(Operation *op) { - if (!op) + if (!op) { return false; + } Operation *parentOp = op->getParentOp(); while (parentOp) { if (isa( @@ -204,8 +205,9 @@ static void createWorkgroupCountFromDagRootRegion( RewriterBase &rewriter, IREE::Flow::DispatchRegionOp ®ionOp, TypeRange workloadTypes, ArrayRef workloadLocs) { Region &countRegion = regionOp.getWorkgroupCount(); - if (!countRegion.empty()) + if (!countRegion.empty()) { return; + } Block *body = rewriter.createBlock(&countRegion, countRegion.begin(), workloadTypes, workloadLocs); auto args = body->getArguments(); @@ -221,8 +223,9 @@ static void createWorkgroupCountFromDagRootRegion( /// dynamic dimension. static bool hasDynamicShape(Type t) { auto shapedType = dyn_cast(t); - if (!shapedType) + if (!shapedType) { return false; + } return !shapedType.hasStaticShape(); } @@ -234,8 +237,9 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, OpBuilder::InsertionGuard guard(b); // Case 1: No dynamic result dims. - if (!hasDynamicShape(value.getType())) + if (!hasDynamicShape(value.getType())) { return success(); + } // There is at least one dynamic dimension, continue... ShapedType shapedType = cast(value.getType()); @@ -252,8 +256,9 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, // Case 2: Value is a block argument. if (auto bbArg = dyn_cast(value)) { - if (!createTensorDimOps) + if (!createTensorDimOps) { return failure(); + } b.setInsertionPointToStart(bbArg.getOwner()); emitTensorDimOps(); @@ -277,20 +282,24 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, auto tiedOp = dyn_cast(op); if (tiedOp) { Value tiedOperand = tiedOp.getTiedResultOperand(value); - if (tiedOperand && tiedOperand.getType() == value.getType()) + if (tiedOperand && tiedOperand.getType() == value.getType()) { return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims, /*createTensorDimOps=*/true); + } } // Case 5: Query ReifyRankedShapedTypeOpInterface. auto reifyShapeOp = dyn_cast(op); if (reifyShapeOp) { ReifiedRankedShapedTypeDims dims; - if (failed(reifyShapeOp.reifyResultShapes(b, dims))) + if (failed(reifyShapeOp.reifyResultShapes(b, dims))) { return failure(); - for (int64_t i = 0; i < shapedType.getRank(); ++i) - if (shapedType.isDynamicDim(i)) + } + for (int64_t i = 0; i < shapedType.getRank(); ++i) { + if (shapedType.isDynamicDim(i)) { dynamicDims.push_back(cast(dims[opResult.getResultNumber()][i])); + } + } return success(); } @@ -303,8 +312,9 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value, /*createTensorDimOps=*/true); } - if (!createTensorDimOps) + if (!createTensorDimOps) { return failure(); + } // None of the above. Insert tensor.dim ops. b.setInsertionPointAfter(op); @@ -416,8 +426,9 @@ clonePrecedingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target, Region *parentRegion = parentOperation->getParentRegion(); while ((parentOperation = parentOperation->getParentOp())) { - if (regionOp.getOperation() == parentOperation) + if (regionOp.getOperation() == parentOperation) { break; + } parentRegion = parentOperation->getParentRegion(); } @@ -876,8 +887,9 @@ bool isClonableIntoDispatchOp(Operation *op, } if (isa(op) || isa(op)) { - if (clInlineConstantByteLength == 0) + if (clInlineConstantByteLength == 0) { return false; + } Attribute constantValueAttr; if (!matchPattern(op->getResult(0), m_Constant(&constantValueAttr))) { return false; @@ -930,13 +942,15 @@ static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) { Operation *owner = ownerWorkgroupsOp ? ownerWorkgroupsOp : ownerRegionOp; // Ignore uses outside of dispatch workgroups op. - if (owner != dispatchOp) + if (owner != dispatchOp) { continue; + } // Cannot fuse producer of `dest` with `tensor.insert_slice`. if (auto insertSliceUser = dyn_cast(user)) { - if (insertSliceUser.getDest() == v) + if (insertSliceUser.getDest() == v) { return true; + } } } return false; @@ -948,8 +962,9 @@ SmallVector getCloneableOps(IREE::Flow::DispatchRegionOp regionOp, // of the dispatch region. llvm::SetVector valuesDefinedAbove; mlir::getUsedValuesDefinedAbove(regionOp.getBody(), valuesDefinedAbove); - if (valuesDefinedAbove.empty()) + if (valuesDefinedAbove.empty()) { return {}; + } // Traverse the defining ops of these values (and the ops on their reverse // SSA use-def chain). @@ -960,8 +975,9 @@ SmallVector getCloneableOps(IREE::Flow::DispatchRegionOp regionOp, while (!worklist.empty()) { Value outsideValue = worklist.pop_back_val(); // Skip values that were already visited. - if (visited.count(outsideValue)) + if (visited.count(outsideValue)) { continue; + } visited.insert(outsideValue); Operation *definingOp = outsideValue.getDefiningOp(); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp index 4090026d6305..ebb45e4be18a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/TopLevelSCFToCFG.cpp @@ -43,9 +43,10 @@ void TopLevelSCFToCFGPass::runOnOperation() { target.addLegalOp(); target.markOpRecursivelyLegal(); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler::IREE::Flow diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp index 14790bf2fb90..cb37aae40366 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp @@ -191,8 +191,9 @@ BindingLayoutAnalysis::BindingLayoutAnalysis(Operation *rootOp, // before we derive the layouts. auto getExportInfo = [&](Operation *exportOp) -> ExportInfo & { auto &exportInfo = exportInfos[exportOp]; - if (!exportInfo) + if (!exportInfo) { exportInfo = std::make_unique(); + } return *exportInfo; }; rootOp->walk([&](Operation *op) { @@ -238,8 +239,9 @@ bool BindingLayoutAnalysis::hasDispatches() const { ArrayRef BindingLayoutAnalysis::getExportDispatches(Operation *exportOp) const { auto it = exportInfos.find(exportOp); - if (it == exportInfos.end()) + if (it == exportInfos.end()) { return {}; // not analyzed + } return it->second.get()->dispatchOps; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp index 931c825d98f2..35bf462c11bf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferViewOps.cpp @@ -18,9 +18,10 @@ struct ElementTypeOpConversion ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::ElementTypeOp::getTypeValue(op.getTypeAttr().getValue()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported element type"); + } rewriter.replaceOpWithNewOp(op, value.value()); return success(); } @@ -33,9 +34,10 @@ struct EncodingTypeOpConversion matchAndRewrite(IREE::HAL::EncodingTypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::EncodingTypeOp::getTypeValue(op.getEncodingAttr()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported encoding type"); + } rewriter.replaceOpWithNewOp(op, value.value()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 42f92ae80e93..8c0917a718ab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -68,14 +68,16 @@ class CommandBufferCreateOpConversion }; auto modesValue = detail::rewriteAttrToOperands( op.getLoc(), adaptor.getModesAttr(), rewriter.getI32Type(), rewriter); - if (!modesValue.has_value()) + if (!modesValue.has_value()) { return failure(); + } callOperands.append(modesValue.value()); auto categoriesValue = detail::rewriteAttrToOperands( op.getLoc(), adaptor.getCommandCategoriesAttr(), rewriter.getI32Type(), rewriter); - if (!categoriesValue.has_value()) + if (!categoriesValue.has_value()) { return failure(); + } callOperands.append(categoriesValue.value()); callOperands.push_back(adaptor.getQueueAffinity()); if (adaptor.getBindingCapacity()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp index 414cd14a46c0..82cd2872c849 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp @@ -25,10 +25,12 @@ class DeviceQueryCastOpConversion matchAndRewrite(IREE::HAL::DeviceQueryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto targetType = op.getValue().getType(); - if (targetType.isInteger(64)) + if (targetType.isInteger(64)) { return failure(); // handled natively - if (!targetType.isIntOrIndex()) + } + if (!targetType.isIntOrIndex()) { return rewriter.notifyMatchFailure(op, "unsupported result type"); + } // Query as i64. // Note that due to type conversion we need to handle the default logic @@ -94,12 +96,14 @@ class DeviceQueryI64OpConversion LogicalResult matchAndRewrite(IREE::HAL::DeviceQueryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!op.getValue().getType().isInteger(64)) + if (!op.getValue().getType().isInteger(64)) { return failure(); + } auto results = rewriteToCall(op, adaptor, importOp, *getTypeConverter(), rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } auto ok = results->front(); auto value = results->back(); if (op.getDefaultValue().has_value()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp index 5b843f23bbab..771fac901c7a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp @@ -56,8 +56,9 @@ Value createPackedConstantBuffer(Location loc, ValueRange constantValues, // extra IR for the indices. We should batch them up and append in one go. for (auto constantValue : llvm::enumerate(constantValues)) { // Buffer is zero-initialized so we can skip zero values. - if (mlir::matchPattern(constantValue.value(), m_Zero())) + if (mlir::matchPattern(constantValue.value(), m_Zero())) { continue; + } auto constantLoc = constantValue.value().getLoc(); IREE::VM::BufferStoreI32Op::create( builder, constantLoc, constantBuffer, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 32442e52c18b..66c4f7490af9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -703,8 +703,9 @@ convertCollectiveAttr(IREE::Stream::CollectiveAttr sourceAttr) { auto convertReductionOp = [](std::optional op) -> std::optional { - if (!op.has_value()) + if (!op.has_value()) { return std::nullopt; + } return static_cast(op.value()); }; return IREE::HAL::CollectiveAttr::get( @@ -1201,11 +1202,13 @@ static void insertSerializationBarriers(Location loc, Block &block, // Note that we can't mutate the block while iterating it so we first grab // all the original ops. SmallVector serialOps; - for (auto &op : block) + for (auto &op : block) { serialOps.push_back(&op); + } for (auto *op : serialOps) { - if (op->hasTrait()) + if (op->hasTrait()) { continue; + } builder.setInsertionPointAfter(op); IREE::HAL::CommandBufferExecutionBarrierOp::create( builder, loc, commandBuffer, sourceStage, targetStage, flags); @@ -1711,10 +1714,12 @@ struct GlobalTimepointConversionPattern matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto initialValue = op.getInitialValue(); - if (!initialValue.has_value()) + if (!initialValue.has_value()) { return failure(); - if (!isa(*initialValue)) + } + if (!isa(*initialValue)) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.removeInitialValueAttr(); }); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp index 79763bc7cc80..a4a629358ce1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp @@ -104,8 +104,9 @@ lookupAllocatorAndQueueAffinityFor(Operation *op, Value memoryTypes, Value getOrCreateWaitFence(Location loc, Value timepointFence, PatternRewriter &rewriter) { - if (timepointFence) + if (timepointFence) { return timepointFence; + } return IREE::Util::NullOp::create(rewriter, loc, rewriter.getType()); } @@ -115,18 +116,21 @@ Value getOrCreateWaitFence(Location loc, Value timepointFence, // it is the only consumer of the timepoint and it is removed upon return. static Value consumeBoundFence(Value timepoint, PatternRewriter &rewriter) { // Must only have one use. We can't consume a fence multiple times. - if (!timepoint.hasOneUse()) + if (!timepoint.hasOneUse()) { return nullptr; // >1 use + } // The use must be an export to a fence. auto chainOp = dyn_cast( *timepoint.getUsers().begin()); - if (!chainOp) + if (!chainOp) { return nullptr; // non-export use + } assert(!chainOp.getExternalValues().empty()); auto fence = chainOp.getExternalValues().front(); - if (!fence || !isa(fence.getType())) + if (!fence || !isa(fence.getType())) { return nullptr; + } // Try really hard to figure out if the fence can be used. A larger analysis // pass running prior to conversion that did some code motion could help @@ -157,8 +161,9 @@ Value getOrCreateSignalFence(Location loc, Value device, Value timepoint, // Check to see if the timepoint is associated with a fence. In common cases // when along ABI boundaries we can usually find an association. auto fence = consumeBoundFence(timepoint, rewriter); - if (fence) + if (fence) { return fence; + } // Create a new fence. return IREE::HAL::FenceCreateOp::create( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp index 0abf92970a0c..b5f633c5f4ca 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/UtilToHAL/Patterns.cpp @@ -24,8 +24,9 @@ struct GlobalConversionPattern matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newType = getTypeConverter()->convertType(op.getType()); - if (newType == op.getType()) + if (newType == op.getType()) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { // NOTE: the initial value may be invalid here! We rely on // dialect-specific conversions to handle it. diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index afbed1294831..f1cbee7207e5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp @@ -833,8 +833,9 @@ Value IREE::HAL::DeviceSelectAttr::buildDeviceEnumeration( auto deviceAttr = deviceAttrs.front(); Value tryDevice = deviceAttr.buildDeviceEnumeration( loc, buildDeviceTargetMatch, tryBuilder); - if (deviceAttrs.size() == 1) + if (deviceAttrs.size() == 1) { return tryDevice; // termination case + } Value isNull = IREE::Util::CmpEQOp::create(tryBuilder, loc, tryDevice, nullDevice); auto ifOp = @@ -868,8 +869,9 @@ Attribute DeviceAffinityAttr::parse(AsmParser &p, Type type) { queueMask = 0; if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { int64_t i = 0; - if (failed(p.parseInteger(i))) + if (failed(p.parseInteger(i))) { return failure(); + } queueMask |= 1ll << i; return success(); }))) { @@ -991,8 +993,9 @@ Attribute DevicePromiseAttr::parse(AsmParser &p, Type type) { queueMask = 0; if (failed(p.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() { int64_t i = 0; - if (failed(p.parseInteger(i))) + if (failed(p.parseInteger(i))) { return failure(); + } queueMask |= 1ll << i; return success(); }))) { @@ -1119,10 +1122,12 @@ bool DeviceTopologyAttr::hasTransparentAccess( Attribute sourceDevice = getAffinityDevice(source); Attribute targetDevice = getAffinityDevice(target); - if (!sourceDevice || !targetDevice) + if (!sourceDevice || !targetDevice) { return false; - if (sourceDevice == targetDevice) + } + if (sourceDevice == targetDevice) { return true; // Same device has transparent access. + } // Search for a matching link and check if it has transparent access. for (DeviceLinkAttr link : getLinks()) { @@ -1140,10 +1145,12 @@ bool DeviceTopologyAttr::hasUnifiedMemory( Attribute sourceDevice = getAffinityDevice(source); Attribute targetDevice = getAffinityDevice(target); - if (!sourceDevice || !targetDevice) + if (!sourceDevice || !targetDevice) { return false; - if (sourceDevice == targetDevice) + } + if (sourceDevice == targetDevice) { return true; // Same device has unified memory. + } // Search for a matching link and check if it has unified memory. for (DeviceLinkAttr link : getLinks()) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index b984b4a74ceb..a03d8ca44067 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -34,8 +34,9 @@ struct ElideUnusedOp : public OpRewritePattern { : OpRewritePattern(context, /*benefit=*/1000) {} LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - if (!op.use_empty()) + if (!op.use_empty()) { return failure(); + } rewriter.eraseOp(op); return success(); } @@ -230,8 +231,9 @@ struct FoldBufferViewCreateSubspan needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceBufferMutable().assign(newSourceBuffer); op.getSourceOffsetMutable().assign(newSourceOffset); @@ -318,8 +320,9 @@ struct FoldCommandBufferFillBufferSubspans needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getTargetBufferMutable().assign(newTargetBuffer); op.getTargetOffsetMutable().assign(newTargetOffset); @@ -358,8 +361,9 @@ struct FoldCommandBufferUpdateBufferSubspans needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getTargetBufferMutable().assign(newTargetBuffer); op.getTargetOffsetMutable().assign(newTargetOffset); @@ -408,8 +412,9 @@ struct FoldCommandBufferCopyBufferSubspans needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceBufferMutable().assign(newSourceBuffer); op.getSourceOffsetMutable().assign(newSourceOffset); @@ -444,8 +449,9 @@ struct FoldCommandBufferDispatchBufferSubspan : public OpRewritePattern { auto bindingOffsets = llvm::to_vector(op.getBindingOffsets()); for (size_t i = 0; i < bindingBuffers.size(); ++i) { auto *definingOp = bindingBuffers[i].getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } if (auto subspanOp = dyn_cast(definingOp)) { needsUpdate = true; bindingBuffers[i] = subspanOp.getSourceBuffer(); @@ -454,8 +460,9 @@ struct FoldCommandBufferDispatchBufferSubspan : public OpRewritePattern { } } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { auto mutableBindingBuffers = op.getBindingBuffersMutable(); mutableBindingBuffers.clear(); @@ -489,8 +496,9 @@ struct FoldCommandBufferDispatchIndirectBufferSubspan PatternRewriter &rewriter) const override { Value workgroupsBuffer = op.getWorkgroupsBuffer(); auto *definingOp = workgroupsBuffer.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return failure(); + } Value workgroupsOffset = op.getWorkgroupsOffset(); if (auto subspanOp = dyn_cast(definingOp)) { workgroupsBuffer = subspanOp.getSourceBuffer(); @@ -526,18 +534,21 @@ void CommandBufferDispatchIndirectOp::getCanonicalizationPatterns( // same basic block. We need an abstract interpreter to do much more as we'd // need to track conditionals/branching logic. static bool isOpAlwaysExecutedWith(Operation *before, Operation *after) { - if (before == after) + if (before == after) { return true; - if (before->getBlock() != after->getBlock()) + } + if (before->getBlock() != after->getBlock()) { return false; + } return before->isBeforeInBlock(after); } // Returns true if |op| was hoisted before |insertBefore| without breaking // SSA invariants. Returns false if no IR modifications were made. static bool tryHoistOpBeforeUser(Operation *op, Operation *insertBefore) { - if (op == insertBefore) + if (op == insertBefore) { return false; + } // Currently conservative - should be doing a domination check. if (op->getBlock() != insertBefore->getBlock()) { @@ -811,11 +822,13 @@ static void rewriteToOneReturn(int numResults, Region ®ion, // Get all of the return ops - if there's only one then the requirement is // already satisfied and we can exit early. auto returnOps = llvm::to_vector(region.getOps()); - if (returnOps.size() <= 1) + if (returnOps.size() <= 1) { return; // no-op + } SmallVector returnLocs; - for (auto returnOp : returnOps) + for (auto returnOp : returnOps) { returnLocs.push_back(returnOp.getLoc()); + } // Create the new exit block with arguments matching 1:1 with results. auto anyReturnOp = returnOps.front(); @@ -860,8 +873,9 @@ struct MergeExecutableConstantBlocks SmallVector resultLocs; for (auto blockOp : blockOps) { blockLocs.push_back(blockOp.getLoc()); - if (blockOp.getNumArguments() > 0) + if (blockOp.getNumArguments() > 0) { anyRequireDevice = true; + } llvm::append_range(resultTypes, blockOp.getResultTypes()); llvm::append_range(resultKeys, blockOp.getKeys().getValue()); llvm::append_range( @@ -967,8 +981,9 @@ static void filterReturnOperands(ExecutableConstantBlockOp blockOp, llvm::make_early_inc_range(blockOp.getOps())) { SmallVector operands; for (auto [i, operand] : llvm::enumerate(returnOp.getOperands())) { - if (preservedIndices.test(i)) + if (preservedIndices.test(i)) { operands.push_back(operand); + } } returnOp.getOperandsMutable().assign(operands); } @@ -980,11 +995,13 @@ struct DropUnusedExecutableConstantBlockDeviceArg using Base::Base; LogicalResult matchAndRewrite(ExecutableConstantBlockOp blockOp, PatternRewriter &rewriter) const override { - if (blockOp.getNumArguments() == 0) + if (blockOp.getNumArguments() == 0) { return failure(); + } auto deviceArg = blockOp.getArgument(0); - if (!deviceArg.use_empty()) + if (!deviceArg.use_empty()) { return failure(); + } rewriter.modifyOpInPlace(blockOp, [&]() { // Type conversion here shouldn't fail. (void)blockOp.eraseArgument(0); @@ -1057,8 +1074,9 @@ void FenceCreateOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// OpFoldResult FenceJoinOp::fold(FoldAdaptor operands) { - if (getFences().size() == 1) + if (getFences().size() == 1) { return getFences().front(); + } return {}; } @@ -1069,8 +1087,9 @@ struct ElideEmptyFenceJoin : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(FenceJoinOp op, PatternRewriter &rewriter) const override { - if (op.getNumOperands() != 0) + if (op.getNumOperands() != 0) { return failure(); + } rewriter.replaceOpWithNewOp(op, op.getResult().getType()); return success(); @@ -1091,8 +1110,9 @@ deduplicateFenceOperands(ValueRange operands) { newOperands.insert(operand); } - if (newOperands.size() == operands.size()) + if (newOperands.size() == operands.size()) { return std::nullopt; + } return newOperands.takeVector(); } @@ -1102,8 +1122,9 @@ struct DeduplicateFenceJoinFences : public OpRewritePattern { LogicalResult matchAndRewrite(FenceJoinOp op, PatternRewriter &rewriter) const override { auto newOperands = deduplicateFenceOperands(op.getFences()); - if (!newOperands) + if (!newOperands) { return failure(); + } rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getFlagsAttr(), newOperands.value()); return success(); @@ -1143,8 +1164,9 @@ struct ElideSignaledFence : public OpRewritePattern { auto fence = signalOp.getFence(); auto createOp = dyn_cast_if_present(fence.getDefiningOp()); - if (!createOp) + if (!createOp) { return failure(); + } // TODO(benvanik): broader analysis - likely in a dedicated fence elision // pass so we can do IPO. For now block-only. @@ -1194,8 +1216,9 @@ struct ElideEmptyFenceAwait : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(FenceAwaitOp op, PatternRewriter &rewriter) const override { - if (!op.getFences().empty()) + if (!op.getFences().empty()) { return failure(); + } rewriter.replaceOpWithNewOp(op, /*ok=*/0, 32); return success(); } @@ -1207,8 +1230,9 @@ struct DeduplicateFenceAwaitFences : public OpRewritePattern { LogicalResult matchAndRewrite(FenceAwaitOp op, PatternRewriter &rewriter) const override { auto newOperands = deduplicateFenceOperands(op.getFences()); - if (newOperands == std::nullopt) + if (newOperands == std::nullopt) { return failure(); + } // TODO(benvanik): resolve flag sets. rewriter.replaceOpWithNewOp( op, op.getStatus().getType(), op.getTimeoutMillis(), op.getFlagsAttr(), diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 157e82cf0651..445ba69e2870 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -120,12 +120,14 @@ static void printDeviceQueueAffinityList(OpAsmPrinter &p, Operation *, static ParseResult parseDescriptorType(OpAsmParser &parser, DescriptorTypeAttr &dtAttr) { StringRef enumKeyword; - if (failed(parser.parseKeyword(&enumKeyword))) + if (failed(parser.parseKeyword(&enumKeyword))) { return failure(); + } std::optional maybeEnum = symbolizeDescriptorType(enumKeyword); - if (!maybeEnum) + if (!maybeEnum) { return failure(); + } dtAttr = DescriptorTypeAttr::get(parser.getContext(), *maybeEnum); return success(); } @@ -381,8 +383,9 @@ static ParseResult parseTargetConditionRegion(OpAsmParser &parser, static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p.printRegionArgument(arg); }); @@ -403,15 +406,17 @@ static ParseResult parseTargetConditionObjects( do { // #hal.executable.target<...> Attribute targetAttr; - if (failed(parser.parseAttribute(targetAttr))) + if (failed(parser.parseAttribute(targetAttr))) { return failure(); + } targetsAttrs.push_back(targetAttr); // if(...) -> i1 { ... } auto region = std::make_unique(); if (succeeded(parser.parseOptionalKeyword("if"))) { - if (failed(parseTargetConditionRegion(parser, *region))) + if (failed(parseTargetConditionRegion(parser, *region))) { return failure(); + } } targetRegions.push_back(std::move(region)); @@ -421,15 +426,17 @@ static ParseResult parseTargetConditionObjects( failed(parser.parseLParen()) || failed(parser.parseAttribute(targetOrdinalAttr, IndexType::get(parser.getContext()))) || - failed(parser.parseRParen())) + failed(parser.parseRParen())) { return failure(); + } targetOrdinalsAttrs.push_back(targetOrdinalAttr); // = [#hal.executable.object<...>, ...] ArrayAttr targetObjectsAttr; if (failed(parser.parseEqual()) || - failed(parser.parseAttribute(targetObjectsAttr))) + failed(parser.parseAttribute(targetObjectsAttr))) { return failure(); + } targetObjectsAttrs.push_back(targetObjectsAttr); } while (succeeded(parser.parseOptionalComma())); targetsAttr = ArrayAttr::get(parser.getContext(), targetsAttrs); @@ -506,8 +513,9 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p.printRegionArgument(arg); }); @@ -550,8 +558,9 @@ static ParseResult parseExportConditionRegion(OpAsmParser &parser, static void printExportConditionRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "("; llvm::interleaveComma(body.getArguments(), p, [&](BlockArgument arg) { p.printRegionArgument(arg); }); @@ -627,8 +636,9 @@ void TensorImportOp::build(OpBuilder &builder, OperationState &result, "information is required"); SmallVector dynamicDims; for (int64_t i = 0; i < shapedType.getRank(); ++i) { - if (!shapedType.isDynamicDim(i)) + if (!shapedType.isDynamicDim(i)) { continue; + } dynamicDims.push_back(builder.createOrFold( result.location, builder.getIndexType(), source, builder.getIndexAttr(i))); @@ -641,12 +651,14 @@ void TensorImportOp::build(OpBuilder &builder, OperationState &result, static LogicalResult verifyTypeStorageCompatibility(Operation *op, Type encodingType, Type storageType) { - if (encodingType == storageType) + if (encodingType == storageType) { return success(); + } auto encodingShapedType = dyn_cast(encodingType); auto storageShapedType = dyn_cast(storageType); - if (!encodingShapedType || !storageShapedType) + if (!encodingShapedType || !storageShapedType) { return success(); + } if (IREE::Util::getRoundedElementByteWidth( encodingShapedType.getElementType()) != @@ -832,8 +844,9 @@ void DispatchExternOp::build(OpBuilder &builder, OperationState &state, state.addRegion(); // Add one empty region per target. - for (size_t i = 0; i < targetObjects.getTargets().size(); ++i) + for (size_t i = 0; i < targetObjects.getTargets().size(); ++i) { state.addRegion(); + } } // Verifies that |dynamicDims| contains the appropriate number of dims for all @@ -885,8 +898,9 @@ static LogicalResult verifyWorkgroupCountWorkload(Operation *op, // Verifies that the workgroup count region matches the expected // signature. Returns success if the region is empty. static LogicalResult verifyWorkgroupCountRegion(Operation *op, Region ®ion) { - if (region.empty()) + if (region.empty()) { return success(); + } // Verify one of the supported signatures. bool validArguments = true; @@ -946,12 +960,14 @@ LogicalResult DispatchExternOp::verify() { return success(); }; for (auto type : getOperandTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } for (auto type : getResultTypes()) { - if (failed(verifyIOType(type))) + if (failed(verifyIOType(type))) { return failure(); + } } if (failed(verifyWorkgroupCountRegion(op, getWorkgroupCount()))) { @@ -1219,16 +1235,18 @@ LogicalResult ElementTypeOp::verify() { // static std::optional EncodingTypeOp::getTypeValue(Attribute attr) { // TODO(#6762): encoding attribute handling/mapping to enums. - if (attr) + if (attr) { return std::nullopt; + } // Default to IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR for now. return 1; } void EncodingTypeOp::getAsmResultNames( function_ref setNameFn) { - if (!getEncodingAttr()) + if (!getEncodingAttr()) { setNameFn(getResult(), "dense_row_major"); + } } LogicalResult EncodingTypeOp::verify() { @@ -1637,9 +1655,10 @@ LogicalResult ExecutableSourceOp::verify() { ExecutableSourceOp op = *this; auto conditionOps = getOps(); - if (llvm::range_size(conditionOps) > 1) + if (llvm::range_size(conditionOps) > 1) { return op.emitOpError() << "only one condition op is allowed in an executable"; + } return success(); } @@ -1668,8 +1687,9 @@ LogicalResult ExecutableOp::verify() { // signature. Returns success if the region is empty. static LogicalResult verifyExportConditionRegion(Operation *op, Region ®ion) { - if (region.empty()) + if (region.empty()) { return success(); + } // Verify one of the supported signatures. bool validArguments = true; @@ -1937,8 +1957,9 @@ LogicalResult ExecutableVariantOp::verify() { ExecutableVariantOp op = *this; auto conditionOps = getOps(); - if (llvm::range_size(conditionOps) > 1) + if (llvm::range_size(conditionOps) > 1) { return op.emitOpError() << "only one condition op is allowed in a variant"; + } return success(); } @@ -2016,13 +2037,15 @@ void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result, ParseResult ExecutableConditionOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseTargetConditionRegion(parser, *result.addRegion())) + if (parseTargetConditionRegion(parser, *result.addRegion())) { return failure(); + } result.addAttribute( "function_type", TypeAttr::get(getTargetConditionRegionType(parser.getContext()))); - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) { return failure(); + } return success(); } @@ -2068,8 +2091,9 @@ ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, return failure(); } SmallVector argTypes; - for (auto &arg : entryArgs) + for (auto &arg : entryArgs) { argTypes.push_back(arg.type); + } auto fnType = builder.getFunctionType(argTypes, resultTypes); result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); @@ -2078,20 +2102,23 @@ ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser, // There must be one key per result. Note that we support omitted parens when // only one result is present. SmallVector keyAttrs; - if (failed(parser.parseKeyword("as"))) + if (failed(parser.parseKeyword("as"))) { return failure(); + } if (resultTypes.size() == 1) { std::string key; - if (failed(parser.parseString(&key))) + if (failed(parser.parseString(&key))) { return failure(); + } keyAttrs.push_back(builder.getStringAttr(key)); } else { if (failed(parser.parseCommaSeparatedList( AsmParser::Delimiter::OptionalParen, [&]() { std::string key; - if (failed(parser.parseString(&key))) + if (failed(parser.parseString(&key))) { return failure(); + } keyAttrs.push_back(builder.getStringAttr(key)); return success(); }, @@ -2138,12 +2165,14 @@ void ExecutableConstantBlockOp::print(OpAsmPrinter &p) { p, cast(op), argTypes, /*isVariadic=*/false, resultTypes); p << " as "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << '('; + } llvm::interleaveComma(getKeys().getValue(), p, [&](Attribute attr) { p << attr; }); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ')'; + } mlir::function_interface_impl::printFunctionAttributes( p, op, {getFunctionTypeAttrName(), getKeysAttrName()}); p << " "; @@ -2316,20 +2345,23 @@ llvm::Align InterfaceBindingSubspanOp::calculateAlignment() { // If the binding has no assigned alignment we fall back to natural alignment. auto baseAlignment = getBaseAlignment(); - if (!baseAlignment) + if (!baseAlignment) { return naturalAlignment; + } // If there's no offset specified then we can use the binding alignment // directly. - if (!getByteOffset()) + if (!getByteOffset()) { return baseAlignment.value(); + } // Try to get the alignment of the byte offset. If it's a constant then we can // find a common alignment between it and the base and otherwise we need to // try to infer the alignment from the IR - otherwise we fall back. auto offsetOrAlignment = lookupOffsetOrAlignment(getByteOffset()); - if (!offsetOrAlignment.has_value()) + if (!offsetOrAlignment.has_value()) { return naturalAlignment; + } // Compute the common alignment between that of the binding base and that of // the byte offset. diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index c2ecfa3d94bb..340677b79dc5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -21,8 +21,9 @@ namespace mlir::iree_compiler::IREE::HAL { //===----------------------------------------------------------------------===// llvm::MaybeAlign commonAlignment(llvm::MaybeAlign lhs, llvm::MaybeAlign rhs) { - if (!lhs.has_value() || !rhs.has_value()) + if (!lhs.has_value() || !rhs.has_value()) { return std::nullopt; + } return llvm::MaybeAlign( llvm::MinAlign(lhs.value().value(), rhs.value().value())); } @@ -37,8 +38,9 @@ std::optional lookupOffsetOrAlignment(Value value) { } auto op = value.getDefiningOp(); - if (!op) + if (!op) { return std::nullopt; + } if (auto alignmentAttr = op->getAttrOfType("stream.alignment")) { // The op has an alignment tagged on it we can use directly. return alignmentAttr.getValue().getZExtValue(); @@ -107,8 +109,9 @@ void HALDialect::registerTypes() { Type HALDialect::parseType(DialectAsmParser &parser) const { StringRef typeKind; - if (parser.parseKeyword(&typeKind)) + if (parser.parseKeyword(&typeKind)) { return {}; + } auto type = llvm::StringSwitch(typeKind) .Case("allocator", AllocatorType::get(getContext())) .Case("buffer", BufferType::get(getContext())) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp index e341232dc131..9d5b1d0c3869 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp @@ -49,16 +49,21 @@ void TargetOptions::bindOptions(OptionsBinder &binder) { "executable files (sources, benchmarks, intermediates, binaries) " "to."), llvm::cl::callback([&](const std::string &path) { - if (executableSourcesPath.empty()) + if (executableSourcesPath.empty()) { executableSourcesPath = path; - if (executableConfigurationsPath.empty()) + } + if (executableConfigurationsPath.empty()) { executableConfigurationsPath = path; - if (executableBenchmarksPath.empty()) + } + if (executableBenchmarksPath.empty()) { executableBenchmarksPath = path; - if (executableIntermediatesPath.empty()) + } + if (executableIntermediatesPath.empty()) { executableIntermediatesPath = path; - if (executableBinariesPath.empty()) + } + if (executableBinariesPath.empty()) { executableBinariesPath = path; + } }), llvm::cl::cat(halTargetOptionsCategory)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp index 34687e438195..23eae255b440 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp @@ -155,8 +155,9 @@ bool llvm::cl::parser::parse(Option &O, StringRef ArgName, // We ignore Arg here and just use the global registry. We could parse a list // of target backends and create a new registry with just that subset but // ownership gets tricky. - if (Arg != "global") + if (Arg != "global") { return true; + } Val.value = &mlir::iree_compiler::IREE::HAL::TargetRegistry::getGlobal(); return false; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp index 337cfc0950a9..bbc81d624e2e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/AnnotateTargetDevices.cpp @@ -91,8 +91,9 @@ static void annotateOperandsAndResults(Operation *op, static void annotateFuncOp(FunctionOpInterface funcOp, DeviceAnalysis &deviceAnalysis) { - if (funcOp.empty()) + if (funcOp.empty()) { return; + } for (auto arg : funcOp.front().getArguments()) { if (isa(arg.getType())) { funcOp.setArgAttr( @@ -117,8 +118,9 @@ struct AnnotateTargetDevicesPass // Annotate all ops with derived affinities. for (auto &op : getOperation().getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } if (auto globalOp = dyn_cast(op)) { annotateGlobalOp(globalOp, deviceAnalysis); } else if (auto funcOp = dyn_cast(op)) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp index 31424147a4f7..cd3c5bfd86d5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CaptureExecutableSources.cpp @@ -35,8 +35,9 @@ static void insertDictionaryAttrEntry(Operation *op, StringRef dictionaryName, StringRef key, Attribute value) { NamedAttrList attrs; auto dictionaryAttr = op->getAttrOfType(dictionaryName); - if (dictionaryAttr) + if (dictionaryAttr) { attrs.assign(dictionaryAttr.getValue()); + } attrs.set(key, value); op->setAttr(dictionaryName, DictionaryAttr::get(op->getContext(), attrs)); } @@ -67,15 +68,17 @@ struct CaptureExecutableSourcesPass for (auto variantOp : executableOp.getOps()) { // Skip externally defined variants as there's no source to capture. - if (variantOp.isExternal()) + if (variantOp.isExternal()) { continue; + } // Ignore if there is already source assigned. auto fileName = (moduleName + "_" + executableOp.getName() + "_" + variantOp.getName() + "." + stage + ".mlir") .str(); - if (hasDictionaryAttrEntry(variantOp, "sources", fileName)) + if (hasDictionaryAttrEntry(variantOp, "sources", fileName)) { continue; + } // Create a standalone executable with just the variant being captured. // This allows the source to be passed to iree-compile in the diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp index 5e2e466fa7fb..13effa02246d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ConfigureExecutables.cpp @@ -49,8 +49,9 @@ class ConfigureTargetExecutableVariantsPass void runOnOperation() override { IREE::HAL::ExecutableVariantOp variantOp = getOperation(); - if (variantOp.getTarget().getBackend().getValue() != target) + if (variantOp.getTarget().getBackend().getValue() != target) { return; + } auto targetBackend = targetRegistry->getTargetBackend(target); if (!targetBackend) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index e2a500172300..0ec0cde692f1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -445,8 +445,9 @@ buildBenchmarkModule(IREE::HAL::ExecutableOp sourceExecutableOp, } // Skip the file when we could not generate any benchmarks. - if (!hasAnyBenchmarks) + if (!hasAnyBenchmarks) { return {}; + } IRRewriter rewriter(moduleOp->getContext()); DominanceInfo domInfo; @@ -478,8 +479,9 @@ struct DumpExecutableBenchmarksPass SymbolTable symbolTable(moduleOp); DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + if (failed(deviceAnalysis.run())) { return signalPassFailure(); + } if (deviceAnalysis.getDeviceGlobals().empty()) { mlir::emitRemark(moduleOp.getLoc()) << "Executable benchmarks were requested but no devices were " @@ -516,8 +518,9 @@ struct DumpExecutableBenchmarksPass executableOp.getOps()) { auto benchmarkModuleOp = buildBenchmarkModule( executableOp, variantOp, dispatchParamsMap, deviceAnalysis); - if (!benchmarkModuleOp) + if (!benchmarkModuleOp) { continue; + } auto fileName = (moduleName + "_" + executableOp.getName() + "_" + variantOp.getName() + "_benchmark.mlir") .str(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp index f496b19cc7a9..d2d284ed72c6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/ElideRedundantCommands.cpp @@ -91,8 +91,9 @@ struct ElideRedundantCommandsPass stateMap[commandBuffer].previousFullBarrier = {}; }; for (auto &op : llvm::make_early_inc_range(block.getOperations())) { - if (!op.getDialect()) + if (!op.getDialect()) { continue; + } TypeSwitch(&op) .Case([&](IREE::HAL::CommandBufferFinalizeOp op) { invalidateState(op.getCommandBuffer()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp index 96ff5613f0ee..7a75fc5b992d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp @@ -28,8 +28,9 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { static std::string getAttrStr(Attribute attr) { - if (!attr) + if (!attr) { return ""; + } std::string result; llvm::raw_string_ostream os(result); attr.print(os, /*elideType=*/true); @@ -69,8 +70,9 @@ static Value createChunkHeader(Location loc, iree_idbts_chunk_type_t type, static Value createPadding(Location loc, uint64_t unalignedLength, OpBuilder &builder) { uint64_t padding = llvm::alignTo(unalignedLength, 16) - unalignedLength; - if (!padding) + if (!padding) { return nullptr; + } auto i8Type = builder.getI8Type(); auto zeroAttr = IntegerAttr::get(i8Type, 0); auto dataAttr = DenseElementsAttr::get( @@ -107,8 +109,9 @@ struct MaterializeDispatchInstrumentationPass MaterializeDispatchInstrumentationPassBase; void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } auto moduleBuilder = OpBuilder(&moduleOp.getBody()->front()); auto i8Type = moduleBuilder.getI8Type(); @@ -170,8 +173,9 @@ struct MaterializeDispatchInstrumentationPass for (auto exportOp : executableOp.getOps()) { auto funcOp = exportOp.lookupFunctionRef(); - if (!funcOp) + if (!funcOp) { continue; + } // Capture the source before we mess with it. auto originalSource = getOpStr(funcOp); @@ -256,8 +260,9 @@ struct MaterializeDispatchInstrumentationPass break; } } - if (!functionId) + if (!functionId) { return; // not instrumented + } // Append dispatch site ID to correlate this op with where it lives in // the program and what is being dispatched. Note that multiple diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index ac0ead563b77..4d0022eff77f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -295,8 +295,9 @@ convertBindingUsage(mlir::FunctionOpInterface sourceFuncOp, BlockArgument arg, IREE::HAL::PipelineLayoutAttr pipelineLayoutAttr, int64_t bindingOrdinal, IREE::HAL::PipelineBindingAttr bindingAttr) { - if (arg.use_empty()) + if (arg.use_empty()) { return; // no-op + } for (auto &use : llvm::make_early_inc_range(arg.getUses())) { auto oldOp = dyn_cast(use.getOwner()); assert(oldOp && "bindings are only usable by stream.binding.subspan"); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index e23253927427..85ab0fe7780b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -335,8 +335,9 @@ getDeviceFallbackGlobals(IREE::Util::GlobalOpInterface deviceGlobal, SymbolTable &symbolTable) { SetVector resultSet; auto processAttr = [&](Attribute attr) { - if (!attr) + if (!attr) { return true; // ignore uninitialized devices + } return TypeSwitch(attr) .Case([](auto attr) { return true; }) .Case([](auto attr) { return true; }) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp index 48a7b97a1baa..78a353eeb86c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp @@ -85,8 +85,9 @@ struct MemoizeDeviceQueriesPass // we can't memoize the query today. auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(queryOp.getDevice()); - if (!deviceGlobals || deviceGlobals->size() != 1) + if (!deviceGlobals || deviceGlobals->size() != 1) { return WalkResult::advance(); + } IREE::Util::GlobalOpInterface deviceGlobalOp = deviceGlobals->front(); // Construct key used to dedupe/lookup the query. diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp index 63900f1d03fd..7fb2dde0442f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PreprocessExecutables.cpp @@ -144,8 +144,9 @@ static LogicalResult preprocessWithCommand(IREE::HAL::ExecutableOp executableOp, #endif // _WIN32 Tokenize(command, stringSaver, rawArgs, /*MarkEOLs=*/false); SmallVector args; - for (auto rawArg : rawArgs) + for (auto rawArg : rawArgs) { args.push_back(StringRef(rawArg)); + } // Try to find the tool either by absolute path or by looking it up in env. auto tool = findTool(args[0].str()); @@ -156,8 +157,9 @@ static LogicalResult preprocessWithCommand(IREE::HAL::ExecutableOp executableOp, LLVM_DEBUG({ llvm::dbgs() << "Launching hal.executable preprocessor: "; - for (auto arg : args) + for (auto arg : args) { llvm::dbgs() << arg << " "; + } llvm::dbgs() << " 1> " << stdoutFile.str() << " 2> " << stderrFile.str() << "\n"; }); @@ -242,8 +244,9 @@ struct PreprocessExecutablesWithPipelinePass } void runOnOperation() override { - if (!pipeline.hasValue()) + if (!pipeline.hasValue()) { return; + } IREE::HAL::ExecutableOp executableOp = getOperation(); OpPassManager passManager(executableOp.getOperationName()); if (failed(buildPassPipeline(pipeline, passManager))) { @@ -270,8 +273,9 @@ struct PreprocessExecutablesWithToolPass using IREE::HAL::impl::PreprocessExecutablesWithToolPassBase< PreprocessExecutablesWithToolPass>::PreprocessExecutablesWithToolPassBase; void runOnOperation() override { - if (!command.hasValue()) + if (!command.hasValue()) { return; + } IREE::HAL::ExecutableOp executableOp = getOperation(); if (failed(preprocessWithCommand(executableOp, command))) { llvm::errs() << "ERROR: failed to preprocess executable `" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp index 6768ae50c5c3..19fd69e6a205 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/PruneExecutables.cpp @@ -35,8 +35,9 @@ static void markReferenced(SymbolRefAttr symbolRefAttr, ? SymbolRefAttr::get(rootRefAttr) : SymbolRefAttr::get(rootRefAttr, nestedRefAttrs); auto it = referenceMap.find(nestedRefAttr); - if (it != referenceMap.end()) + if (it != referenceMap.end()) { ++it->second.count; + } }; auto rootRefAttr = symbolRefAttr.getRootReference(); auto nestedRefAttrs = symbolRefAttr.getNestedReferences(); @@ -47,8 +48,9 @@ static void markReferenced(SymbolRefAttr symbolRefAttr, static void processOp(Operation *op, SymbolReferenceMap &referenceMap) { SmallVector worklist; - for (auto namedAttr : op->getAttrs()) + for (auto namedAttr : op->getAttrs()) { worklist.push_back(namedAttr.getValue()); + } while (!worklist.empty()) { auto attr = worklist.pop_back_val(); if (auto symbolRefAttr = dyn_cast(attr)) { @@ -107,8 +109,9 @@ struct PruneExecutablesPass SetVector exportRefAttrs; for (auto executableOp : moduleOp.getOps()) { ignoredOps.insert(executableOp); - if (!executableOp.isPrivate()) + if (!executableOp.isPrivate()) { continue; + } auto executableRefAttr = FlatSymbolRefAttr::get(executableOp.getSymNameAttr()); referenceMap[executableRefAttr].symbolOp = executableOp; @@ -156,8 +159,9 @@ struct PruneExecutablesPass // accumulate the usage counts. SymbolTable symbolTable(moduleOp); moduleOp.walk([&](Operation *op) -> WalkResult { - if (ignoredOps.contains(op)) + if (ignoredOps.contains(op)) { return WalkResult::skip(); + } processOp(op, referenceMap); return op->hasTrait() ? WalkResult::skip() diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp index 9cbaec880ab2..2a805a30e6f1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SerializeExecutables.cpp @@ -78,8 +78,9 @@ struct SerializeTargetExecutablesPass auto variantOps = llvm::to_vector( executableOp.getBlock().getOps()); for (auto variantOp : variantOps) { - if (variantOp.getTarget().getBackend().getValue() != target) + if (variantOp.getTarget().getBackend().getValue() != target) { continue; + } OpBuilder executableBuilder(variantOp); // Ask the target backend to serialize the executable. Note that it // may create one or more hal.executable.binary ops in the case of diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp index 35bf99d46e46..80dc320476ab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp @@ -42,8 +42,9 @@ scanSearchPath(std::string prefix, StringRef searchPath, dir != dir_end && !ec; dir.increment(ec)) { auto childPath = dir->path(); llvm::sys::fs::file_status status; - if (llvm::sys::fs::status(childPath, status)) + if (llvm::sys::fs::status(childPath, status)) { continue; + } switch (status.type()) { case llvm::sys::fs::file_type::regular_file: case llvm::sys::fs::file_type::symlink_file: @@ -104,8 +105,9 @@ replaceExecutableOpWithMLIR(IREE::HAL::ExecutableOp executableOp, // Load the replacement IR. It may have any mix of stuff in it including // multiple other executables. auto rootOpRef = loadModuleObject(executableOp.getContext(), filePath); - if (!rootOpRef) + if (!rootOpRef) { return failure(); + } IREE::HAL::ExecutableOp replacementOp; if (auto moduleOp = dyn_cast(rootOpRef.get())) { // We expect a `hal.executable` with the same name as the one we are @@ -165,8 +167,9 @@ externalizeExecutableOp(IREE::HAL::ExecutableOp executableOp, auto fileObjectAttr = builder.getAttr( builder.getStringAttr(filePath), nullptr); auto fileContents = fileObjectAttr.loadData(); - if (!fileContents) + if (!fileContents) { return failure(); + } // Link the referenced object file contents. We fully replace the existing // objects in case there were any as this does entire executable replacement - @@ -243,8 +246,9 @@ struct SubstituteExecutablesPass uniqueSubstitutions[std::string(key)] = value; } - if (uniqueSubstitutions.empty()) + if (uniqueSubstitutions.empty()) { return; // no-op + } // Walk each substitution and process the matching executable if found. for (auto &[executableName, filePath] : uniqueSubstitutions) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp index f1f39fa64918..391ff79be794 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp @@ -51,10 +51,12 @@ struct TranslateTargetExecutableVariantsPass void runOnOperation() override { IREE::HAL::ExecutableVariantOp variantOp = getOperation(); - if (variantOp.getTarget().getBackend().getValue() != target) + if (variantOp.getTarget().getBackend().getValue() != target) { return; - if (variantOp.isExternal()) + } + if (variantOp.isExternal()) { return; + } auto targetBackend = targetRegistry->getTargetBackend(target); if (!targetBackend) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp index f860eaaa0269..5588cecbfa1a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.cpp @@ -89,8 +89,9 @@ loadBitcodeObject(IREE::HAL::ExecutableObjectAttr objectAttr, llvm::MemoryBufferRef bitcodeBufferRef(objectData.value(), objectAttr.getPath()); auto bitcodeModuleValue = llvm::parseBitcodeFile(bitcodeBufferRef, context); - if (!bitcodeModuleValue) + if (!bitcodeModuleValue) { return bitcodeModuleValue; + } // NOTE: at this point the bitcode may not have the expected data layout! return std::move(bitcodeModuleValue.get()); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp index 367162905ba8..31c5664d00be 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning.cpp @@ -156,8 +156,9 @@ LogicalResult Partition::verify(Location loc) { for (auto in : ins) { // Only check ops, not bare values. auto definingOp = in.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } // Collect all values used by this input op (including nested regions). SetVector inputConsumedValues; @@ -216,8 +217,9 @@ LogicalResult PartitionSet::verify(Location loc) { } void PartitionSet::topologicalSort() { - if (partitions.empty()) + if (partitions.empty()) { return; + } SetVector unsortedSet; DenseMap> consumers; @@ -246,8 +248,9 @@ void PartitionSet::topologicalSort() { } } }; - for (auto *partition : unsortedSet) + for (auto *partition : unsortedSet) { postorderWalk(partition); + } SmallVector sortedSet; sortedSet.reserve(partitions.size()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp index 0ae3d0d0997c..9f4c0e5ce7f9 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -64,8 +64,9 @@ struct PartitionBuilder { : affinityOp.getAffinityAttr(); } opInfo.membership.set(ordinal); - if (opInfo.hazards.size() > ordinal) + if (opInfo.hazards.size() > ordinal) { opInfo.hazards.reset(ordinal); + } ops.insert(op); hazards |= opInfo.hazards; hazards |= opInfo.nestedRegionHazards; @@ -497,8 +498,9 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, // First see which partitions are consuming this that we can also safely // move in to. consumers &= candidates; - if (consumers.any()) + if (consumers.any()) { candidates = consumers; + } opInfo.membership.reserve(builders.size() + 1); opInfo.membership.resize(builders.size(), /*t=*/false); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index cf0fe20d7339..268531a62615 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -132,37 +132,50 @@ class AbstractResourceUsage const std::string getAsStr(AsmState &asmState) const override { std::string str; - if (!isValidState()) + if (!isValidState()) { return "*"; + } auto append = [&](const char *part) { - if (!str.empty()) + if (!str.empty()) { str += '|'; + } str += part; }; - if (!this->isAssumed(NOT_INDIRECT)) + if (!this->isAssumed(NOT_INDIRECT)) { append("indirect"); + } append(this->isAssumed(NOT_EXTERNAL) ? "internal" : "external"); append(this->isAssumed(NOT_MUTATED) ? "immutable" : "mutable"); - if (!this->isAssumed(NOT_CONSTANT)) + if (!this->isAssumed(NOT_CONSTANT)) { append("constant"); - if (!this->isAssumed(NOT_TRANSFER_READ)) + } + if (!this->isAssumed(NOT_TRANSFER_READ)) { append("transfer_read"); - if (!this->isAssumed(NOT_TRANSFER_WRITE)) + } + if (!this->isAssumed(NOT_TRANSFER_WRITE)) { append("transfer_write"); - if (!this->isAssumed(NOT_STAGING_READ)) + } + if (!this->isAssumed(NOT_STAGING_READ)) { append("staging_read"); - if (!this->isAssumed(NOT_STAGING_WRITE)) + } + if (!this->isAssumed(NOT_STAGING_WRITE)) { append("staging_write"); - if (!this->isAssumed(NOT_DISPATCH_READ)) + } + if (!this->isAssumed(NOT_DISPATCH_READ)) { append("dispatch_read"); - if (!this->isAssumed(NOT_DISPATCH_WRITE)) + } + if (!this->isAssumed(NOT_DISPATCH_WRITE)) { append("dispatch_write"); - if (!this->isAssumed(NOT_GLOBAL_READ)) + } + if (!this->isAssumed(NOT_GLOBAL_READ)) { append("global_read"); - if (!this->isAssumed(NOT_GLOBAL_WRITE)) + } + if (!this->isAssumed(NOT_GLOBAL_WRITE)) { append("global_write"); - if (!this->isAssumed(NOT_GLOBAL_STORAGE)) + } + if (!this->isAssumed(NOT_GLOBAL_STORAGE)) { append("global_storage"); + } return str.empty() ? "*" : str; } @@ -250,8 +263,9 @@ class ValueResourceUsage : public AbstractResourceUsage { // itself is under analysis. void updateFromDefiningOp(Value value, OpResult result, DFX::Solver &solver) { // Some tied uses route through ops that change types - ignore those. - if (!isa(result.getType())) + if (!isa(result.getType())) { return; + } TypeSwitch(result.getOwner()) .Case([&](mlir::arith::SelectOp op) { @@ -552,8 +566,9 @@ class ValueResourceUsage : public AbstractResourceUsage { // This walks through tied uses as well. void updateFromUse(Value value, OpOperand &operand, DFX::Solver &solver) { // Some tied uses route through ops that change types - ignore those. - if (!isa(operand.get().getType())) + if (!isa(operand.get().getType())) { return; + } auto *userOp = operand.getOwner(); unsigned operandIdx = operand.getOperandNumber(); @@ -977,8 +992,9 @@ std::optional ResourceUsageAnalysis::tryLookupResourceUsage(Value value) { auto resourceUsage = solver.lookupElementFor(Position::forValue(value)); - if (!resourceUsage) + if (!resourceUsage) { return std::nullopt; + } return resourceUsage->getAssumedUsage(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index ae07021b2b49..f316560f9570 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -85,8 +85,9 @@ struct ConvertTensorDynamicConstantOp IREE::Stream::AffinityAttr executionAffinityAttr, ConversionPatternRewriter &rewriter) const override { auto attrType = dyn_cast(constantOp.getValue().getType()); - if (!attrType) + if (!attrType) { return failure(); + } auto resultType = constantOp.getType(); // If the op is acting as a dynamic value then preserve that behavior by @@ -355,13 +356,16 @@ struct ConvertTensorUpdateOp }; static bool isScalarTensor(RankedTensorType type) { - if (type.getRank() == 0) + if (type.getRank() == 0) { return true; // tensor - if (!type.hasStaticShape()) + } + if (!type.hasStaticShape()) { return false; // tensor<...?...xi32> + } int64_t elementCount = 1; - for (int64_t dim : type.getShape()) + for (int64_t dim : type.getShape()) { elementCount *= dim; + } return elementCount == 1; // tensor<1xi32> or tensor<1x1x1xi32> } @@ -1002,8 +1006,9 @@ static bool insertBindingOp(BlockArgument arg, IREE::TensorExt::DispatchTensorType tensorType, Value zero, OpBuilder &builder) { // No uses: don't need a binding op. - if (arg.use_empty()) + if (arg.use_empty()) { return true; + } // Find the dynamic dimension SSA values of the argument within the region. // If the flow dialect properly modeled dimension associations we wouldn't @@ -1018,8 +1023,9 @@ static bool insertBindingOp(BlockArgument arg, IREE::Flow::DispatchTieShapeOp tieShapeOp; for (auto user : arg.getUsers()) { tieShapeOp = dyn_cast(user); - if (tieShapeOp) + if (tieShapeOp) { break; + } } if (tieShapeOp) { // Found a tie shape op - we'll insert ourselves there. @@ -1125,8 +1131,9 @@ struct ConvertExecutableOp // Dispatch tensor arguments become bindings and all others are preserved // as adaptor. Note that we only touch public (exported) functions. for (auto funcOp : moduleOp.getOps()) { - if (!funcOp.isPublic()) + if (!funcOp.isPublic()) { continue; + } SmallVector newTypes; newTypes.reserve(funcOp.getNumArguments()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index 18f8f4cbdd07..1bb8c5ecc6bd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -20,8 +20,9 @@ namespace { /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const auto &vals : values) { llvm::append_range(result, vals); + } return result; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp index 2dd7777dd523..a87455290614 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp @@ -14,8 +14,9 @@ namespace mlir::iree_compiler { TypedAttr convertAttributeToStream(TypedAttr attr) { - if (!attr) + if (!attr) { return {}; + } if (auto parameterAttr = dyn_cast(attr)) { return IREE::Stream::NamedParameterAttr::get( attr.getContext(), parameterAttr.getType(), parameterAttr.getScope(), diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp index 8d90a92090ca..253c3b3d338d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/StandardToStream/Patterns.cpp @@ -32,8 +32,9 @@ namespace { /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const auto &vals : values) { llvm::append_range(result, vals); + } return result; } @@ -130,8 +131,9 @@ struct SelectOpConversion matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only handle selects where the operands are tensors (resources). - if (!isa(op.getTrueValue().getType())) + if (!isa(op.getTrueValue().getType())) { return failure(); + } auto trueOperand = resolveTensorOperands(op.getLoc(), op.getTrueValue(), adaptor.getTrueValue(), rewriter); auto falseOperand = resolveTensorOperands( diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp index de0497e3f439..58bc0c85e097 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp @@ -22,8 +22,9 @@ namespace { /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; - for (const auto &vals : values) + for (const auto &vals : values) { llvm::append_range(result, vals); + } return result; } @@ -99,8 +100,9 @@ struct CallOpConversion }, [&](unsigned i, Type type, SmallVectorImpl &newTypes) { size_t newIndex = newTypes.size(); - if (failed(getTypeConverter()->convertType(type, newTypes))) + if (failed(getTypeConverter()->convertType(type, newTypes))) { anyFailed = true; + } resultMap.push_back(Result{i, newIndex, newTypes[newIndex]}); }, rewriter); @@ -158,8 +160,9 @@ struct GlobalExpansionState { }; static bool isExpandedType(Type type) { - if (isa(type)) + if (isa(type)) { return true; + } if (auto ptrType = dyn_cast(type)) { return isExpandedType(ptrType); } @@ -190,8 +193,9 @@ struct GlobalOpExpansion matchAndRewrite(IREE::Util::GlobalOp globalOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only apply to expanded types (tensors/etc). - if (!isExpandedType(globalOp.getType())) + if (!isExpandedType(globalOp.getType())) { return failure(); + } SmallVector newTypes; if (failed(getTypeConverter()->convertType(globalOp.getType(), newTypes))) { @@ -297,13 +301,15 @@ struct GlobalLoadOpExpansion matchAndRewrite(IREE::Util::GlobalLoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only apply to expanded types (tensors/etc). - if (!isExpandedType(loadOp.getType())) + if (!isExpandedType(loadOp.getType())) { return failure(); + } auto expandedGlobalIt = this->expansionState->globalMap.find(adaptor.getGlobal()); - if (expandedGlobalIt == this->expansionState->globalMap.end()) + if (expandedGlobalIt == this->expansionState->globalMap.end()) { return rewriter.notifyMatchFailure(loadOp, "expanded global not found"); + } auto &expandedGlobal = expandedGlobalIt->getSecond(); @@ -336,13 +342,15 @@ struct GlobalStoreOpExpansion matchAndRewrite(IREE::Util::GlobalStoreOp storeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only apply to expanded types (tensors/etc). - if (!isExpandedType(storeOp.getValue().getType())) + if (!isExpandedType(storeOp.getValue().getType())) { return failure(); + } auto expandedGlobalIt = this->expansionState->globalMap.find(adaptor.getGlobal()); - if (expandedGlobalIt == this->expansionState->globalMap.end()) + if (expandedGlobalIt == this->expansionState->globalMap.end()) { return rewriter.notifyMatchFailure(storeOp, "expanded global not found"); + } auto &expandedGlobal = expandedGlobalIt->getSecond(); @@ -430,8 +438,9 @@ void populateUtilToStreamConversionPatterns( typeConverter.addConversion([=](IREE::Util::PtrType type, SmallVectorImpl &resultTypes) { // Expand pointers to tensors to [resource, sizeof resource] pointers. - if (!isExpandedType(type)) + if (!isExpandedType(type)) { return failure(); + } resultTypes.push_back( IREE::Util::PtrType::get(IREE::Stream::ResourceType::get(context))); resultTypes.push_back(IREE::Util::PtrType::get(IndexType::get(context))); @@ -441,8 +450,9 @@ void populateUtilToStreamConversionPatterns( typeConverter.addConversion( [=](IREE::Util::PtrType type, SmallVectorImpl &resultTypes) { // Expand pointers to tensors to [ptr, ptr]. - if (!isExpandedType(type.getTargetType())) + if (!isExpandedType(type.getTargetType())) { return failure(); + } resultTypes.push_back(IREE::Stream::ResourceType::get(context)); resultTypes.push_back(IndexType::get(context)); return success(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp index 32f357808ef9..6424b8fa9ceb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp @@ -69,8 +69,9 @@ struct StripResourceConversionCastPattern LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, PatternRewriter &rewriter) const override { auto result = castOp.getResult(0); - if (!isa(result.getType())) + if (!isa(result.getType())) { return failure(); + } assert(castOp.getNumOperands() == 2 && "expect resource, index -> resource"); auto resourceValue = castOp.getOperand(0); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 3dddcd632de7..148877ddeba3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -44,10 +44,12 @@ namespace { // 0xCDCDCDCD : i32 -> 0xCD : i8 static APInt computeRequiredPatternBits(APInt pattern) { // Special case for well-known constant values. - if (pattern.isZero()) + if (pattern.isZero()) { return APInt(8, 0u); - if (pattern.isAllOnes()) + } + if (pattern.isAllOnes()) { return APInt(8, 0xFF); + } // Extend up to a power of two bit width. This makes the value easier to work // with as we'll be dealing with one of 4 sizes (1/2/4/8b). @@ -142,8 +144,9 @@ static TypedAttr tryNarrowPatternBits(TypedAttr patternAttr) { // Try narrowing the pattern. auto newPattern = computeRequiredPatternBits(oldPattern); - if (newPattern.getBitWidth() == oldPattern.getBitWidth()) + if (newPattern.getBitWidth() == oldPattern.getBitWidth()) { return patternAttr; + } // Wrap the result in an attribute - note that it is always an integer. return IntegerAttr::get( @@ -163,8 +166,9 @@ struct NarrowFillPattern : public OpRewritePattern { return failure(); } auto newPatternAttr = tryNarrowPatternBits(oldPatternAttr); - if (newPatternAttr == oldPatternAttr) + if (newPatternAttr == oldPatternAttr) { return failure(); + } // Replace the pattern on the op with the new one. auto narrowValue = @@ -182,13 +186,16 @@ struct NarrowFillPattern : public OpRewritePattern { // stream.yield // } static std::optional getYieldIfOnlyOp(Block &block) { - if (block.empty()) + if (block.empty()) { return std::nullopt; - if (&block.front() != &block.back()) + } + if (&block.front() != &block.back()) { return std::nullopt; + } auto yieldOp = dyn_cast(block.back()); - if (yieldOp) + if (yieldOp) { return yieldOp; + } return std::nullopt; } @@ -250,14 +257,16 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { // If the sinking operation would be a no-op, then we need to prevent // the sinking operation, to avoid infinite pattern applications. - if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp))) + if (Block::iterator(targetOp) == std::next(Block::iterator(toBeSunkOp))) { return false; + } // If the sinking is to a different block, then it okay, since for any later // sinkings, this reduces the problem to stable sinking within a single // block (handled below). - if (toBeSunkOp->getBlock() != targetOp->getBlock()) + if (toBeSunkOp->getBlock() != targetOp->getBlock()) { return true; + } SmallPtrSet producerOps; if (allowUseDefPruning) { @@ -274,11 +283,13 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { Block::iterator(targetOp))) { // If the intervening op that is not even a sink candidate itself, // then it cannot fight. - if (!isSinkCandidate(&op)) + if (!isSinkCandidate(&op)) { return true; + } // If the op is pruned by use-def chains, then it won't fight. - if (allowUseDefPruning && !producerOps.contains(&op)) + if (allowUseDefPruning && !producerOps.contains(&op)) { return true; + } } return false; } @@ -286,8 +297,9 @@ static bool canStablySinkTo(Operation *toBeSunkOp, Operation *targetOp) { // Sinks |op| down to |targetOp|, ensuring that we don't oscillate. // Returns success if the op was sunk and failure if sinking was not needed. static LogicalResult sinkOp(Operation *op, Operation *targetOp) { - if (!canStablySinkTo(op, targetOp)) + if (!canStablySinkTo(op, targetOp)) { return failure(); + } op->moveBefore(targetOp); return success(); } @@ -319,8 +331,9 @@ struct ElideUnusedOp : public OpRewritePattern { : OpRewritePattern(context, /*benefit=*/1000) {} LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - if (!op.use_empty()) + if (!op.use_empty()) { return failure(); + } rewriter.eraseOp(op); return success(); } @@ -447,8 +460,9 @@ struct ElideImmediateTimepointWait : public OpRewritePattern { bool isImmediate = op.getAwaitTimepoint() && isa_and_nonnull( op.getAwaitTimepoint().getDefiningOp()); - if (!isImmediate) + if (!isImmediate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getAwaitTimepointMutable().clear(); }); return success(); @@ -482,8 +496,9 @@ struct ChainDependentAwaits : public OpRewritePattern { } } } - if (replacements.empty()) + if (replacements.empty()) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.setAwaitTimepoints(newTimepoints, rewriter); for (auto replacement : replacements) { @@ -712,8 +727,9 @@ struct SelectResourceSizeOp : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceSizeOp op, PatternRewriter &rewriter) const override { auto selectOp = op.getOperand().getDefiningOp(); - if (!selectOp) + if (!selectOp) { return failure(); + } auto trueSize = rewriter.createOrFold( op.getLoc(), selectOp.getTrueValue(), op.getAffinityAttr()); auto falseSize = rewriter.createOrFold( @@ -761,8 +777,9 @@ struct FoldSubviewIntoLoadOp : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceLoadOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getSource()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), op.getSourceOffset()); @@ -806,8 +823,9 @@ struct FoldSubviewIntoStoreOp : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceStoreOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), op.getTargetOffset()); @@ -873,8 +891,9 @@ struct PropagateResourcePackBaseOffset PatternRewriter &rewriter) const override { // Offset is optional. auto baseOffset = op.getOffset(); - if (!baseOffset) + if (!baseOffset) { return failure(); + } // We always strip the offset here. rewriter.modifyOpInPlace(op, [&]() { op.getOffsetMutable().clear(); }); @@ -932,8 +951,9 @@ struct CanonicalizeResourcePackIntervals break; } } - if (!orderChanged) + if (!orderChanged) { return failure(); + } // TODO(benvanik): compact the slice ranges. @@ -993,8 +1013,9 @@ struct FoldResourceSubviewOps : public OpRewritePattern { LogicalResult matchAndRewrite(ResourceSubviewOp op, PatternRewriter &rewriter) const override { auto parentOp = ResourceSubviewOp::findSubviewOp(op.getSource()); - if (!parentOp) + if (!parentOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({parentOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, parentOp.getSourceOffset(), op.getSourceOffset()); @@ -1021,14 +1042,16 @@ struct SinkSubviewAcrossSelectOps using Base::Base; LogicalResult matchAndRewrite(mlir::arith::SelectOp op, PatternRewriter &rewriter) const override { - if (!isa(op.getType())) + if (!isa(op.getType())) { return failure(); + } auto trueSubview = dyn_cast_if_present( op.getTrueValue().getDefiningOp()); auto falseSubview = dyn_cast_if_present( op.getFalseValue().getDefiningOp()); - if (!trueSubview || !falseSubview) + if (!trueSubview || !falseSubview) { return failure(); + } if (trueSubview.getSource() != falseSubview.getSource() || trueSubview.getResultSize() != falseSubview.getResultSize()) { return failure(); @@ -1134,8 +1157,9 @@ struct TensorConstantToEmpty : public OpRewritePattern { LogicalResult matchAndRewrite(TensorConstantOp constantOp, PatternRewriter &rewriter) const override { auto shapedType = dyn_cast(constantOp.getResultEncoding()); - if (!shapedType) + if (!shapedType) { return failure(); + } // See if any dim (including dynamic ones) is known zero. // It's still possible for empty tensors to slip through if their dynamic @@ -1155,8 +1179,9 @@ struct TensorConstantToEmpty : public OpRewritePattern { break; } } - if (!anyZeroDims) + if (!anyZeroDims) { return failure(); + } // Definitely empty if here. Value resultSize = IREE::Stream::TensorSizeOfOp::create( @@ -1383,8 +1408,9 @@ struct DeduplicateTensorDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -1414,8 +1440,9 @@ struct SinkAllocaLikeOpToConsumers : public OpRewritePattern { LogicalResult matchAndRewrite(Op producerOp, PatternRewriter &rewriter) const override { auto users = llvm::to_vector(producerOp->getUsers()); - if (users.size() == 0) + if (users.size() == 0) { return failure(); + } // If we have a single user then we can sink right to it. if (users.size() == 1) { @@ -1576,8 +1603,9 @@ struct PropagateSplatsThroughSlices : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = sliceOp.getSource().getDefiningOp(); - if (!splatOp) + if (!splatOp) { return failure(); + } rewriter.replaceOpWithNewOp( sliceOp, sliceOp.getResult().getType(), splatOp.getValue(), sliceOp.getResultSize(), sliceOp.getAffinityAttr(), @@ -1615,8 +1643,9 @@ struct FlattenFullFillToSplat : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(AsyncFillOp fillOp, PatternRewriter &rewriter) const override { - if (fillOp.getTargetLength() != fillOp.getTargetSize()) + if (fillOp.getTargetLength() != fillOp.getTargetSize()) { return failure(); + } auto targetOp = fillOp.getTarget().getDefiningOp(); if (!targetOp || IREE::Util::TiedOpInterface::findTiedBaseValue( @@ -1647,8 +1676,9 @@ struct ElideRedundantFill : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = dyn_cast_if_present( fillOp.getTarget().getDefiningOp()); - if (!splatOp) + if (!splatOp) { return failure(); + } if (splatOp.getValue() != fillOp.getValue()) { return rewriter.notifyMatchFailure(fillOp, "fill patterns are not compatible"); @@ -1678,8 +1708,9 @@ struct CoalesceAdjacentFills : public OpRewritePattern { PatternRewriter &rewriter) const override { auto sourceOp = dyn_cast_if_present( fillOp.getTarget().getDefiningOp()); - if (!sourceOp) + if (!sourceOp) { return failure(); + } if (!sourceOp.getResult().hasOneUse()) { // Note that hazard analysis could make this work if we can guarantee that // the source result is only ever sliced out to a range that doesn't @@ -1757,20 +1788,23 @@ static bool hasValueSemantics(Value value) { // Can't analyze function arguments (though we could add arg attrs to indicate // value semantics). auto *definingOp = value.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return false; + } // If produced by a tied op then see if the particular result is tied. if (auto tiedOp = dyn_cast(definingOp)) { - if (tiedOp.getTiedResultOperand(value)) + if (tiedOp.getTiedResultOperand(value)) { return false; + } } // To be conservative we only allow stream dialect ops that produce the // resource as we know they all indicate value semantics when non-tied - ops // from other dialects may not. - if (!definingOp->hasTrait()) + if (!definingOp->hasTrait()) { return false; + } return true; } @@ -1894,8 +1928,9 @@ struct CombineSplatUpdateFromToFill : public OpRewritePattern { PatternRewriter &rewriter) const override { auto splatOp = updateOp.getUpdate().getDefiningOp(); - if (!splatOp) + if (!splatOp) { return failure(); + } rewriter.replaceOpWithNewOp( updateOp, updateOp.getResult().getType(), updateOp.getTarget(), updateOp.getTargetSize(), updateOp.getTargetOffset(), @@ -2078,12 +2113,14 @@ struct IntermediateTransferElision : public OpRewritePattern { auto source = originTransferOp.getSource(); auto previousTransferOp = dyn_cast_if_present(source.getDefiningOp()); - if (!previousTransferOp) + if (!previousTransferOp) { break; + } originTransferOp = previousTransferOp; } - if (originTransferOp == transferOp) + if (originTransferOp == transferOp) { return failure(); + } rewriter.replaceOpWithNewOp( transferOp, transferOp.getResult().getType(), originTransferOp.getSource(), originTransferOp.getSourceSize(), @@ -2116,12 +2153,14 @@ struct FoldAsyncLoadBitcast : public OpRewritePattern { LogicalResult matchAndRewrite(AsyncLoadOp loadOp, PatternRewriter &rewriter) const override { auto loadedValue = loadOp.getResult(); - if (!loadedValue.hasOneUse()) + if (!loadedValue.hasOneUse()) { return failure(); + } auto bitcastOp = dyn_cast(*loadedValue.getUsers().begin()); - if (!bitcastOp) + if (!bitcastOp) { return failure(); + } rewriter.modifyOpInPlace( loadOp, [&]() { loadedValue.setType(bitcastOp.getType()); }); rewriter.replaceOp(bitcastOp, loadedValue); @@ -2187,8 +2226,9 @@ struct DeduplicateAsyncDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -2235,13 +2275,15 @@ struct CloneCapturedAsyncExecuteSubviewOps SmallVector captures; for (auto operand : llvm::enumerate(op.getResourceOperands())) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand.value()); - if (!subviewOp) + if (!subviewOp) { continue; + } captures.push_back( SubviewCapture{static_cast(operand.index()), subviewOp}); } - if (captures.empty()) + if (captures.empty()) { return failure(); + } rewriter.startOpModification(op); auto &entryBlock = op.getBody().front(); @@ -2383,8 +2425,9 @@ findConsumerThroughAwait(Value timelineResult) { for (auto [resource, result] : llvm::zip_equal(awaitOp.getResourceOperands(), awaitOp.getResults())) { if (resource == timelineResult) { - if (!result.hasOneUse()) + if (!result.hasOneUse()) { return {nullptr, nullptr}; + } return {*result.getUsers().begin(), result}; } } @@ -2867,8 +2910,9 @@ struct FoldSubviewsIntoCmdFlushOp : public OpRewritePattern { LogicalResult matchAndRewrite(CmdFlushOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -2909,8 +2953,9 @@ struct FoldSubviewsIntoCmdInvalidateOp LogicalResult matchAndRewrite(CmdInvalidateOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -2950,8 +2995,9 @@ struct FoldSubviewsIntoCmdDiscardOp : public OpRewritePattern { LogicalResult matchAndRewrite(CmdDiscardOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -2991,8 +3037,9 @@ struct FoldSubviewsIntoCmdFillOp : public OpRewritePattern { LogicalResult matchAndRewrite(CmdFillOp op, PatternRewriter &rewriter) const override { auto subviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!subviewOp) + if (!subviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( @@ -3034,8 +3081,9 @@ struct FoldSubviewsIntoCmdCopyOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto sourceSubviewOp = ResourceSubviewOp::findSubviewOp(op.getSource()); auto targetSubviewOp = ResourceSubviewOp::findSubviewOp(op.getTarget()); - if (!sourceSubviewOp && !targetSubviewOp) + if (!sourceSubviewOp && !targetSubviewOp) { return failure(); + } setInsertionPointToParentExecutionScope(op, rewriter); if (sourceSubviewOp) { auto fusedLoc = @@ -3100,19 +3148,22 @@ struct FoldSubviewsIntoDispatchOp : public OpRewritePattern { bool anySubviewOps = false; for (auto operand : op.getResources()) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand); - if (subviewOp) + if (subviewOp) { anySubviewOps = true; + } resourceSubviewOps.push_back(subviewOp); } - if (!anySubviewOps) + if (!anySubviewOps) { return failure(); + } rewriter.startOpModification(op); setInsertionPointToParentExecutionScope(op, rewriter); for (auto [resourceIndex, subviewOp] : llvm::enumerate(resourceSubviewOps)) { - if (!subviewOp) + if (!subviewOp) { continue; + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), @@ -3151,8 +3202,9 @@ struct DeduplicateCmdDispatchEntryRefs final PatternRewriter &rewriter) const override { auto originalAttr = dispatchOp.getEntryPointsAttr(); auto newAttr = deduplicateArrayElements(originalAttr); - if (newAttr == originalAttr) + if (newAttr == originalAttr) { return failure(); + } rewriter.modifyOpInPlace(dispatchOp, [&]() { dispatchOp.setEntryPointsAttr(newAttr); }); return success(); @@ -3187,21 +3239,24 @@ struct FoldSubviewsIntoCmdCallOp : public OpRewritePattern { llvm::enumerate(op.getResourceOperands())) { if (isa(operand.getType())) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand); - if (subviewOp) + if (subviewOp) { anySubviewOps = true; + } resourceSubviewOps.push_back({operandIndex, subviewOp}); } } - if (!anySubviewOps) + if (!anySubviewOps) { return failure(); + } rewriter.startOpModification(op); setInsertionPointToParentExecutionScope(op, rewriter); for (auto [resourceIndex, resourceSubviewOp] : llvm::enumerate(resourceSubviewOps)) { auto [operandIndex, subviewOp] = resourceSubviewOp; - if (!subviewOp) + if (!subviewOp) { continue; + } auto fusedLoc = rewriter.getFusedLoc({subviewOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, subviewOp.getSourceOffset(), @@ -3258,13 +3313,15 @@ struct CloneCapturedCmdExecuteSubviewOps SmallVector captures; for (auto operand : llvm::enumerate(op.getResourceOperands())) { auto subviewOp = ResourceSubviewOp::findSubviewOp(operand.value()); - if (!subviewOp) + if (!subviewOp) { continue; + } captures.push_back( SubviewCapture{static_cast(operand.index()), subviewOp}); } - if (captures.empty()) + if (captures.empty()) { return failure(); + } rewriter.startOpModification(op); auto &entryBlock = op.getBody().front(); @@ -3414,8 +3471,9 @@ struct FoldParameterLoadTargetSubviews } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceOffsetsMutable().assign(newSourceOffsets); op.getResultSizesMutable().assign(newResultSizes); @@ -3465,8 +3523,9 @@ struct FoldParameterReadTargetSubview needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceOffsetMutable().assign(newSourceOffset); op.getTargetMutable().assign(newTargetResource); @@ -3518,8 +3577,9 @@ struct FoldParameterWriteSourceSubview needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceMutable().assign(newSourceResource); op.getSourceSizeMutable().assign(newSourceSize); @@ -3768,8 +3828,9 @@ struct ElideImmediateTimepointJoinOperands newTimepoints.push_back(timepoint); } } - if (newTimepoints.size() == op.getAwaitTimepoints().size()) + if (newTimepoints.size() == op.getAwaitTimepoints().size()) { return failure(); + } if (newTimepoints.empty()) { // Fully immediate; replace entire join with immediate. rewriter.replaceOpWithNewOp( @@ -3790,8 +3851,9 @@ struct FoldDuplicateTimepointJoinOperands SetVector newTimepoints; newTimepoints.insert(op.getAwaitTimepoints().begin(), op.getAwaitTimepoints().end()); - if (newTimepoints.size() == op.getAwaitTimepoints().size()) + if (newTimepoints.size() == op.getAwaitTimepoints().size()) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getAwaitTimepointsMutable().assign(newTimepoints.takeVector()); }); @@ -3821,8 +3883,9 @@ struct ExpandTimepointJoinOperands : public OpRewritePattern { newTimepoints.insert(timepoint); } } - if (!didExpand) + if (!didExpand) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getAwaitTimepointsMutable().assign(newTimepoints.takeVector()); }); @@ -3853,8 +3916,9 @@ static bool isSourceImmediatelyResolved(Value resource) { // TODO(benvanik): data flow analysis/at least walk up tied ops. For now we // err on the conservative side and only check for a few common scenarios. auto *definingOp = resource.getDefiningOp(); - if (!definingOp) + if (!definingOp) { return false; + } return TypeSwitch(definingOp) .Case( [](auto op) { return true; }) @@ -3902,8 +3966,9 @@ findSourceAwaitOp(Value resource) { } } auto tiedValue = definingOp.getTiedResultOperand(baseResource); - if (!tiedValue) + if (!tiedValue) { break; + } baseResource = tiedValue; } return {nullptr, nullptr}; @@ -3925,8 +3990,9 @@ struct ChainTimepoints : public OpRewritePattern { // Try to find an await op. This may traverse through any number of tied ops // along the way. auto [awaitOp, baseResource] = findSourceAwaitOp(barrierOp.getResource()); - if (!awaitOp) + if (!awaitOp) { return failure(); + } // TODO(benvanik): move this to a pass that can do IPO. Local analysis is // insufficient for this. For now we conservatively ignore any case where @@ -4007,8 +4073,9 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern { // Its possible we are nested in an SCF region. If so the SCF operation // depends on the timepoint as a whole. Operation *owner = use.getOwner(); - while (owner && owner->getParentOp() != op->getParentOp()) + while (owner && owner->getParentOp() != op->getParentOp()) { owner = owner->getParentOp(); + } if (allUsers.insert(owner)) { auto *userBlock = owner->getBlock(); @@ -4019,8 +4086,9 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern { } } } - if (!commonDominator) + if (!commonDominator) { return failure(); + } // Find the first use within the dominator block (if any) so that we // can sink down to it. @@ -4035,8 +4103,9 @@ struct SinkAwaitToFirstConsumer : public OpRewritePattern { // If sinking to `firstUserInDominator` could result in patterns // fighting each other, then don't sink. - if (!canStablySinkTo(op, firstUserInDominator)) + if (!canStablySinkTo(op, firstUserInDominator)) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op->moveBefore(firstUserInDominator); }); @@ -4057,8 +4126,9 @@ struct SinkSubviewsAcrossAwaits : public OpRewritePattern { for (auto operand : llvm::enumerate(op.getResourceOperands())) { auto subviewOp = operand.value().getDefiningOp(); - if (!subviewOp) + if (!subviewOp) { continue; + } didChange = true; unsigned operandIdx = static_cast(operand.index()); @@ -4095,8 +4165,9 @@ struct SinkSubviewsAcrossAwaits : public OpRewritePattern { static bool areAllOperandsDefinedBy(Operation *op, Operation *insertionPoint, DominanceInfo &dominanceInfo) { for (auto operand : op->getOperands()) { - if (!dominanceInfo.dominates(operand, insertionPoint)) + if (!dominanceInfo.dominates(operand, insertionPoint)) { return false; + } } return true; } @@ -4124,15 +4195,19 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern { // TODO(benvanik): make this handle joins/ties; today we get blocked // there. We rely on other canonicalizers to sink things such that // (hopefully) we get them directly accessible here. - if (use.getOwner() == op) + if (use.getOwner() == op) { continue; - if (op->getBlock() != use.getOwner()->getBlock()) + } + if (op->getBlock() != use.getOwner()->getBlock()) { continue; - if (dominanceInfo.dominates(use.getOwner(), op)) + } + if (dominanceInfo.dominates(use.getOwner(), op)) { continue; + } auto awaitOp = dyn_cast(use.getOwner()); - if (!awaitOp || awaitOp.getSync()) + if (!awaitOp || awaitOp.getSync()) { continue; + } // Ensure all dependencies of the await op are available. if (!areAllOperandsDefinedBy(awaitOp, op, dominanceInfo)) { // One or more operands is defined after op so we can't merge. @@ -4140,8 +4215,9 @@ struct GroupAwaitsByTimepoint : public OpRewritePattern { } coveredOps.push_back(awaitOp); } - if (coveredOps.empty()) + if (coveredOps.empty()) { return failure(); + } coveredOps.push_back(op); // Sort the ops by their definition order; this gives us a deterministic diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp index 6afb02f6d217..cb5249d0c752 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp @@ -167,10 +167,12 @@ static LogicalResult verifyAllResourcesCaptured(Region ®ion) { availableResources.insert(result); } for (auto operand : op.getOperands()) { - if (!operand) + if (!operand) { continue; - if (!isa(operand.getType())) + } + if (!isa(operand.getType())) { continue; + } if (!availableResources.contains(operand)) { return op.emitOpError() << "used resource not listed in explicit " "captures (or produced internally)"; @@ -215,8 +217,9 @@ static void eraseStreamRegionResults(Region ®ion, ArrayRef excludedResultIndices) { for (auto &block : region.getBlocks()) { auto yieldOp = dyn_cast(block.getTerminator()); - if (!yieldOp) + if (!yieldOp) { continue; + } // HACK: there's no good way of updating the operand and size together today // - we should add a helper to the ClosureYieldOpInterface that checks for // size/shape aware traits and does this automatically. @@ -316,8 +319,9 @@ static IREE::Util::ValueAccess computeValueAccess(Value rootValue) { DenseSet processedValues; SmallVector worklist; auto enqueueValue = [&](Value value) { - if (processedValues.contains(value)) + if (processedValues.contains(value)) { return; + } processedValues.insert(value); worklist.push_back(value); }; @@ -357,8 +361,9 @@ static IREE::Util::ValueAccess computeValueAccess(Value rootValue) { if (auto tiedOp = dyn_cast(user)) { auto tiedIndices = tiedOp.getTiedResultOperandIndices(); for (int64_t tiedIndex : tiedIndices) { - if (tiedIndex == IREE::Util::TiedOpInterface::kUntiedIndex) + if (tiedIndex == IREE::Util::TiedOpInterface::kUntiedIndex) { continue; + } auto operand = user->getOperand(tiedIndex); if (operand == value) { // Tied operand. @@ -387,16 +392,19 @@ static ParseResult parseDispatchEntryPoints(OpAsmParser &parser, if (succeeded(parser.parseOptionalLBrace())) { do { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRBrace())) + if (failed(parser.parseRBrace())) { return failure(); + } } else { SymbolRefAttr entryPointAttr; - if (failed(parser.parseAttribute(entryPointAttr))) + if (failed(parser.parseAttribute(entryPointAttr))) { return failure(); + } entryPointAttrs.push_back(entryPointAttr); } entryPointAttrsArray = parser.getBuilder().getArrayAttr(entryPointAttrs); @@ -434,21 +442,24 @@ static ParseResult parseEncodedResourceOperands( TypeAttr resourceEncoding; if (failed(parser.parseOperand(resources.back())) || failed(parser.parseColon()) || - failed(parser.parseAttribute(resourceEncoding))) + failed(parser.parseAttribute(resourceEncoding))) { return failure(); + } resourceEncodingAttrs.push_back(resourceEncoding); if (int64_t dynamicDimCount = cast(resourceEncoding.getValue()).getNumDynamicDims()) { if (failed(parser.parseOperandList(resourceEncodingDims, dynamicDimCount, - AsmParser::Delimiter::Braces))) + AsmParser::Delimiter::Braces))) { return failure(); + } } resourceTypes.emplace_back(); resourceSizes.emplace_back(); if (failed(parser.parseKeyword("in")) || failed(parseSizeAwareType(parser, resourceTypes.back(), - resourceSizes.back()))) + resourceSizes.back()))) { return failure(); + } } while (succeeded(parser.parseOptionalComma())); resourceEncodings = parser.getBuilder().getArrayAttr(resourceEncodingAttrs); return success(); @@ -1429,12 +1440,14 @@ static void printResourceRegion(OpAsmPrinter &p, Operation *op, p << ")"; if (!resultTypes.empty()) { p << " -> "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printShapedResultList(p, op, operands, operandTypes, operandSizes, resultTypes, resultSizes, tiedOperands); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } p << " "; p.printRegion(body, /*printEntryBlockArgs=*/false, @@ -1527,8 +1540,9 @@ static ParseResult parsePackSliceRanges( auto indexType = parser.getBuilder().getIndexType(); SmallVector lifetimeRangeValues; do { - if (failed(parser.parseOptionalLSquare())) + if (failed(parser.parseOptionalLSquare())) { break; + } IntegerAttr lifetimeStart; IntegerAttr lifetimeEnd; OpAsmParser::UnresolvedOperand dynamicSliceSize; @@ -1552,8 +1566,9 @@ static void printPackSliceRanges(OpAsmPrinter &p, Operation *op, ArrayAttr lifetimeIntervals, ValueRange dynamicSliceSizes, TypeRange packedOffsetTypes) { - if (packedOffsetTypes.empty()) + if (packedOffsetTypes.empty()) { return; + } for (unsigned i = 0; i < packedOffsetTypes.size(); ++i) { auto lifetimeStart = lifetimeIntervals[i * 2]; auto lifetimeEnd = lifetimeIntervals[i * 2 + 1]; @@ -1565,8 +1580,9 @@ static void printPackSliceRanges(OpAsmPrinter &p, Operation *op, p.printAttributeWithoutType(lifetimeEnd); p << "] = "; p.printOperand(sliceSize); - if (i < packedOffsetTypes.size() - 1) + if (i < packedOffsetTypes.size() - 1) { p << ","; + } } p.printNewline(); } @@ -1604,16 +1620,18 @@ static ParseResult parseConstantValueList( static void printConstantValueList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ValueRange resultSizes, ArrayAttr values) { - if (resultTypes.empty()) + if (resultTypes.empty()) { return; + } for (unsigned i = 0; i < resultTypes.size(); ++i) { p.printNewline(); p << " "; printSizeAwareType(p, op, resultTypes[i], resultSizes[i]); p << " = "; p.printAttribute(values[i]); - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ","; + } } } @@ -1667,13 +1685,15 @@ static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser, static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op, Region &body) { - if (body.empty()) + if (body.empty()) { return; + } p << "workgroups("; auto args = body.getArguments(); for (unsigned i = 0; i < args.size(); ++i) { - if (i > 0) + if (i > 0) { p << ", "; + } p.printRegionArgument(args[i]); } p << ")"; @@ -1695,8 +1715,9 @@ ResourceAllocOp::createSuballocations( bool uninitialized, AffinityAttr affinityAttr, OpBuilder &builder) { assert(locs.size() == storageSizes.size() && "expect locs and storageSizes to match"); - if (locs.empty()) + if (locs.empty()) { return {}; + } if (locs.size() == 1) { auto allocOp = IREE::Stream::ResourceAllocOp::create( builder, locs.front(), resourceType, storageSizes.front(), @@ -1750,8 +1771,9 @@ ResourceAllocaOp::createSuballocations(Type timepointType, Type resourceType, OpBuilder &builder) { assert(locs.size() == storageSizes.size() && "expect locs and storageSizes to match"); - if (locs.empty()) + if (locs.empty()) { return {}; + } if (locs.size() == 1) { auto allocaOp = IREE::Stream::ResourceAllocaOp::create( builder, locs.front(), resourceType, timepointType, @@ -2546,12 +2568,14 @@ void AsyncSplatOp::build(OpBuilder &builder, OperationState &state, Type result_type, Value value, Value result_size, Attribute affinity, Value await_timepoint) { state.addTypes(result_type); - if (await_timepoint) + if (await_timepoint) { state.addOperands(await_timepoint); + } state.addOperands(value); state.addOperands(result_size); - if (affinity) + if (affinity) { state.addAttribute("affinity", affinity); + } } LogicalResult AsyncSplatOp::verify() { @@ -2748,8 +2772,9 @@ static ParseResult parseCollectiveParam( OpAsmParser &parser, Attribute opAttr, std::optional &optionalParamValue) { const char *keyword = getCollectiveParamKeyword(opAttr); - if (!keyword) + if (!keyword) { return success(); // optional + } OpAsmParser::UnresolvedOperand paramValue; if (failed(parser.parseKeyword(keyword)) || failed(parser.parseLParen()) || failed(parser.parseOperand(paramValue)) || failed(parser.parseRParen())) { @@ -2995,16 +3020,19 @@ static ParseResult parseDispatchOperands( SmallVectorImpl &resourceOffsets, SmallVectorImpl &resourceEnds, SmallVectorImpl &resourceLengths) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } // Handle the case of no operands specially. - if (succeeded(parser.parseOptionalRParen())) + if (succeeded(parser.parseOptionalRParen())) { return success(); + } do { // All entries at least have an %operand. resourceOperands.emplace_back(); - if (failed(parser.parseOperand(resourceOperands.back()))) + if (failed(parser.parseOperand(resourceOperands.back()))) { return failure(); + } // Resources have a range. if (succeeded(parser.parseOptionalLSquare())) { resourceOffsets.emplace_back(); @@ -3020,8 +3048,9 @@ static ParseResult parseDispatchOperands( } } } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } return success(); } @@ -3090,8 +3119,9 @@ void AsyncDispatchOp::getAsyncAccessRanges( unsigned rangeIndex = 0; unsigned tiedOperandBase = getTiedOperandsIndexAndLength().first; for (auto [operandIndex, operand] : llvm::enumerate(getResourceOperands())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } ResourceAccessBitfield access = ResourceAccessBitfield::Read; auto tiedResults = getOperandTiedResults(tiedOperandBase + operandIndex); if (!tiedResults.empty()) { @@ -3171,12 +3201,14 @@ void AsyncFuncOp::build(OpBuilder &builder, OperationState &state, bool AsyncFuncOp::isResultTied(int resultIndex) { auto tiedOperandsAttr = getTiedOperandsAttr(); - if (!tiedOperandsAttr) + if (!tiedOperandsAttr) { return false; + } auto indexAttr = dyn_cast_if_present( tiedOperandsAttr.getValue()[resultIndex]); - if (!indexAttr) + if (!indexAttr) { return false; + } return indexAttr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; } @@ -3246,8 +3278,9 @@ AsyncCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } // auto typesCompatible = [](Type actual, Type expected) { auto typesCompatible = [](Type callee, Type call) { - if (callee == call) + if (callee == call) { return true; + } auto calleeResource = dyn_cast(callee); auto callResource = dyn_cast(call); if (calleeResource && callResource) { @@ -3293,8 +3326,9 @@ void AsyncCallOp::getAsyncAccessRanges( unsigned rangeIndex = 0; unsigned tiedOperandBase = getTiedOperandsIndexAndLength().first; for (auto [operandIndex, operand] : llvm::enumerate(getResourceOperands())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } ResourceAccessBitfield access = ResourceAccessBitfield::Read; auto tiedResults = getOperandTiedResults(tiedOperandBase + operandIndex); if (!tiedResults.empty()) { @@ -3336,8 +3370,9 @@ void AsyncExecuteOp::build(OpBuilder &builder, OperationState &state, state.addOperands(operands); state.addOperands(operandSizes); state.addOperands(resultSizes); - if (awaitTimepoint) + if (awaitTimepoint) { state.addOperands(awaitTimepoint); + } state.addAttributes(attributes); state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName()); state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(), @@ -3400,8 +3435,9 @@ getExecutionAsyncAccessRanges(Op op, for (auto [i, operand, operandSize] : llvm::zip_equal( llvm::seq(0, op.getResourceOperands().size()), op.getResourceOperands(), op.getResourceOperandSizes())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } ResourceAccessBitfield access = ResourceAccessBitfield::Read; auto tiedResults = op.getOperandTiedResults(tiedOperandBase + i); if (!tiedResults.empty()) { @@ -3479,8 +3515,9 @@ AsyncExecuteOp::cloneReplacementExcludingOperandsAndResults( auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : excludedOperandIndices) + for (auto i : excludedOperandIndices) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; } @@ -3598,8 +3635,9 @@ AsyncConcurrentOp::cloneReplacementExcludingOperandsAndResults( eraseStreamRegionResults(newBody, excludedResultIndices); auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : excludedOperandIndices) + for (auto i : excludedOperandIndices) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; } @@ -3640,8 +3678,9 @@ Value AsyncParameterReadOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterReadOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) - return {0}; // result tied to target + if (resultIndex == 0) { + return {0}; // result tied to target + } return std::nullopt; // result_timepoint not tied } @@ -3678,8 +3717,9 @@ Value AsyncParameterWriteOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterWriteOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) - return {0}; // result tied to source + if (resultIndex == 0) { + return {0}; // result tied to source + } return std::nullopt; // result_timepoint not tied } @@ -3737,10 +3777,11 @@ Value AsyncParameterGatherOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterGatherOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) + if (resultIndex == 0) { return { getSourceOffsets() - .size()}; // result tied to target (after variadic source_offsets) + .size()}; // result tied to target (after variadic source_offsets) + } return std::nullopt; // result_timepoint not tied } @@ -3802,8 +3843,9 @@ Value AsyncParameterScatterOp::getTiedResult(unsigned resultIndex) { ::std::optional AsyncParameterScatterOp::getTiedResultOperandIndex(unsigned resultIndex) { - if (resultIndex == 0) - return {0}; // result tied to source + if (resultIndex == 0) { + return {0}; // result tied to source + } return std::nullopt; // result_timepoint not tied } @@ -4109,8 +4151,9 @@ printDispatchResources(OpAsmPrinter &p, Operation *op, ValueRange resources, p.printOperand(resourceLength); p << "] : "; printSizeAwareType(p, op, resourceType, resourceSize); - if (i < resources.size() - 1) + if (i < resources.size() - 1) { p << ","; + } } } @@ -4196,8 +4239,9 @@ static ParseResult parseDispatchFunctionArgumentList( SmallVector argAttrsVec; do { OpAsmParser::UnresolvedOperand arg; - if (failed(parser.parseOperand(arg))) + if (failed(parser.parseOperand(arg))) { return failure(); + } bool hasOffsetLength = false; OpAsmParser::UnresolvedOperand offsetArg; OpAsmParser::UnresolvedOperand lengthArg; @@ -4272,8 +4316,9 @@ static void printDispatchFunctionResultList(OpAsmPrinter &p, Operation *op, p.printOptionalAttrDict(attrs.getValue()); } } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -4284,8 +4329,9 @@ ParseResult parseDispatchFunctionSignature(OpAsmParser &parser, SmallVector args; SmallVector argTypes; SmallVector resultTypes; - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } if (failed(parser.parseOptionalRParen())) { if (failed(parseDispatchFunctionArgumentList(parser, args, argTypes, argAttrs)) || @@ -4318,8 +4364,9 @@ void printDispatchFunctionSignature(OpAsmPrinter &p, Operation *op, auto functionType = cast(functionTypeAttr.getValue()); p << "("; for (size_t argIndex = 0; argIndex < functionType.getNumInputs();) { - if (argIndex) + if (argIndex) { p << ", "; + } int baseArgIndex = argIndex; auto type = functionType.getInput(baseArgIndex); p << "%arg"; @@ -4345,11 +4392,13 @@ void printDispatchFunctionSignature(OpAsmPrinter &p, Operation *op, auto resultTypes = functionType.getResults(); if (!resultTypes.empty()) { p << " -> "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printDispatchFunctionResultList(p, op, resultTypes, resultAttrs); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } } @@ -4411,11 +4460,13 @@ static ParseResult parseCmdCallOperands( SmallVectorImpl &resourceOffsets, SmallVectorImpl &resourceLengths, ArrayAttr &resourceAccesses) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } // Handle the case of no operands specially. - if (succeeded(parser.parseOptionalRParen())) + if (succeeded(parser.parseOptionalRParen())) { return success(); + } SmallVector accessAttrs; do { StringRef accessStr; @@ -4454,8 +4505,9 @@ static ParseResult parseCmdCallOperands( } } while (succeeded(parser.parseOptionalComma())); resourceAccesses = parser.getBuilder().getArrayAttr(accessAttrs); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } return success(); } @@ -4500,8 +4552,9 @@ static void printCmdCallOperands(OpAsmPrinter &p, Operation *op, // Primitive/custom type. p.printOperand(operand); } - if (i < resourceOperands.size() - 1) + if (i < resourceOperands.size() - 1) { p << ", "; + } } p << ")"; } @@ -4517,8 +4570,9 @@ void CmdExecuteOp::build(OpBuilder &builder, OperationState &state, state.addTypes(IREE::Stream::TimepointType::get(builder.getContext())); state.addOperands(operands); state.addOperands(operandSizes); - if (awaitTimepoint) + if (awaitTimepoint) { state.addOperands(awaitTimepoint); + } state.addAttributes(attributes); state.attributes.erase(getOperandSegmentSizeAttr()); state.addAttribute(getOperandSegmentSizeAttr(), @@ -4552,8 +4606,9 @@ LogicalResult CmdExecuteOp::verify() { return failure(); } for (auto &nestedOp : op.getBody().front()) { - if (failed(verifyCmdOp(&nestedOp))) + if (failed(verifyCmdOp(&nestedOp))) { return failure(); + } } return success(); } @@ -4616,8 +4671,9 @@ CmdExecuteOp::cloneReplacementExcludingOperandsAndResults( newBody.takeBody(getClosureBodyRegion()); auto &block = newBody.front(); BitVector eraseIndices(block.getNumArguments()); - for (auto i : excludedOperandIndices) + for (auto i : excludedOperandIndices) { eraseIndices.set(i); + } block.eraseArguments(eraseIndices); return newOp; } @@ -4629,8 +4685,9 @@ CmdExecuteOp::cloneReplacementExcludingOperandsAndResults( LogicalResult CmdSerialOp::verify() { CmdSerialOp op = *this; for (auto &nestedOp : op.getBody().front()) { - if (failed(verifyCmdOp(&nestedOp))) + if (failed(verifyCmdOp(&nestedOp))) { return failure(); + } } return success(); } @@ -4655,8 +4712,9 @@ void CmdSerialOp::getSuccessorRegions( LogicalResult CmdConcurrentOp::verify() { CmdConcurrentOp op = *this; for (auto &nestedOp : op.getBody().front()) { - if (failed(verifyCmdOp(&nestedOp))) + if (failed(verifyCmdOp(&nestedOp))) { return failure(); + } } return success(); } @@ -4768,8 +4826,9 @@ LogicalResult TimepointJoinOp::verify() { Value TimepointJoinOp::join(Location loc, ValueRange timepoints, OpBuilder &builder) { assert(!timepoints.empty() && "must have at least one timepoint"); - if (timepoints.size() == 1) + if (timepoints.size() == 1) { return timepoints.front(); + } return IREE::Stream::TimepointJoinOp::create( builder, loc, builder.getType(), timepoints); } @@ -4967,11 +5026,13 @@ LogicalResult ExecutableExportOp::verify() { mlir::FunctionOpInterface ExecutableExportOp::lookupFunctionRef() { auto executableOp = this->getOperation()->getParentOfType(); - if (!executableOp) + if (!executableOp) { return {}; + } auto innerModuleOp = executableOp.getInnerModule(); - if (!innerModuleOp) + if (!innerModuleOp) { return {}; + } return innerModuleOp.lookupSymbol( getFunctionRef()); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp index 7126df08a411..7bbb612f8b17 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp @@ -100,8 +100,9 @@ bool AsyncAccessRange::mayOverlap(const AsyncAccessRange &lhs, const AsyncAccessRange &rhs) { // Different resources do not overlap for this purpose. They may still alias // at various points but that's beyond the analysis we can do here. - if (lhs.resource != rhs.resource) + if (lhs.resource != rhs.resource) { return false; + } // Check for adjacent but not overlapping. if (lhs.end == rhs.start || lhs.start == rhs.end) { @@ -167,8 +168,9 @@ void printParameterReference(AsmPrinter &p, StringAttr scopeAttr, // static Attribute ResourceConfigAttr::parse(AsmParser &p, Type type) { - if (failed(p.parseLess()) || failed(p.parseLBrace())) + if (failed(p.parseLess()) || failed(p.parseLBrace())) { return {}; + } int64_t maxAllocationSize = 0; int64_t minBufferOffsetAlignment = 0; @@ -183,43 +185,53 @@ Attribute ResourceConfigAttr::parse(AsmParser &p, Type type) { return {}; } if (key == "max_allocation_size") { - if (failed(p.parseInteger(maxAllocationSize))) + if (failed(p.parseInteger(maxAllocationSize))) { return {}; + } } else if (key == "min_buffer_offset_alignment") { - if (failed(p.parseInteger(minBufferOffsetAlignment))) + if (failed(p.parseInteger(minBufferOffsetAlignment))) { return {}; + } } else if (key == "max_buffer_range") { - if (failed(p.parseInteger(maxBufferRange))) + if (failed(p.parseInteger(maxBufferRange))) { return {}; + } } else if (key == "min_buffer_range_alignment") { - if (failed(p.parseInteger(minBufferRangeAlignment))) + if (failed(p.parseInteger(minBufferRangeAlignment))) { return {}; + } } else if (key == "index_bits") { - if (failed(p.parseInteger(indexBits))) + if (failed(p.parseInteger(indexBits))) { return {}; + } } else if (key == "alias_mutable_bindings") { StringRef value; - if (failed(p.parseKeyword(&value))) + if (failed(p.parseKeyword(&value))) { return {}; - if (value == "true") + } + if (value == "true") { aliasMutableBindings = true; - else if (value == "false") + } else if (value == "false") { aliasMutableBindings = false; - else + } else { return {}; + } } else if (key == "memory_model") { StringRef value; - if (failed(p.parseKeyword(&value))) + if (failed(p.parseKeyword(&value))) { return {}; + } auto enumValue = symbolizeMemoryModel(value); - if (!enumValue.has_value()) + if (!enumValue.has_value()) { return {}; + } memoryModel = enumValue.value(); } (void)p.parseOptionalComma(); } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } return ResourceConfigAttr::get(p.getContext(), maxAllocationSize, minBufferOffsetAlignment, maxBufferRange, @@ -245,10 +257,12 @@ void ResourceConfigAttr::print(AsmPrinter &p) const { ResourceConfigAttr ResourceConfigAttr::intersectBufferConstraints(ResourceConfigAttr lhs, ResourceConfigAttr rhs) { - if (!lhs) + if (!lhs) { return rhs; - if (!rhs) + } + if (!rhs) { return lhs; + } Builder b(lhs.getContext()); return ResourceConfigAttr::get( b.getContext(), @@ -285,15 +299,17 @@ ResourceConfigAttr ResourceConfigAttr::lookup(Operation *op) { while (op) { // Use an override if specified. auto attr = op->getAttrOfType(attrId); - if (attr) + if (attr) { return attr; + } // See if the affinity specified provides a resource configuration. if (auto affinityOp = dyn_cast(op)) { auto affinityAttr = affinityOp.getAffinityAttr(); if (affinityAttr) { auto attr = affinityAttr.getResourceConfigAttr(); - if (attr) + if (attr) { return attr; + } } } op = op->getParentOp(); @@ -325,13 +341,15 @@ int64_t NamedParameterAttr::getStorageSize() const { Attribute TimepointAttr::parse(AsmParser &p, Type type) { StringRef timeStr; - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } if (failed(p.parseKeyword(&timeStr))) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } if (timeStr != "immediate") { p.emitError(p.getCurrentLocation(), "only immediate timepoint attrs are supported"); @@ -389,8 +407,9 @@ AffinityAttr AffinityAttr::lookupOrDefault(Operation *fromOp) { // static bool AffinityAttr::areCompatible(AffinityAttr desiredAffinity, AffinityAttr requiredAffinity) { - if (desiredAffinity == requiredAffinity) + if (desiredAffinity == requiredAffinity) { return true; + } if ((desiredAffinity && !requiredAffinity) || (requiredAffinity && !desiredAffinity)) { return true; @@ -401,10 +420,12 @@ bool AffinityAttr::areCompatible(AffinityAttr desiredAffinity, // static bool AffinityAttr::canExecuteTogether(AffinityAttr lhs, AffinityAttr rhs) { - if (lhs == rhs) + if (lhs == rhs) { return true; - if ((lhs && !rhs) || (rhs && !lhs)) + } + if ((lhs && !rhs) || (rhs && !lhs)) { return true; + } return lhs.isExecutableWith(rhs); } @@ -429,15 +450,17 @@ AffinityAttr AffinityAttr::joinOR(ArrayRef affinityAttrs) { Attribute PartitioningConfigAttr::parse(AsmParser &p, Type type) { std::string favorStr; - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } if (succeeded(p.parseOptionalStar())) { favorStr = "size"; } else if (failed(p.parseString(&favorStr))) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } auto favor = symbolizeFavor(favorStr); if (!favor.has_value()) { p.emitError(p.getNameLoc(), "unknown favor value: ") << favorStr; @@ -458,8 +481,9 @@ PartitioningConfigAttr PartitioningConfigAttr::lookup(Operation *op) { auto attrId = StringAttr::get(op->getContext(), "stream.partitioning"); while (op) { auto attr = op->getAttrOfType(attrId); - if (attr) + if (attr) { return attr; + } op = op->getParentOp(); } // No config found; use defaults. @@ -499,15 +523,17 @@ static void printLifetime(Lifetime lifetime, llvm::raw_ostream &os) { Type ResourceType::parse(AsmParser &p) { StringRef lifetimeStr; - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } if (succeeded(p.parseOptionalStar())) { lifetimeStr = "*"; } else if (failed(p.parseKeyword(&lifetimeStr))) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } auto lifetime = parseLifetime(lifetimeStr); if (!lifetime.has_value()) { p.emitError(p.getNameLoc(), "unknown lifetime value: ") << lifetimeStr; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp index 891ec95a8584..440c0341b626 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp @@ -137,8 +137,9 @@ struct AnnotateAffinitiesPass // Annotate all ops with derived affinities. for (auto &op : getOperation().getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } if (auto globalOp = dyn_cast(op)) { annotateGlobalOp(globalOp, affinityAnalysis); } else if (auto funcOp = dyn_cast(op)) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp index 74104a7a9536..6197f580acd5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchArguments.cpp @@ -216,8 +216,9 @@ ChangeStatus GlobalPVS::updateOperation(IREE::Util::GlobalOp globalOp, auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp); for (auto use : globalInfo->uses) { auto storeOp = dyn_cast(use); - if (!storeOp) + if (!storeOp) { continue; + } auto value = solver.getElementFor( *this, Position::forValue(storeOp.getStoredGlobalValue()), DFX::Resolution::REQUIRED); @@ -275,8 +276,9 @@ class ValueAlignment } static llvm::MaybeAlign computeAlignment(const ValuePVS::SetTy &set) { - if (set.empty()) + if (set.empty()) { return llvm::MaybeAlign(); + } llvm::MaybeAlign alignment; for (auto value : set) { APInt valueDivisor = (value & (~(value - 1))); @@ -373,8 +375,9 @@ class ArgumentAnalysis { ArrayRef getDispatchSites(IREE::Stream::ExecutableExportOp exportOp) { auto it = entryDispatchMap.find(exportOp); - if (it == entryDispatchMap.end()) + if (it == entryDispatchMap.end()) { return {}; + } return it->second; } @@ -383,8 +386,9 @@ class ArgumentAnalysis { llvm::MaybeAlign getAlignmentFor(Value value) { auto element = solver.lookupElementFor(Position::forValue(value)); - if (!element) + if (!element) { return llvm::MaybeAlign(); + } return element->getAssumedAlignment(); } @@ -422,8 +426,9 @@ class ArgumentAnalysis { for (auto dispatchOp : getDispatchSites(exportOp)) { auto element = solver.lookupElementFor( Position::forValue(dispatchOp.getUniformOperands()[operandIdx])); - if (!element || !element->isValidState()) + if (!element || !element->isValidState()) { return llvm::MaybeAlign(); + } alignment = commonAlignment(alignment, element->getAssumedAlignment()); } if (alignment.valueOrOne().value() == kMaximumAlignment) { @@ -441,8 +446,9 @@ class ArgumentAnalysis { for (auto dispatchOp : getDispatchSites(exportOp)) { auto element = solver.lookupElementFor( Position::forValue(dispatchOp.getResourceOffsets()[resourceIdx])); - if (!element || !element->isValidState()) + if (!element || !element->isValidState()) { return llvm::MaybeAlign(); + } alignment = commonAlignment(alignment, element->getAssumedAlignment()); } if (alignment.valueOrOne().value() == kMaximumAlignment) { @@ -477,8 +483,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, // Operands/resources on the func are in an arbitrary order; get maps that // lets us go from dispatch site operand/resource to function argument. auto funcOp = exportOp.lookupFunctionRef(); - if (!funcOp) + if (!funcOp) { return; + } auto operandToArgMap = IREE::Stream::CmdDispatchOp::makeOperandToArgMap(funcOp); auto resourceToArgMap = @@ -502,8 +509,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, llvm::sort(potentialValues, [](Attribute lhs, Attribute rhs) { auto lhsInt = dyn_cast(lhs); auto rhsInt = dyn_cast(rhs); - if (!lhsInt || !rhsInt) + if (!lhsInt || !rhsInt) { return false; + } return lhsInt.getValue().ult(rhsInt.getValue()); }); auto potentialValuesAttr = ArrayAttr::get(context, potentialValues); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp index eb74773f2ffe..378c3bf6b1d5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateDispatchAssumptions.cpp @@ -54,8 +54,9 @@ class ArgumentAnalysis { LogicalResult run() { for (Operation *analysisRoot : analysisRoots) { - if (failed(solver.initializeAndRun(analysisRoot))) + if (failed(solver.initializeAndRun(analysisRoot))) { return failure(); + } } return success(); } @@ -65,8 +66,9 @@ class ArgumentAnalysis { ArrayRef getDispatchSites(IREE::Stream::ExecutableExportOp exportOp) { auto it = entryDispatchMap.find(exportOp); - if (it == entryDispatchMap.end()) + if (it == entryDispatchMap.end()) { return {}; + } return it->second; } @@ -109,8 +111,9 @@ class ArgumentAnalysis { IREE::Util::IntAssumptionAttr::get(context, umin, umax, udiv)); } - if (assumptions.empty()) + if (assumptions.empty()) { return {}; + } return std::make_pair( ArrayAttr::get(context, ArrayRef(assumptions.begin(), @@ -138,8 +141,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, // Operands/resources on the func are in an arbitrary order; get maps that // lets us go from dispatch site operand/resource to function argument. auto funcOp = exportOp.lookupFunctionRef(); - if (!funcOp || funcOp.empty()) + if (!funcOp || funcOp.empty()) { return; + } auto operandToArgMap = IREE::Stream::CmdDispatchOp::makeOperandToArgMap(funcOp); auto resourceToArgMap = @@ -156,8 +160,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, unsigned argIdx = operandToArgMap[operandIdx]; Value argValue = funcOp.getArgument(argIdx); Type argType = argValue.getType(); - if (!argType.isIndex() && !argType.isInteger()) + if (!argType.isIndex() && !argType.isInteger()) { continue; + } auto [assumptions, hasNonEmpty] = analysis.getOperandAssumptions(exportOp, operandIdx); @@ -168,8 +173,9 @@ static void annotateExport(IREE::Stream::ExecutableOp executableOp, } } - if (nonEmptyCount == 0) + if (nonEmptyCount == 0) { return; + } // Do the rewrite. OpBuilder builder = OpBuilder::atBlockBegin(&funcOp.front()); @@ -186,8 +192,9 @@ class AnnotateDispatchAssumptionsPass AnnotateDispatchAssumptionsPass> { void runOnOperation() override { ArgumentAnalysis analysis(getOperation()); - if (failed(analysis.run())) + if (failed(analysis.run())) { return signalPassFailure(); + } // Annotate the exported dispatch functions. for (auto executableOp : getOperation().getBodyRegion().getOps()) { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 1894b47f9694..86bc008142bd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -45,8 +45,9 @@ namespace { static bool doesOperationNeedWrapping(Operation *op) { return llvm::any_of(op->getOperands(), [](Value operand) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { return false; + } return !isa_and_nonnull( operand.getDefiningOp()); }) || diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp index 63c70436ab16..9715f7929fbe 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/DumpStatistics.cpp @@ -127,8 +127,9 @@ struct Statistics { for (auto [name, globalOp] : usageInfo.resourceGlobalOps) { auto globalType = dyn_cast(globalOp.getType()); - if (!globalType) + if (!globalType) { continue; + } // TODO(benvanik): analyze size in UsageInfo where possible. switch (globalType.getLifetime()) { case IREE::Stream::Lifetime::Constant: @@ -436,14 +437,16 @@ static void dumpExecutionCSVTable(const UsageInfo &usageInfo, TypeSwitch(op) .Case([&](auto op) { ++depth; - for (auto &nestedOp : op.getBody().front()) + for (auto &nestedOp : op.getBody().front()) { dumpRow(&nestedOp); + } --depth; }) .Case([&](auto op) { ++depth; - for (auto &nestedOp : op.getBody().front()) + for (auto &nestedOp : op.getBody().front()) { dumpRow(&nestedOp); + } --depth; }) .Case([&](auto op) { @@ -462,8 +465,9 @@ static void dumpExecutionCSVTable(const UsageInfo &usageInfo, auto workload = op.getWorkload(); SmallString<32> workloadStr; for (unsigned i = 0; i < workload.size(); ++i) { - if (i > 0) + if (i > 0) { workloadStr.append(";"); + } APInt dimValue; if (matchPattern(workload[i], m_ConstantInt(&dimValue))) { dimValue.toString(workloadStr, 10, /*signed=*/true); @@ -575,8 +579,9 @@ openOutputFile(StringRef filePath) { std::error_code ec; auto result = std::make_unique( filePath, ec, llvm::sys::fs::OF_TextWithCRLF); - if (!ec) + if (!ec) { return result; + } llvm::errs() << "Error opening iree-stream-dump-statistics output file '" << filePath << "'\n"; return std::make_unique(2, false); // stderr. @@ -588,8 +593,9 @@ struct DumpStatisticsPass using IREE::Stream::impl::DumpStatisticsPassBase< DumpStatisticsPass>::DumpStatisticsPassBase; void runOnOperation() override { - if (outputFormat == DumpOutputFormat::None) + if (outputFormat == DumpOutputFormat::None) { return; + } // Open the output file we'll be streaming to. // Since we are processing the entire module at once we overwrite the file. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp index c25281f0ed33..563a65034c07 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp @@ -173,8 +173,9 @@ class ArgumentSemantics const std::string getAsStr(AsmState &asmState) const override { std::string str; auto append = [&](const char *part) { - if (!str.empty()) + if (!str.empty()) { str += '|'; + } str += part; }; append(this->isAssumed(NOT_MUTATED) ? "immutable" : "mutable"); @@ -190,8 +191,9 @@ class ArgumentSemantics static bool isTiedUse(OpOperand &operand) { if (auto tiedOp = dyn_cast(operand.getOwner())) { - if (tiedOp.isOperandTied(operand.getOperandNumber())) + if (tiedOp.isOperandTied(operand.getOperandNumber())) { return true; + } } return false; } @@ -573,8 +575,9 @@ class ElisionAnalysis { bool isArgMoved(BlockArgument arg) { auto argumentSemantics = solver.lookupElementFor(Position::forValue(arg)); - if (!argumentSemantics) + if (!argumentSemantics) { return false; + } return argumentSemantics->getAssumedByValue(); } @@ -1006,16 +1009,18 @@ static bool isSafeToElideSliceOp(IREE::Stream::AsyncSliceOp sliceOp, SmallVector consumerRanges; SmallVector queryRanges; for (auto user : source.getUsers()) { - if (user == sliceOp) + if (user == sliceOp) { continue; + } if (auto accessOp = dyn_cast(user)) { // Async op consuming part of the resource. We can query it to see what // it's doing to its operands/results and filter to just the accesses of // the source value. accessOp.getAsyncAccessRanges(queryRanges); for (auto range : queryRanges) { - if (range.resource == source) + if (range.resource == source) { consumerRanges.push_back(range); + } } queryRanges.clear(); } else { @@ -1058,10 +1063,12 @@ static bool isSafeToElideSliceOp(IREE::Stream::AsyncSliceOp sliceOp, // arith.addi folders are terrible and don't handle adds of 0 so we handle that // here and then avoid doing the folding. static Value addOffset(Value lhs, Value rhs, OpBuilder &builder) { - if (matchPattern(lhs, m_Zero())) + if (matchPattern(lhs, m_Zero())) { return rhs; - if (matchPattern(rhs, m_Zero())) + } + if (matchPattern(rhs, m_Zero())) { return lhs; + } return builder.createOrFold( builder.getFusedLoc(lhs.getLoc(), rhs.getLoc()), lhs, rhs); } @@ -1111,8 +1118,9 @@ static void foldSliceIntoDispatch(IREE::Stream::AsyncSliceOp sliceOp, // Elides a stream.async.slice op (assuming able) by folding it into consumers. static void elideSliceOp(IREE::Stream::AsyncSliceOp sliceOp) { SmallVector> consumers; - for (auto &use : sliceOp.getResult().getUses()) + for (auto &use : sliceOp.getResult().getUses()) { consumers.push_back(std::make_pair(use.getOwner(), use.getOperandNumber())); + } for (auto [owner, operandNumberIt] : consumers) { unsigned operandNumber = operandNumberIt; // need C++20 to avoid this :| TypeSwitch(owner) @@ -1222,8 +1230,9 @@ static bool isSafeToElideUpdateOp(IREE::Stream::AsyncUpdateOp updateOp, // the dispatch fully overwrites our update region. if (auto dispatchOp = dyn_cast(user)) { for (auto &operand : user->getOpOperands()) { - if (operand.get() != result) + if (operand.get() != result) { continue; + } if (dispatchOp.isOperandTied(operand.getOperandNumber())) { // Result is tied to dispatch - check if dispatch fully overwrites // the update region. If not, downstream reads might access our diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp index 30e478919cb0..119bd605d7eb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ElideTimepoints.cpp @@ -786,8 +786,9 @@ class TimepointCoverageAnalysis { // Seed FenceCoverage for fence values that might be relevant. // This includes fences from timeline-aware ops and imported fences. for (auto op : block.getOps()) { - if (!op.participatesInTimeline()) + if (!op.participatesInTimeline()) { continue; + } if (Value signalFence = op.getSignalFence()) { solver.getOrCreateElementFor( Position::forValue(signalFence)); @@ -828,8 +829,9 @@ class TimepointCoverageAnalysis { }; for (auto callableOp : getTopLevelOps()) { auto *region = callableOp.getCallableRegion(); - if (!region || region->empty()) + if (!region || region->empty()) { continue; + } seedRegion(*region); } @@ -846,8 +848,9 @@ class TimepointCoverageAnalysis { // Returns true if |value| is known to be immediately resolved. bool isImmediate(Value value) { - if (isDefinedImmediate(value)) + if (isDefinedImmediate(value)) { return true; + } auto &isImmediate = solver.getOrCreateElementFor(Position::forValue(value)); return isImmediate.isValidState() && isImmediate.isKnown(); @@ -881,8 +884,9 @@ class TimepointCoverageAnalysis { bool unionTransitivelyReachedTimepoints(Value value, SetVector &set) { auto coverage = solver.getOrCreateElementFor( Position::forValue(value)); - if (!coverage.isValidState() || coverage.isUndefContained()) + if (!coverage.isValidState() || coverage.isUndefContained()) { return false; + } for (auto reached : coverage.getAssumedSet()) { set.insert(reached); } @@ -914,8 +918,9 @@ buildRequiredCoverageSet(SmallVector possibleTimepoints, if (isValid) { for (auto reachedTimepoint : reachedTimepoints) { // TODO(benvanik): avoid self-references so we don't need this check. - if (reachedTimepoint == possibleTimepoint) + if (reachedTimepoint == possibleTimepoint) { continue; + } ++coverageMap[reachedTimepoint]; } } @@ -1036,8 +1041,9 @@ static bool trySinkAwaitIntoBranch(IREE::Stream::TimepointAwaitOp awaitOp, llvm::dbgs() << "[ElideTimepoints] sinking await into scf.if "; bool first = true; for (Region *region : regionsWithDirectUse) { - if (!first) + if (!first) { llvm::dbgs() << " and "; + } if (region == &ifOp.getThenRegion()) { llvm::dbgs() << "then"; } else { @@ -1066,8 +1072,9 @@ static bool trySinkAwaitIntoBranch(IREE::Stream::TimepointAwaitOp awaitOp, bool first = true; auto caseRegions = switchOp.getCaseRegions(); for (Region *region : regionsWithDirectUse) { - if (!first) + if (!first) { llvm::dbgs() << ", "; + } // Find which case this is. bool foundCase = false; for (auto [idx, caseRegion] : llvm::enumerate(caseRegions)) { @@ -1532,8 +1539,9 @@ static bool tryElideTimepointsInRegion(Region ®ion, // Elides |elidedTimepoint| by replacing all its uses by |op| with an // immediate timepoint value. auto elideTimepointOperand = [&](Operation *op, Value elidedTimepoint) { - if (isDefinedImmediate(elidedTimepoint)) + if (isDefinedImmediate(elidedTimepoint)) { return; // already immediate + } auto immediateTimepoint = makeImmediate(elidedTimepoint, OpBuilder(op)); elidedTimepoint.replaceUsesWithIf( immediateTimepoint, @@ -1544,10 +1552,12 @@ static bool tryElideTimepointsInRegion(Region ®ion, // Elides all timepoint operands of |op| that are immediately resolved. auto elideTimepointOperands = [&](Operation *op) { for (auto operand : llvm::make_early_inc_range(op->getOperands())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; - if (isDefinedImmediate(operand)) + } + if (isDefinedImmediate(operand)) { continue; + } if (analysis.isImmediate(operand)) { LLVM_DEBUG({ llvm::dbgs() << " >>> eliding known-immediate operand "; @@ -1562,10 +1572,12 @@ static bool tryElideTimepointsInRegion(Region ®ion, // Elides |elidedTimepoint| by replacing all its uses with an immediate // timepoint value. The original value will end up with zero uses. auto elideTimepointResult = [&](Operation *op, Value elidedTimepoint) { - if (elidedTimepoint.use_empty()) + if (elidedTimepoint.use_empty()) { return; // no-op - if (isDefinedImmediate(elidedTimepoint)) + } + if (isDefinedImmediate(elidedTimepoint)) { return; // already immediate + } OpBuilder afterBuilder(op); afterBuilder.setInsertionPointAfterValue(elidedTimepoint); Value immediateTimepoint = IREE::Stream::TimepointImmediateOp::create( @@ -1583,10 +1595,12 @@ static bool tryElideTimepointsInRegion(Region ®ion, // %imm0 = immediate // %imm1 = immediate for (auto result : llvm::reverse(op->getResults())) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; - if (isDefinedImmediate(result)) + } + if (isDefinedImmediate(result)) { continue; + } if (analysis.isImmediate(result)) { LLVM_DEBUG({ llvm::dbgs() << " >>> eliding known-immediate result "; @@ -1604,8 +1618,9 @@ static bool tryElideTimepointsInRegion(Region ®ion, auto processTimelineOp = [&](IREE::Stream::TimelineOpInterface op) { auto resultTimepoint = op.getResultTimepoint(); auto awaitTimepoints = op.getAwaitTimepoints(); - if (awaitTimepoints.empty()) + if (awaitTimepoints.empty()) { return; + } LLVM_DEBUG({ llvm::dbgs() << "[ElideTimepoints] pruning " << op->getName() @@ -1652,8 +1667,9 @@ static bool tryElideTimepointsInRegion(Region ®ion, } // If there's only one timepoint we don't have to worry with coverage. - if (possibleTimepoints.size() <= 1) + if (possibleTimepoints.size() <= 1) { return; + } // Perform the analysis on the possible timepoints to find which are covered // by others and elide all of those known-covered. @@ -1761,8 +1777,9 @@ struct ElideTimepointsPass : public IREE::Stream::impl::ElideTimepointsPassBase { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } // Perform whole-program analysis to find for each timepoint what other // timepoints are known to be reached. @@ -1793,8 +1810,9 @@ struct ElideTimepointsPass tryElideTimepointsInRegion(*region, analysis, domInfo) || didChange; } - if (didChange) + if (didChange) { signalFixedPointModified(moduleOp); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp index 9b808c74674b..d1ee6cbb0981 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp @@ -60,8 +60,9 @@ static LogicalResult checkEncoding(Operation *op, RankedTensorType encodingType, static RankedTensorType alignTensorType(RankedTensorType originalType) { Type elementType = originalType.getElementType(); Type alignedType = legalizeStorageElementType(originalType); - if (alignedType == elementType) + if (alignedType == elementType) { return originalType; + } return RankedTensorType::get(originalType.getShape(), alignedType, originalType.getEncoding()); } @@ -79,8 +80,9 @@ static Value makeTensorDim(Location loc, RankedTensorType tensorType, // Map from absolute dimension index to the compact dynamic index. unsigned di = 0; for (unsigned j = 0; j < i; ++j) { - if (tensorType.isDynamicDim(j)) + if (tensorType.isDynamicDim(j)) { ++di; + } } return dynamicDims[di]; } @@ -661,8 +663,9 @@ alignDispatchTensorType(IREE::TensorExt::DispatchTensorType originalType) { Type elementType = originalType.getBoundElementType(); Type alignedType = legalizeStorageElementType(originalType.asRankedTensorType()); - if (alignedType == elementType) + if (alignedType == elementType) { return originalType; + } return IREE::TensorExt::DispatchTensorType::get( originalType.getAccess(), originalType.getShape(), alignedType); } @@ -688,8 +691,9 @@ struct EncodeBindingSubspanOp // Align the element type, if needed. IREE::TensorExt::DispatchTensorType alignedType = alignDispatchTensorType(originalType); - if (originalType == alignedType) + if (originalType == alignedType) { return failure(); // already aligned. + } // Directly swap the type with the one, changing all uses in the IR. // This works because @@ -713,8 +717,9 @@ struct EncodeDispatchTensorLoadOp // Align the element type, if needed. RankedTensorType alignedType = alignTensorType(targetType); - if (targetType == alignedType) + if (targetType == alignedType) { return failure(); // already aligned. + } // Loads always truncate from an byte aligned type to a sub-byte one. assert(targetType.getElementTypeBitWidth() < @@ -747,8 +752,9 @@ struct EncodeDispatchTensorStoreOp // Align the element type, if needed. RankedTensorType alignedType = alignTensorType(sourceType); - if (sourceType == alignedType) + if (sourceType == alignedType) { return failure(); // already aligned. + } // Stores always extend from a sub-byte aligned type to a byte aligned one. assert(sourceType.getElementTypeBitWidth() < diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp index 50440574d0c8..4a9a1c2deb3f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FoldUniformOperands.cpp @@ -143,8 +143,9 @@ deduplicateOperands(mlir::FunctionOpInterface funcOp, llvm::interleaveComma(deadOperandsMap.set_bits(), llvm::dbgs()); llvm::dbgs() << "\n"; for (auto replacement : llvm::enumerate(argReplacementMap)) { - if (replacement.index() == replacement.value()) + if (replacement.index() == replacement.value()) { continue; + } llvm::dbgs() << " %arg" << replacement.index() << " -> %arg" << replacement.value() << "\n"; } @@ -155,8 +156,9 @@ deduplicateOperands(mlir::FunctionOpInterface funcOp, for (auto replacement : llvm::enumerate(argReplacementMap)) { unsigned deadIdx = replacement.index(); unsigned liveIdx = replacement.value(); - if (deadIdx == liveIdx) + if (deadIdx == liveIdx) { continue; + } deadArgMap.set(deadIdx); entryBlock.getArgument(deadIdx).replaceAllUsesWith( entryBlock.getArgument(liveIdx)); @@ -164,8 +166,9 @@ deduplicateOperands(mlir::FunctionOpInterface funcOp, // Update each dispatch site to remove duplicates. SmallVector deadOperands; - for (auto idx : deadOperandsMap.set_bits()) + for (auto idx : deadOperandsMap.set_bits()) { deadOperands.push_back(idx); + } for (auto dispatchOp : dispatchOps) { for (auto idx : llvm::reverse(deadOperands)) { dispatchOp.getUniformOperandsMutable().erase(idx); @@ -202,8 +205,9 @@ inlineUniformConstants(mlir::FunctionOpInterface funcOp, llvm::BitVector uniformOperandMap(operandCount, /*t=*/true); for (auto dispatchOp : dispatchOps) { for (unsigned idx = 0; idx < operandCount; ++idx) { - if (!uniformOperandMap.test(idx)) + if (!uniformOperandMap.test(idx)) { continue; + } auto value = dispatchOp.getUniformOperands()[idx]; APInt intValue; if (!matchPattern(value, m_ConstantInt(&intValue))) { @@ -232,8 +236,9 @@ inlineUniformConstants(mlir::FunctionOpInterface funcOp, LLVM_DEBUG({ llvm::dbgs() << "inlineUniformConstants for " << funcOp.getName() << "\n"; for (unsigned i = 0; i < operandValues.size(); ++i) { - if (!operandValues[i].has_value()) + if (!operandValues[i].has_value()) { continue; + } llvm::dbgs() << " operand " << i << " = " << operandValues[i].value() << "\n"; } @@ -258,8 +263,9 @@ inlineUniformConstants(mlir::FunctionOpInterface funcOp, // Update each dispatch site to remove duplicates. SmallVector deadOperands; - for (auto idx : uniformOperandMap.set_bits()) + for (auto idx : uniformOperandMap.set_bits()) { deadOperands.push_back(idx); + } for (auto dispatchOp : dispatchOps) { for (auto idx : llvm::reverse(deadOperands)) { dispatchOp.getUniformOperandsMutable().erase(idx); @@ -410,8 +416,9 @@ struct FoldUniformOperandsPass for (auto exportOp : executableOp.getOps()) { auto &dispatchOps = entryDispatchMap[exportOp]; - if (dispatchOps.empty()) + if (dispatchOps.empty()) { continue; // no-op if no dispatches + } auto funcOp = exportOp.lookupFunctionRef(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp index 3e182f2b97c1..5ca60ead127b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/FuseDispatchBindings.cpp @@ -148,8 +148,9 @@ findCorrelatedBindings(unsigned bindingCount, llvm::BitVector handledBindings(bindingCount, /*t=*/false); for (unsigned i = 0; i < bindingCount; ++i) { // Ignore bindings we've already covered earlier during iteration. - if (handledBindings.test(i)) + if (handledBindings.test(i)) { continue; + } // Build new binding. Binding binding; @@ -316,8 +317,9 @@ fuseDispatchBindings(IREE::Stream::ExecutableOp executableOp, IREE::Stream::ExecutableExportOp exportOp, ArrayRef dispatchOps, MemoizedCmdZeros &memoizedZeros) { - if (dispatchOps.empty()) + if (dispatchOps.empty()) { return; // no-op if no dispatches + } auto anyDispatchOp = dispatchOps.front(); unsigned bindingCount = anyDispatchOp.getResources().size(); @@ -443,8 +445,9 @@ struct FuseDispatchBindingsPass MemoizedCmdZeros memoizedZeros; for (auto executableOp : getOperation().getBodyRegion().getOps()) { - if (!executableOp.getInnerModule()) + if (!executableOp.getInnerModule()) { continue; + } for (auto exportOp : executableOp.getOps()) { fuseDispatchBindings(executableOp, exportOp, entryDispatchMap[exportOp], diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp index d3a992c6fcaa..184c713300e0 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp @@ -180,8 +180,9 @@ packDynamicSlicesConservatively(Location loc, Value baseOffset, SmallVector slices; bool intersects(const Slice &slice) const { for (auto *binSlice : slices) { - if (binSlice->intersects(slice)) + if (binSlice->intersects(slice)) { return true; + } } return false; } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp index b9b1b9589e10..e122c3917497 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeBuiltins.cpp @@ -351,8 +351,9 @@ struct MaterializeBuiltinsPass MaterializeBuiltinsPass> { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } // Find and replace (if needed) ops that we want to turn into builtins // across the entire program. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp index dd9a40d7e80d..65211e9e23b6 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp @@ -41,26 +41,29 @@ namespace { static bool isSafeToElideCOW(Value operand, IREE::Stream::ResourceType type) { // Can't do anything with block args without analysis - we don't know if the // value they carry is the last user (move semantics). - if (isa(operand)) + if (isa(operand)) { return false; + } // If our value is a constant then we need to ensure that we aren't // tied to a constant operand. If we are we need to clone to a // non-constant value. We could make this work in cases where constants are // being initialized, however those are best modeled as transfer operations // where no mutations will occur on the constant transfer target. - if (type.getLifetime() == IREE::Stream::Lifetime::Constant) + if (type.getLifetime() == IREE::Stream::Lifetime::Constant) { return false; + } // If there's more than one user we can't make a local decision. It's // expensive to query relative operation order within a block and within a // region the lifetime of values may vary - all things we can't tell here. Operation *firstUser = nullptr; for (Operation *user : operand.getUsers()) { - if (firstUser == nullptr) + if (firstUser == nullptr) { firstUser = user; - else if (firstUser != user) + } else if (firstUser != user) { return false; + } } // We are the only user and the value is contained entirely within the @@ -80,10 +83,12 @@ static Value materializeOperandCOW(Location loc, OpOperand &operand, // has to wait until a subsequent pass. auto resourceType = dyn_cast(operand.get().getType()); - if (!resourceType) + if (!resourceType) { return nullptr; - if (isSafeToElideCOW(operand.get(), resourceType)) + } + if (isSafeToElideCOW(operand.get(), resourceType)) { return nullptr; + } // Materialize a clone operation just for the operand provided. auto sizeAwareType = cast(resourceType); @@ -110,8 +115,9 @@ static bool materializeTiedOpCOW(IREE::Util::TiedOpInterface tiedOp) { auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices(); for (unsigned i = 0; i < tiedOperandIndices.size(); ++i) { int64_t operandIdx = tiedOperandIndices[i]; - if (operandIdx == IREE::Util::TiedOpInterface::kUntiedIndex) + if (operandIdx == IREE::Util::TiedOpInterface::kUntiedIndex) { continue; + } auto &tiedOperand = tiedOp->getOpOperand(operandIdx); // If copy was required and materialized, we should forward it to all @@ -125,8 +131,9 @@ static bool materializeTiedOpCOW(IREE::Util::TiedOpInterface tiedOp) { // TODO(#11249): Support in-place collective operations. if (!isa(tiedOp)) { for (auto &operand : tiedOp->getOpOperands()) { - if (operand.get() == original) + if (operand.get() == original) { operand.set(clone); + } } } } @@ -141,8 +148,9 @@ static bool materializeRegionCOW(Region ®ion) { bool didChange = false; for (auto &block : region.getBlocks()) { for (auto &op : block) { - if (!op.hasTrait()) + if (!op.hasTrait()) { continue; + } didChange = TypeSwitch(&op) .Case values; int64_t offset = 0; for (auto &constantSpan : storageBuffer.spans) { - if (constantSpan.length == 0) + if (constantSpan.length == 0) { continue; + } int64_t start = constantSpan.offset; int64_t end = start + constantSpan.length; @@ -465,8 +466,9 @@ static Value generateSerializedUpload( // will need and where each value will be placed. auto storageResources = computePackingMap(slices, resourceConfig, builder.getContext()); - if (storageResources.empty()) + if (storageResources.empty()) { return nullptr; + } // TODO(benvanik): should be able to have a single buffer constant and // subrange it so that we don't need so many files. @@ -551,8 +553,9 @@ static Value generateParameterUpload( storageResources = computePackingMap(slices, resourceConfig, builder.getContext()); } - if (storageResources.empty()) + if (storageResources.empty()) { return nullptr; + } // Sort resources by type so we can batch them. // Loads are only possible if we are using the parameter as a constant and diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp index 54cbb46de0a0..3427890e5ca5 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackDispatchOperands.cpp @@ -251,16 +251,18 @@ static Value recomposeFromI32sAndConvert( // Preserve the arg attrs on either the final op or the function argument // if none was required. if (auto definingOp = value.getDefiningOp()) { - if (oldArgAttr) + if (oldArgAttr) { definingOp->setAttrs(oldArgAttr); + } newArgAttrs.push_back(nullptr); } else { newArgAttrs.push_back(oldArgAttr); } // Note that if we had decomposed the arg we'll expect that there are two attr // dicts for the two new args. - if (wasDecomposed) + if (wasDecomposed) { newArgAttrs.push_back(nullptr); + } return value; } @@ -311,8 +313,9 @@ struct PackDispatchOperandsPass for (auto executableOp : getOperation().getOps()) { auto innerModuleOp = executableOp.getInnerModule(); - if (!innerModuleOp) + if (!innerModuleOp) { continue; + } for (auto funcOp : innerModuleOp.getOps()) { if (funcOp.isPublic()) { updateExportFuncOp(funcOp); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp index f1f5eb7b2f7e..3e51b04f75b4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp @@ -64,8 +64,9 @@ static ExpandedGlobalMap expandResourceGlobals(Operation *rootOp, // Gather all of the resource globals in the root. for (auto ®ion : rootOp->getRegions()) { for (auto globalOp : region.getOps()) { - if (!isa(globalOp.getType())) + if (!isa(globalOp.getType())) { continue; + } expandedGlobals[globalOp.getName()].resourceOp = globalOp; } } @@ -113,8 +114,9 @@ static void expandType(Type type, SmallVectorImpl &newTypes) { // Expands resources in the given |types| list to (timepoint, resource). // This could be changed to some iterator magic to avoid the alloc. static SmallVector expandTypes(TypeRange types) { - if (types.empty()) + if (types.empty()) { return {}; + } SmallVector newTypes; newTypes.reserve(types.size() * 2); for (auto type : types) { @@ -199,8 +201,9 @@ static Value makeBlockArgResourceSize(Location loc, Value resourceValue, if (auto sizeAwareOp = dyn_cast_if_present( resourceValue.getDefiningOp())) { auto sizeValue = sizeAwareOp.getResultSizeFromValue(resourceValue); - if (sizeValue) + if (sizeValue) { return sizeValue; + } } // Try first to scan uses in the IR. Since we carry the shape in most ops we @@ -208,11 +211,13 @@ static Value makeBlockArgResourceSize(Location loc, Value resourceValue, for (auto &use : resourceValue.getUses()) { auto sizeAwareOp = dyn_cast(use.getOwner()); - if (!sizeAwareOp) + if (!sizeAwareOp) { continue; + } auto sizeValue = sizeAwareOp.getOperandSize(use.getOperandNumber()); - if (!sizeValue) + if (!sizeValue) { continue; + } if (sizeValue.getParentRegion()->isProperAncestor( builder.getInsertionBlock()->getParent())) { // Size value found and implicitly captured; we can reuse (could be @@ -242,16 +247,19 @@ static Value makeBlockArgResourceSize(Location loc, Value resourceValue, static void expandRegion(Region ®ion, bool canModifyEntryBlock, SymbolTable &symbolTable, ExpandedGlobalMap &globalMap, IRMapping &resourceTimepointMap) { - if (region.empty()) + if (region.empty()) { return; + } // Update all block arguments. auto timepointType = IREE::Stream::TimepointType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) + if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) { continue; - if (block.isEntryBlock() && !canModifyEntryBlock) + } + if (block.isEntryBlock() && !canModifyEntryBlock) { continue; + } // Insert and build a list of expanded (timepoint, resource) pairs. // Don't add mappings here - we need to check if wrapExpandedBlockArgFn @@ -259,8 +267,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, SmallVector> expansions; for (int i = block.getNumArguments() - 1; i >= 0; --i) { auto resourceArg = block.getArgument(i); - if (!isResourceType(resourceArg.getType())) + if (!isResourceType(resourceArg.getType())) { continue; + } auto timepointArg = block.insertArgument(i + 1, timepointType, resourceArg.getLoc()); expansions.push_back(std::make_pair(timepointArg, resourceArg)); @@ -272,8 +281,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, // If the resource already has an associated timepoint mapping from the // region branch expansion (wrapExpandedBlockArgFn), defer awaiting to // the consumer to avoid over-synchronization at block boundaries. - if (resourceTimepointMap.contains(resource)) + if (resourceTimepointMap.contains(resource)) { continue; + } // Add the mapping for this block arg since we're inserting an await. resourceTimepointMap.map(resource, timepoint); // If we can look down the chain and see the size then we can use that. @@ -325,8 +335,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, ExpandedGlobalMap &globalMap, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto &expandedGlobal = globalMap[op.getGlobalName()]; auto timepoint = expandedGlobal.timepointOp.createLoadOp(op.getLoc(), builder) @@ -369,8 +380,9 @@ static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op, ExpandedGlobalMap &globalMap, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto timepointOperand = consumeTimepoint( op.getLoc(), op.getStoredGlobalValue(), resourceTimepointMap, builder); @@ -433,13 +445,15 @@ static void expandFuncOp(IREE::Util::FuncOp op, SymbolTable &symbolTable, // stream.timepoint.await %rt, %t static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } // Ignore calls to public/external functions. auto calleeOp = symbolTable.lookup(op.getCallee()); - if (IREE::Util::isPublicOrExternal(calleeOp)) + if (IREE::Util::isPublicOrExternal(calleeOp)) { return; + } // Build the new call op with expanded operands and results. OpBuilder builder(op); @@ -490,10 +504,13 @@ static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, // util.return %t, %0 static void expandReturnOp(IREE::Util::ReturnOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; - if (IREE::Util::isPublicOrExternal(op->getParentOfType())) + } + if (IREE::Util::isPublicOrExternal( + op->getParentOfType())) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getOperands(), resourceTimepointMap, builder); @@ -514,8 +531,9 @@ static void expandReturnOp(IREE::Util::ReturnOp op, // %1 = stream.timepoint.await %bb_t, %bb_0 static void expandBranchOp(mlir::cf::BranchOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getDestOperands(), resourceTimepointMap, builder); @@ -525,8 +543,9 @@ static void expandBranchOp(mlir::cf::BranchOp op, static void expandCondBranchOp(mlir::cf::CondBranchOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); mlir::cf::CondBranchOp::create( builder, op.getLoc(), op.getCondition(), op.getTrueDest(), @@ -540,8 +559,9 @@ static void expandCondBranchOp(mlir::cf::CondBranchOp op, static void expandSwitchOp(mlir::cf::SwitchOp op, IRMapping &resourceTimepointMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto caseOperands = llvm::to_vector( llvm::map_range(op.getCaseOperands(), [&](ValueRange operands) { @@ -577,8 +597,9 @@ static void expandAwaitOp(IREE::Stream::TimepointAwaitOp op, // mappings to leak between sibling regions (e.g., scf.if then/else // branches), leading to invalid IR where one branch tries to use a // timepoint defined in another branch. - if (isa(inputOperand)) + if (isa(inputOperand)) { continue; + } resourceTimepointMap.map(inputOperand, op.getAwaitTimepoint()); } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp index 833ac102a220..c3dfaf6b41dd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/RefineUsage.cpp @@ -132,8 +132,9 @@ struct UsageRefinementPattern : public OpRewritePattern { // Returns true if a change was made. bool applyArgTransition(BlockArgument arg, PatternRewriter &rewriter) const { auto oldType = dyn_cast(arg.getType()); - if (!oldType) + if (!oldType) { return false; + } auto newUsage = analysis.lookupResourceUsage(arg); auto newLifetime = convertUsageToLifetime(newUsage); auto newType = rewriter.getType(newLifetime); @@ -155,8 +156,9 @@ struct UsageRefinementPattern : public OpRewritePattern { bool applyResultTransition(Operation *op, Value result, PatternRewriter &rewriter) const { auto oldType = dyn_cast(result.getType()); - if (!oldType) + if (!oldType) { return false; + } auto newUsage = analysis.lookupResourceUsage(result); auto newLifetime = convertUsageToLifetime(newUsage); auto newType = rewriter.getType(newLifetime); @@ -193,8 +195,9 @@ struct UsageRefinementPattern : public OpRewritePattern { IREE::Stream::AffinityAttr affinityAttr, PatternRewriter &rewriter) const { auto oldType = dyn_cast(result.getType()); - if (!oldType) + if (!oldType) { return false; + } auto newUsage = analysis.lookupResourceUsage(result); auto newLifetime = convertUsageToLifetime(newUsage); auto newType = rewriter.getType(newLifetime); @@ -335,8 +338,9 @@ struct ApplyFuncOp : public UsageRefinementPattern { } // Blocks and nested operations: - if (this->applyRegionTransitions(op, rewriter)) + if (this->applyRegionTransitions(op, rewriter)) { didChange = true; + } return success(didChange); } @@ -350,8 +354,9 @@ struct ApplyScfIfOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange |= true; + } } } @@ -367,8 +372,9 @@ struct ApplyScfForOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange |= true; + } } } return success(didChange); @@ -383,8 +389,9 @@ struct ApplyScfWhileOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange |= true; + } } } @@ -406,8 +413,9 @@ struct ApplyGenericOp : public UsageRefinementPattern { for (unsigned i = 0; i < op->getNumResults(); ++i) { auto result = op->getResult(i); if (isa(result.getType())) { - if (this->applyResultTransition(op, result, rewriter)) + if (this->applyResultTransition(op, result, rewriter)) { didChange = true; + } } } if (didChange) { @@ -535,8 +543,9 @@ struct RefineUsagePass : public IREE::Stream::impl::RefineUsagePassBase { void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) + if (moduleOp.getBody()->empty()) { return; + } // Run analysis on the entire module. ResourceUsageAnalysis analysis(moduleOp); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index fb11759a5682..2e50bd325a2f 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -47,8 +47,9 @@ class ValueAliasingSet { SmallVector> getValueAliasSets() const { SmallVector> result; for (auto it = valueAliasing.begin(); it != valueAliasing.end(); ++it) { - if (!(*it)->isLeader()) + if (!(*it)->isLeader()) { continue; // Ignore non-leader sets. + } auto &aliasSet = result.emplace_back(); for (auto mi = valueAliasing.member_begin(**it); mi != valueAliasing.member_end(); ++mi) { @@ -110,8 +111,9 @@ static void computeRegionValueAliases(Operation *regionOp, // Tied results reuse their operand buffer. auto tiedOp = dyn_cast(op); for (auto result : op.getResults()) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } if (tiedOp) { auto tiedOperand = tiedOp.getTiedResultOperand(result); if (tiedOperand) { @@ -181,8 +183,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, SmallPtrSet liveOuts; auto yieldOp = cast(streamBlock->back()); for (auto returnValue : yieldOp.getResourceOperands()) { - if (!isa(returnValue.getType())) + if (!isa(returnValue.getType())) { continue; + } liveOuts.insert(returnValue); } @@ -191,8 +194,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, LivenessIntervalMap valueIntervals; int ordinal = 0; for (Value value : streamBlock->getArguments()) { - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } LivenessInterval interval; interval.start = LIVE_IN; if (liveOuts.contains(value)) { @@ -218,16 +222,19 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, // the duration of the region. concurrentOp.walk([&](Operation *op) { for (auto value : op->getResults()) { - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } if (auto tiedOp = dyn_cast(op)) { // Skip tied results as their liveness is determined by the tied // operand. - if (tiedOp.getTiedResultOperand(value)) + if (tiedOp.getTiedResultOperand(value)) { continue; + } } - if (!value.use_empty()) + if (!value.use_empty()) { continue; + } LivenessInterval interval; interval.start = start; interval.end = start; @@ -238,8 +245,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, }); } for (auto value : op.getResults()) { - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } LivenessInterval interval; interval.start = start; if (liveOuts.contains(value)) { @@ -267,8 +275,9 @@ computeExecutionRegionLivenessIntervals(IREE::Stream::AsyncExecuteOp executeOp, // We'd need to update this analysis to handle the nesting in order to // compute the ranges here but that's not (currently) required as all // allocated values roll up to the parent scope by way of the yields. - if (llvm::all_of(aliasSet, isNested)) + if (llvm::all_of(aliasSet, isNested)) { continue; + } assert((llvm::all_of(aliasSet, isNested) || llvm::none_of(aliasSet, isNested)) && @@ -371,8 +380,9 @@ struct AllocationScope { // Returns a memoized ConstantIndexOp of |value|. Value lookupOrCreateIndex(int64_t value) { auto it = indexConstantMap.find(value); - if (it != indexConstantMap.end()) + if (it != indexConstantMap.end()) { return it->second; + } auto constantValue = OpBuilder(rootOp).createOrFold( rootOp->getLoc(), value); indexConstantMap.insert(std::make_pair(value, constantValue)); @@ -382,10 +392,12 @@ struct AllocationScope { // Performs a memoized add (as many adds of offsets or lengths are redundant). Value add(Location loc, Value lhs, Value rhs) { // TODO(benvanik): memoize - if worth it. Needs profiling. - if (matchPattern(lhs, m_Zero())) + if (matchPattern(lhs, m_Zero())) { return rhs; - if (matchPattern(rhs, m_Zero())) + } + if (matchPattern(rhs, m_Zero())) { return lhs; + } auto result = OpBuilder(rootOp).createOrFold(loc, lhs, rhs); return result; } @@ -394,8 +406,9 @@ struct AllocationScope { // All aliases of |resource| will also be mapped. void mapResourceRange(Value resource, ResourceRange resourceRange, AsmState *asmState) { - if (resourceRangeMap.count(resource)) + if (resourceRangeMap.count(resource)) { return; + } if (!resourceRange.offset && !resourceRange.length) { resourceRange.offset = lookupOrCreateIndex(0); @@ -957,8 +970,9 @@ applyAsyncAllocations(IREE::Stream::AffinityAttr executionAffinityAttr, auto ops = llvm::map_to_vector(llvm::reverse(block), [&](Operation &op) { return &op; }); for (auto *op : ops) { - if (op->hasTrait()) + if (op->hasTrait()) { continue; + } if (failed(TypeSwitch(op) .Case([&](IREE::Stream::ResourceSubviewOp op) { return applyResourceSubviewOp(op, scope, OpBuilder(op)); @@ -1053,8 +1067,9 @@ allocateLocalTransients(IREE::Stream::AsyncExecuteOp executeOp, auto value = valueInterval.value; assert(value && "must have value for interval"); auto valueType = dyn_cast(value.getType()); - if (!valueType) + if (!valueType) { continue; + } // Only handle transient buffers (created/used/dropped within the stream). if (valueInterval.start == LIVE_IN || valueInterval.end == LIVE_OUT) { @@ -1268,8 +1283,9 @@ struct ConstantAllocation { // Returns true if |value| has one use and it is a stream.yield op. static bool isOnlyUseYield(Value value) { for (auto *user : value.getUsers()) { - if (!isa(user)) + if (!isa(user)) { return false; + } } return true; } @@ -1552,8 +1568,9 @@ gatherSubranges(Value derivedValue) { while (auto definingOp = dyn_cast_if_present( baseValue.getDefiningOp())) { auto tiedValue = definingOp.getTiedResultOperand(baseValue); - if (!tiedValue) + if (!tiedValue) { break; + } if (auto subrangeOp = dyn_cast( definingOp.getOperation())) { if (subrangeOp.getSubrangeResource() == tiedValue) { @@ -1580,8 +1597,9 @@ static ResourceRange deriveResourceRangeFromResult(Value resultValue, Value resultSize, OpBuilder &builder) { auto subranges = gatherSubranges(resultValue); - if (subranges.empty()) + if (subranges.empty()) { return ResourceRange(resultValue, resultSize); + } // TODO(benvanik): switch to affine.apply when fully supported. Value offset; @@ -1716,8 +1734,9 @@ allocateExecutionRegion(IREE::Stream::AsyncExecuteOp executeOp, // Replace results of escaping uploads with the upload values. for (auto &reservation : constantAllocation.reservations) { auto result = findTiedYieldResult(reservation.constantOp.getResult()); - if (!result) + if (!result) { continue; + } result.replaceAllUsesWith(reservation.resource); handledResults.insert(result); LLVM_DEBUG({ @@ -1954,8 +1973,9 @@ allocateExecutionRegion(IREE::Stream::AsyncExecuteOp executeOp, executeOp.getResultTimepoint().replaceAllUsesWith( newExecuteOp.getResultTimepoint()); for (auto replacement : resultReplacements) { - if (!replacement.second) + if (!replacement.second) { continue; // handled already + } LLVM_DEBUG({ AsmState asmState(newExecuteOp->getParentOp()); llvm::dbgs() << " == replacing region result "; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp index 504a42fef10a..5625923d5c76 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp @@ -54,10 +54,12 @@ struct WavePartitionBuilder { Operation *insertionPt = nullptr; for (auto in : partition->ins) { auto *definingOp = in.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; - if (definingOp->getBlock() != parentBlock) + } + if (definingOp->getBlock() != parentBlock) { continue; + } if (!insertionPt) { insertionPt = definingOp; // first defining op } else if (insertionPt->isBeforeInBlock(definingOp)) { @@ -83,8 +85,9 @@ struct WavePartitionBuilder { resultTypes.push_back(out.getType()); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, out, parentBuilder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } SmallVector operands; SmallVector operandTypes; @@ -93,14 +96,16 @@ struct WavePartitionBuilder { operandTypes.reserve(partition->ins.size()); operandSizes.reserve(partition->ins.size()); for (auto in : partition->ins) { - if (!isa(in.getType())) + if (!isa(in.getType())) { continue; + } operands.push_back(in); operandTypes.push_back(in.getType()); auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, in, parentBuilder); - if (operandSize) + if (operandSize) { operandSizes.push_back(operandSize); + } } // TODO(benvanik): tie operands, or leave to canonicalization. @@ -134,8 +139,9 @@ struct WavePartitionBuilder { // // Returns true if the operation was cloned into the partition. bool visit(Operation *op) { - if (!partition->ops.contains(op)) + if (!partition->ops.contains(op)) { return false; + } // Clone the op into the partition and remap it. auto *clonedOp = builder.clone(*op, mapping); @@ -159,8 +165,9 @@ struct WavePartitionBuilder { results.push_back(newResult); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( concurrentOp.getLoc(), newResult, builder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } IREE::Stream::YieldOp::create(builder, concurrentOp.getLoc(), results, resultSizes); @@ -188,8 +195,9 @@ struct ScheduleConcurrencyPass } for (auto executeOp : parentOp.getCallableRegion()->getOps()) { - if (failed(runOnRegion(executeOp))) + if (failed(runOnRegion(executeOp))) { return signalPassFailure(); + } } } @@ -205,10 +213,12 @@ struct ScheduleConcurrencyPass // Compute a set of partitions covering all of the streamable ops in the // execution region. auto waveSet = partitionRegionConcurrency(configAttr, block); - if (waveSet.empty()) + if (waveSet.empty()) { return success(); - if (failed(waveSet.verify(parentOp.getLoc()))) + } + if (failed(waveSet.verify(parentOp.getLoc()))) { return failure(); + } // Create partition builders for each partition. // We'll clone ops into each and insert them into the block at the @@ -217,8 +227,9 @@ struct ScheduleConcurrencyPass SmallVector partitionBuilders; partitionBuilders.reserve(waveSet.size()); for (auto partition : llvm::enumerate(waveSet.partitions)) { - if (partition.value().ops.size() == 1) + if (partition.value().ops.size() == 1) { continue; + } partitionBuilders.push_back(WavePartitionBuilder(block, partition.index(), &partition.value(), mapping, &getContext())); @@ -231,8 +242,9 @@ struct ScheduleConcurrencyPass // creates a lot of new IR (up to O(op*partitions)). SetVector deadOps; for (auto &op : *block) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } bool handled = false; for (auto &partitionBuilder : partitionBuilders) { handled = partitionBuilder.visit(&op) || handled; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp index 76976a163177..457e30a03edb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp @@ -55,8 +55,9 @@ struct ExecutePartitionBuilder { // This is at the last op in the partition. Operation *insertionPt = nullptr; for (auto *op : partition->ops) { - if (op->getBlock() != parentBlock) + if (op->getBlock() != parentBlock) { continue; + } if (!insertionPt) { insertionPt = op; // first defining op } else if (insertionPt->isBeforeInBlock(op)) { @@ -82,8 +83,9 @@ struct ExecutePartitionBuilder { resultTypes.push_back(out.getType()); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, out, parentBuilder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } SmallVector operands; SmallVector operandTypes; @@ -92,14 +94,16 @@ struct ExecutePartitionBuilder { operandTypes.reserve(partition->ins.size()); operandSizes.reserve(partition->ins.size()); for (auto in : partition->ins) { - if (!isa(in.getType())) + if (!isa(in.getType())) { continue; + } operands.push_back(in); operandTypes.push_back(in.getType()); auto operandSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( fusedLoc, in, parentBuilder); - if (operandSize) + if (operandSize) { operandSizes.push_back(operandSize); + } } // Collect await timepoints from all ops being partitioned and join them. @@ -148,8 +152,9 @@ struct ExecutePartitionBuilder { // // Returns true if the operation was cloned into the partition. bool visit(Operation *op) { - if (!partition->ops.contains(op)) + if (!partition->ops.contains(op)) { return false; + } // Clone the op into the partition and remap it. auto *clonedOp = builder.clone(*op, mapping); @@ -197,8 +202,9 @@ struct ExecutePartitionBuilder { results.push_back(newResult); auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( executeOp.getLoc(), newResult, builder); - if (resultSize) + if (resultSize) { resultSizes.push_back(resultSize); + } } IREE::Stream::YieldOp::create(builder, executeOp.getLoc(), results, resultSizes); @@ -228,8 +234,9 @@ static SmallVector sortBlocksInDominanceOrder(Region ®ion) { } llvm::SmallSetVector markedBlocks; std::function visit = [&](Block *block) { - if (markedBlocks.count(block) > 0) + if (markedBlocks.count(block) > 0) { return; + } for (auto *childBlock : dominanceInfo.getNode(block)->children()) { visit(childBlock->getBlock()); } @@ -322,8 +329,9 @@ LogicalResult processRegion(Location loc, MLIRContext *context, Region ®ion, // creates a lot of new IR (up to O(op*partitions)). SetVector deadOps; for (auto &op : *block) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } for (auto &partitionBuilder : partitionBuilders) { partitionBuilder.visit(&op); } @@ -436,8 +444,9 @@ LogicalResult processRegion(Location loc, MLIRContext *context, Region ®ion, } for (auto &subregion : op.getRegions()) { - if (failed(processRegion(loc, context, subregion, configAttr))) + if (failed(processRegion(loc, context, subregion, configAttr))) { return failure(); + } } } } @@ -479,8 +488,9 @@ struct ScheduleExecutionPass // order so that we are sure if we replace values that dominate other blocks // they see the correct values. auto ®ion = *parentOp.getCallableRegion(); - if (failed(processRegion(parentOp.getLoc(), context, region, configAttr))) + if (failed(processRegion(parentOp.getLoc(), context, region, configAttr))) { return signalPassFailure(); + } // Cleanup the dead ops. // TODO(benvanik): less work here - maybe no patterns to just force folding? diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp index 38ce8d49b77a..50391ac33946 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp @@ -64,8 +64,9 @@ buildConstantTable(mlir::FunctionOpInterface funcOp, llvm::BitVector constantOperandMap(operandCount, /*t=*/true); for (auto dispatchOp : dispatchOps) { for (unsigned idx = 0; idx < operandCount; ++idx) { - if (!constantOperandMap.test(idx)) + if (!constantOperandMap.test(idx)) { continue; + } auto value = dispatchOp.getUniformOperands()[idx]; Attribute constantValue; if (!matchPattern(value, m_Constant(&constantValue))) { @@ -86,8 +87,9 @@ buildConstantTable(mlir::FunctionOpInterface funcOp, DenseMap typeSets; SmallVector typeOrder; for (unsigned idx = 0; idx < operandCount; ++idx) { - if (!constantOperandMap.test(idx)) + if (!constantOperandMap.test(idx)) { continue; + } auto operandType = anyDispatchOp.getUniformOperands()[idx].getType(); auto &set = typeSets[operandType]; if (!set.type) { @@ -286,15 +288,17 @@ specializeDispatches(IREE::Stream::ExecutableOp executableOp, IREE::Stream::ExecutableExportOp exportOp, SmallVector &dispatchOps, MemoizedCmdConstants &memoizedConstants) { - if (dispatchOps.empty()) + if (dispatchOps.empty()) { return; // no-op if no dispatches + } auto funcOp = exportOp.lookupFunctionRef(); // Build a constant table for unique per-dispatch constant values. auto constantTable = buildConstantTable(funcOp, dispatchOps); - if (constantTable.coveredOperands.none()) + if (constantTable.coveredOperands.none()) { return; + } LLVM_DEBUG({ AsmState asmState(executableOp->getParentOp()); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp index 0c92aa38cd02..8d66b30189db 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/UnifyEncodingForGlobals.cpp @@ -813,10 +813,12 @@ struct UnifyEncodingForGlobalsPass [](TensorDispatchOp a, TensorDispatchOp b) { std::string aStr, bStr; llvm::raw_string_ostream aStream(aStr), bStream(bStr); - if (auto aAffinity = a.getAffinityAttr()) + if (auto aAffinity = a.getAffinityAttr()) { aStream << aAffinity; - if (auto bAffinity = b.getAffinityAttr()) + } + if (auto bAffinity = b.getAffinityAttr()) { bStream << bAffinity; + } return aStr < bStr; }); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h index 5d5b972602f7..4cbc1405be19 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Utils.h @@ -25,8 +25,9 @@ SmallVector gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { SmallPtrSet resultSet; for (auto dialect : moduleOp.getContext()->getLoadedDialects()) { auto *dialectInterface = dialect->getRegisteredInterface(); - if (!dialectInterface) + if (!dialectInterface) { continue; + } resultSet.insert(dialectInterface); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp index 5cd52e393a04..dc2be3195fbd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp @@ -60,8 +60,9 @@ struct VerifyAffinitiesPass ? WalkResult::skip() : WalkResult::advance(); }) - .wasInterrupted()) + .wasInterrupted()) { return signalPassFailure(); + } // Preserve all analyses since this is a read-only verification pass. markAllAnalysesPreserved(); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp index a19861592066..3f19b271217c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAsyncAccessRanges.cpp @@ -23,11 +23,13 @@ namespace mlir::iree_compiler::IREE::Stream { namespace { static std::optional matchConstant(Value value) { - if (!value) + if (!value) { return std::nullopt; + } APInt constant; - if (!matchPattern(value, m_ConstantInt(&constant))) + if (!matchPattern(value, m_ConstantInt(&constant))) { return std::nullopt; + } return constant.getSExtValue(); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp index a6deb12024eb..e1defae5de3d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyLowerings.cpp @@ -122,15 +122,17 @@ class Verifier { // Check types for operands/results. for (auto operandType : llvm::enumerate(op->getOperandTypes())) { - if (isTypeLegal(operandType.value())) + if (isTypeLegal(operandType.value())) { continue; + } emitIllegalTypeError(op, "operand", operandType.index(), operandType.value()); foundAnyIllegal = true; } for (auto resultType : llvm::enumerate(op->getResultTypes())) { - if (isTypeLegal(resultType.value())) + if (isTypeLegal(resultType.value())) { continue; + } emitIllegalTypeError(op, "result", resultType.index(), resultType.value()); foundAnyIllegal = true; @@ -358,8 +360,9 @@ struct VerifyLoweringToAsyncPass } // Allow metadata ops outside of execution regions. - if (op.isMetadata()) + if (op.isMetadata()) { return Verifier::Legality::LEGAL; + } // TODO(benvanik): execution region interface to make this generic. if (!op->template getParentOfType()) { From 3884ee61c8d94ca5b60bbbab21071eb938f6e114 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 16:01:34 -0500 Subject: [PATCH 49/71] Add braces in Util, VM, and supporting dialects. NFC. 5/n (#23147) --- .../Dialect/Encoding/IR/EncodingAttrs.cpp | 6 +- .../Encoding/Utils/ElementPackingUtils.cpp | 12 +- .../IR/AggregatedOpInterfaceImpl.cpp | 6 +- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 30 ++- .../LinalgExt/IR/TilingInterfaceImpl.cpp | 21 +- .../Transforms/ConvertConvToIm2ColOp.cpp | 9 +- .../LinalgExt/Transforms/ReshapeFusion.cpp | 51 ++-- .../LinalgExt/Transforms/RewriteFft.cpp | 3 +- .../LinalgExt/Transforms/SplitReduction.cpp | 6 +- .../Dialect/LinalgExt/Utils/MatchUtils.cpp | 3 +- .../Dialect/LinalgExt/Utils/Utils.cpp | 41 ++- .../TensorExt/IR/TensorExtOpFolders.cpp | 21 +- .../Dialect/TensorExt/IR/TensorExtTypes.cpp | 9 +- .../Dialect/TensorExt/Transforms/Folders.cpp | 6 +- .../Util/Analysis/Attributes/Range.cpp | 6 +- .../Util/Analysis/Constant/ConstExpr.cpp | 54 ++-- .../Util/Analysis/Constant/ConstExpr.h | 6 +- .../Util/Analysis/Constant/OpOracle.cpp | 6 +- .../Dialect/Util/Analysis/DFX/DepGraph.cpp | 3 +- .../Dialect/Util/Analysis/DFX/Element.cpp | 3 +- .../Dialect/Util/Analysis/DFX/Element.h | 3 +- .../Dialect/Util/Analysis/DFX/Solver.cpp | 18 +- .../Dialect/Util/Analysis/DFX/Solver.h | 3 +- .../Dialect/Util/Analysis/DFX/State.cpp | 6 +- .../Dialect/Util/Analysis/DFX/State.h | 30 ++- .../Dialect/Util/Analysis/Explorer.cpp | 87 ++++--- .../Analysis/IntegerDivisibilityAnalysis.cpp | 3 +- .../Util/Conversion/FuncToUtil/Patterns.cpp | 3 +- .../Dialect/Util/IR/ClosureOpUtils.cpp | 9 +- .../compiler/Dialect/Util/IR/UtilAttrs.cpp | 18 +- .../compiler/Dialect/Util/IR/UtilDialect.cpp | 12 +- .../Dialect/Util/IR/UtilOpFolders.cpp | 51 ++-- .../iree/compiler/Dialect/Util/IR/UtilOps.cpp | 246 ++++++++++++------ .../compiler/Dialect/Util/IR/UtilTypes.cpp | 63 +++-- .../iree/compiler/Dialect/Util/IR/UtilTypes.h | 6 +- .../Util/TransformOps/UtilTransformOps.cpp | 26 +- .../Util/Transforms/DropCompilerHints.cpp | 3 +- .../Util/Transforms/FixedPointIterator.cpp | 12 +- .../Dialect/Util/Transforms/FuseGlobals.cpp | 6 +- .../Util/Transforms/HoistIntoGlobals.cpp | 24 +- .../compiler/Dialect/Util/Transforms/IPO.cpp | 18 +- .../Util/Transforms/ImportResources.cpp | 3 +- .../Util/Transforms/OptimizeIntArithmetic.cpp | 77 ++++-- .../Dialect/Util/Transforms/Patterns.cpp | 48 ++-- .../Util/Transforms/PropagateSubranges.cpp | 52 ++-- .../Transforms/SimplifyGlobalAccesses.cpp | 21 +- .../VM/Analysis/LinearScan/LiveIntervals.cpp | 18 +- .../Dialect/VM/Analysis/OrdinalAnalysis.cpp | 3 +- .../VM/Analysis/RegisterAllocation.cpp | 60 +++-- .../Dialect/VM/Analysis/RegisterAllocation.h | 9 +- .../Dialect/VM/Analysis/ValueLiveness.cpp | 18 +- .../VM/Conversion/ArithToVM/Patterns.cpp | 18 +- .../Dialect/VM/Conversion/ImportUtils.cpp | 24 +- .../Dialect/VM/Conversion/ImportUtils.h | 9 +- .../VM/Conversion/MathToVM/Patterns.cpp | 6 +- .../VM/Conversion/StandardToVM/Patterns.cpp | 15 +- .../Conversion/UtilToVM/ConvertBufferOps.cpp | 9 +- .../VM/Conversion/UtilToVM/ConvertListOps.cpp | 9 +- .../UtilToVM/ConvertStructuralOps.cpp | 15 +- .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 21 +- .../VM/Conversion/VMToEmitC/EmitCBuilders.cpp | 3 +- .../iree/compiler/Dialect/VM/IR/VMDialect.cpp | 6 +- .../compiler/Dialect/VM/IR/VMOpFolders.cpp | 157 +++++++---- .../src/iree/compiler/Dialect/VM/IR/VMOps.cpp | 49 ++-- .../iree/compiler/Dialect/VM/IR/VMTypes.cpp | 3 +- .../VM/Target/Bytecode/ArchiveWriter.cpp | 3 +- .../VM/Target/Bytecode/BytecodeEncoder.cpp | 12 +- .../Target/Bytecode/BytecodeModuleTarget.cpp | 18 +- .../Target/Bytecode/DebugDatabaseBuilder.cpp | 9 +- .../Dialect/VM/Tools/VMOpEncoderGen.cpp | 6 +- .../VM/Transforms/AnnotateFunctions.cpp | 24 +- .../Dialect/VM/Transforms/Conversion.cpp | 9 +- .../VM/Transforms/ConvertToYieldableCalls.cpp | 6 +- .../VM/Transforms/DeduplicateRodata.cpp | 3 +- .../DropEmptyModuleInitializers.cpp | 6 +- .../VM/Transforms/GlobalInitialization.cpp | 3 +- .../VM/Transforms/MaterializeRefDiscards.cpp | 42 ++- .../VM/Transforms/OrdinalAllocation.cpp | 3 +- .../VM/Transforms/ResolveRodataLoads.cpp | 12 +- .../compiler/Dialect/VM/Utils/TypeTable.cpp | 9 +- .../Conversion/VMVXToVM/ConvertVMVXToVM.cpp | 3 +- .../VMVX/Transforms/MaterializeConstants.cpp | 3 +- .../Transforms/ResolveBufferDescriptors.cpp | 15 +- 83 files changed, 1189 insertions(+), 599 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 895fbf40bcb3..3cdde7c627c9 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -320,8 +320,9 @@ SmallVector EncodingAttr::getRootMaps() const { return cast(m).getAffineMap(); } if (auto mapsAttr = dyn_cast(m)) { - if (mapsAttr.empty()) + if (mapsAttr.empty()) { return AffineMap(); + } return cast(mapsAttr[0]).getAffineMap(); } return AffineMap(); @@ -339,8 +340,9 @@ AffineMap EncodingAttr::getLastMapForOperandIndex() const { return mapAttr.getAffineMap(); } if (auto mapsAttr = dyn_cast(indexingMap)) { - if (mapsAttr.empty()) + if (mapsAttr.empty()) { return AffineMap(); + } return cast(mapsAttr[mapsAttr.size() - 1]).getAffineMap(); } return AffineMap(); diff --git a/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp index 10d75a6ae56b..ea00b86f7878 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/Utils/ElementPackingUtils.cpp @@ -63,19 +63,22 @@ static Type legalizeStorageElementTypeImpl(Type elementType, bool isPackedStorage) { // Only handle integers; floats in MLIR all have aligned widths (today). auto intType = dyn_cast(elementType); - if (!intType) + if (!intType) { return elementType; + } // For sub-byte elements, default to pack them into bytes. unsigned bitWidth = intType.getWidth(); - if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage)) + if (needToPackSubByteElementBitWidthImpl(bitWidth, isPackedStorage)) { return elementType; + } // Otherwise, extend them to the next power-of-two bit width. unsigned alignedBitWidth = IREE::Util::getRoundedElementByteWidth(intType) * 8; - if (alignedBitWidth == bitWidth) + if (alignedBitWidth == bitWidth) { return elementType; + } return IntegerType::get(elementType.getContext(), alignedBitWidth, intType.getSignedness()); } @@ -115,8 +118,9 @@ Value calculateStorageElementCountInBytes(Location loc, } for (unsigned i = 0; i < shapedType.getRank(); ++i) { - if (!shapedType.isDynamicDim(i)) + if (!shapedType.isDynamicDim(i)) { staticCount *= shapedType.getDimSize(i); + } } // Scale by dynamic dims, if present. auto value = diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index e8c3d90362c0..3cb319dbe17a 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -827,8 +827,9 @@ FailureOr> Im2colOp::decomposeOperation(OpBuilder &b) { } SetVector batchPosSet(getBatchPos().begin(), getBatchPos().end()); for (auto [idx, size] : enumerate(inputSizes)) { - if (batchPosSet.contains(idx)) + if (batchPosSet.contains(idx)) { continue; + } if (mPosSet.contains(idx)) { kBasis.push_back(kernelSize[mKernelIdx[idx]]); continue; @@ -861,8 +862,9 @@ FailureOr> Im2colOp::decomposeOperation(OpBuilder &b) { int delinKIdx = 0; SmallVector invInputKPerm = invertPermutationVector(inputKPerm); for (int i = 0; i < getInputRank(); ++i) { - if (batchPosSet.contains(i)) + if (batchPosSet.contains(i)) { continue; + } if (mPosSet.contains(i)) { windowOffset.push_back(delinKOffset[invInputKPerm[delinKIdx++]]); continue; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 6d2badfdb75c..197435f69779 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1048,8 +1048,9 @@ LogicalResult FftOp::verify() { // After tiling, it could be dynamic shape. (Because // subview/subtensor does not inference the type correctly // on (1 << x)) cases). - if (ShapedType::isDynamic(length)) + if (ShapedType::isDynamic(length)) { return success(); + } if (length & (length - 1)) { return op->emitOpError("only powers of 2 are handled currently"); } @@ -1287,8 +1288,9 @@ LogicalResult ArgCompareOp::verify() { SmallVector expectedShape; for (int64_t i = 0; i < rank; ++i) { - if (i != dim) + if (i != dim) { expectedShape.push_back(inputType.getDimSize(i)); + } } if (!llvm::equal(expectedShape, outputValueType.getShape())) { return op->emitOpError("output shape must match input shape with reduction " @@ -1392,15 +1394,18 @@ areNotFullTiles(ArrayRef inputShape, DenseMap const &dimAndTileMapping) { int64_t rank = inputShape.size(); for (int64_t dim = 0; dim < rank; dim++) { - if (ShapedType::isDynamic(inputShape[dim])) + if (ShapedType::isDynamic(inputShape[dim])) { continue; + } auto it = dimAndTileMapping.find(dim); if (it != dimAndTileMapping.end()) { std::optional constantTile = getConstantIntValue(it->second); - if (!constantTile) + if (!constantTile) { continue; - if (inputShape[dim] % (*constantTile) != 0) + } + if (inputShape[dim] % (*constantTile) != 0) { return true; + } } } return false; @@ -2127,8 +2132,9 @@ LogicalResult AttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) + if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) { return failure(); + } } int expectedSymbols = getQueryMap().getNumInputs(); @@ -2153,14 +2159,16 @@ LogicalResult AttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkDomain("Mask", *maskMap))) + if (failed(checkDomain("Mask", *maskMap))) { return failure(); + } } auto &block = getRegion().front(); auto blockTys = block.getArgumentTypes(); - if (!isa(blockTys[0])) + if (!isa(blockTys[0])) { return attnOp->emitOpError("block argument 0 should be float"); + } auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp) { @@ -2304,8 +2312,9 @@ LogicalResult OnlineAttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) + if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap))) { return failure(); + } } int expectedSymbols = getQueryMap().getNumInputs(); @@ -2332,8 +2341,9 @@ LogicalResult OnlineAttentionOp::verify() { // Additional check case if mask exists if (auto maskMap = getMaskMap()) { - if (failed(checkDomain("Mask", *maskMap))) + if (failed(checkDomain("Mask", *maskMap))) { return failure(); + } } Block &block = attnOp.getRegion().front(); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp index 306eff61f43d..acd09029fc9b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp @@ -291,8 +291,9 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, auto dim = dimMap[i]; - if (starts[dim]) + if (starts[dim]) { ret = arith::AddIOp::create(b, loc, ret, starts[dim]); + } starts[dim] = ret; } @@ -441,8 +442,9 @@ LogicalResult GatherOp::generateScalarImplementation(OpBuilder &b, Location loc, Value idx = memref::LoadOp::create(b, loc, getIndices(), loadIndices); Value ret = arith::IndexCastOp::create(b, loc, b.getIndexType(), idx); auto dim = dimMap[i]; - if (starts[dim]) + if (starts[dim]) { ret = arith::AddIOp::create(b, loc, ret, starts[dim]); + } starts[dim] = ret; } @@ -1192,11 +1194,13 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, scanBlkArgs.push_back( memref::LoadOp::create(b, loc, getOutput(), indices)); Value i0; - if (!isInclusive) + if (!isInclusive) { i0 = memref::LoadOp::create(b, loc, getInput(), indices); + } indices[scanDim] = iv; - if (isInclusive) + if (isInclusive) { i0 = memref::LoadOp::create(b, loc, getInput(), indices); + } scanBlkArgs.push_back(i0); }); @@ -1292,8 +1296,9 @@ LogicalResult ScanOp::getResultTilePosition( int64_t rank = getOperandRank(); if (rank > 1) { for (auto i : llvm::seq(0, rank)) { - if (i == getDimension()) + if (i == getDimension()) { continue; + } resultOffsets.push_back(offsets[i]); resultSizes.push_back(sizes[i]); } @@ -1617,8 +1622,9 @@ LogicalResult ArgCompareOp::generateScalarImplementation(OpBuilder &b, uint64_t reductionDim = getDimension(); SmallVector parallelIndices; for (size_t i = 0, rank = ivs.size(); i < rank; ++i) { - if (i == reductionDim) + if (i == reductionDim) { continue; + } parallelIndices.push_back(ivs[i]); } @@ -3572,8 +3578,9 @@ static void offsetCustomOpIndices(OpBuilder &b, CustomOp customOp, ArrayRef offsets) { IRRewriter rewriter(b); for (auto indexOp : customOp.getBody()->getOps()) { - if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) + if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) { continue; + } OpBuilder::InsertionGuard guard(b); rewriter.setInsertionPointAfter(indexOp); AffineExpr index, offset; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp index ecc7dd2035e4..2b6076aa0793 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp @@ -27,15 +27,17 @@ static bool hasAllOneValues(ArrayRef attr) { static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { bool isInt = isa(x.getType()); - if (isInt) + if (isInt) { return arith::AddIOp::create(builder, loc, x, y); + } return arith::AddFOp::create(builder, loc, x, y); } static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { bool isInt = isa(x.getType()); - if (isInt) + if (isInt) { return arith::MulIOp::create(builder, loc, x, y); + } return arith::MulFOp::create(builder, loc, x, y); } @@ -153,9 +155,10 @@ class ConvertConvGeneric final auto igemmConvDetailsOrFailure = LinalgExt::getIGEMMGenericConvDetails(linalgOp); - if (failed(igemmConvDetailsOrFailure)) + if (failed(igemmConvDetailsOrFailure)) { return rewriter.notifyMatchFailure(linalgOp, "Failed to extract IGEMM details"); + } LinalgExt::IGEMMGenericConvDetails igemmConvDetails = *igemmConvDetailsOrFailure; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index 252e6204d790..d91a41213e20 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -71,8 +71,9 @@ static SmallVector getDimSizes(Value v) { static bool isIdentityReassoc(const SmallVector &indices) { for (auto &index : indices) { - if (index.size() != 1) + if (index.size() != 1) { return false; + } } return true; }; @@ -240,8 +241,9 @@ LogicalResult ExpansionInfo::compute( SmallVector infos, SmallVector loopRanges, OpOperand *fusableOpOperand, ArrayRef operandReassoc, ArrayRef expandedShape) { - if (operandReassoc.empty()) + if (operandReassoc.empty()) { return failure(); + } // Check that the operand dim size matches the iteration space dim size. This // can fail when one is static and the other is dynamic. @@ -307,28 +309,33 @@ CollapsingInfo::initialize(unsigned origNumLoops, llvm::SmallDenseSet processedDims; // Find all the dims that are folded. for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { - if (foldedIterationDim.empty()) + if (foldedIterationDim.empty()) { continue; + } // If the folded dims contain dims already folded, that's illegal // specification. Repetition within a list is also illegal. for (auto dim : foldedIterationDim) { - if (dim >= origNumLoops) + if (dim >= origNumLoops) { return failure(); - if (processedDims.count(dim)) + } + if (processedDims.count(dim)) { return failure(); + } processedDims.insert(dim); } collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), foldedIterationDim.end()); } - if (processedDims.size() > origNumLoops) + if (processedDims.size() > origNumLoops) { return failure(); + } // Add all the preserved dims of the original op as single // elements to `collapsedOpToOrigOpIterationDim`. for (auto dim : llvm::seq(0, origNumLoops)) { - if (processedDims.count(dim)) + if (processedDims.count(dim)) { continue; + } collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); } @@ -339,9 +346,10 @@ CollapsingInfo::initialize(unsigned origNumLoops, origOpToCollapsedOpIterationDim.resize(origNumLoops); for (const auto &foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) { - for (const auto &dim : enumerate(foldedDims.value())) + for (const auto &dim : enumerate(foldedDims.value())) { origOpToCollapsedOpIterationDim[dim.value()] = std::make_pair(foldedDims.index(), dim.index()); + } } return success(); } @@ -387,9 +395,10 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) { indicesInfo.originalShape = getDimSizes(scatterOp.getIndices()); llvm::append_range(indicesInfo.operandToIterationSpace, llvm::seq(0, scatterOp.getBatchRank())); - if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank()) + if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank()) { indicesInfo.operandToIterationSpace.push_back( ReshapeOperandInfo::kNoMapping); + } infos.push_back(std::move(indicesInfo)); ReshapeOperandInfo originalInfo; @@ -420,9 +429,10 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) { indicesInfo.originalShape = getDimSizes(gatherOp.getIndices()); llvm::append_range(indicesInfo.operandToIterationSpace, llvm::seq(0, gatherOp.getBatchRank())); - if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank()) + if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank()) { indicesInfo.operandToIterationSpace.push_back( ReshapeOperandInfo::kNoMapping); + } infos.push_back(std::move(indicesInfo)); ReshapeOperandInfo outputInfo; @@ -846,10 +856,12 @@ struct FoldWithProducerReshapeByExpansion final for (OpOperand &opOperand : op->getOpOperands()) { tensor::CollapseShapeOp reshapeOp = opOperand.get().getDefiningOp(); - if (!reshapeOp) + if (!reshapeOp) { continue; - if (!controlFoldingReshapes(&opOperand)) + } + if (!controlFoldingReshapes(&opOperand)) { continue; + } std::optional replacementValue = fuseWithReshapeByExpansion(op, reshapeOp, &opOperand, rewriter); @@ -893,8 +905,9 @@ struct FoldWithConsumerReshapeByExpansion final std::optional replacementValue = fuseWithReshapeByExpansion( op, expandOp, op.getTiedOpOperand(producerResult), rewriter); - if (!replacementValue) + if (!replacementValue) { return failure(); + } rewriter.replaceOp(op, *replacementValue); return success(); } @@ -946,8 +959,9 @@ static Value getCollapsedOpOperand(Location loc, AttentionOp op, // the number of results of the indexing map, then nothing to do for this // operand. Value operand = opOperand->get(); - if (operandReassociation.size() == indexingMap.getNumResults()) + if (operandReassociation.size() == indexingMap.getNumResults()) { return operand; + } // Insert a reshape to collapse the dimensions. if (isa(operand.getType())) { @@ -982,8 +996,9 @@ static void collapseOperandsAndResults(AttentionOp op, outputOperands.push_back(newOutput); // If the op has "buffer semantics", then the init operands are ranked // memrefs and the op has no results. - if (!op.hasPureBufferSemantics()) + if (!op.hasPureBufferSemantics()) { resultTypes.push_back(newOutput.getType()); + } } } @@ -1001,8 +1016,9 @@ getCollapsedOpIndexingMap(AffineMap indexingMap, for (auto expr : indexingMap.getResults()) { unsigned dim = cast(expr).getPosition(); // If the dim is not the first of the collapsed dim, do nothing. - if (origOpToCollapsedOpMapping[dim].second != 0) + if (origOpToCollapsedOpMapping[dim].second != 0) { continue; + } // The next n-dims are guaranteed to be collapsed. So just use the // iteration dimension of the collapsed op. resultExprs.push_back( @@ -1067,8 +1083,9 @@ collapseOpIterationDims(AttentionOp op, if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { return foldedDims.size() <= 1; - })) + })) { return failure(); + } CollapsingInfo collapsingInfo; if (failed( diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp index 087ec5a7e729..e34b54c33cc3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp @@ -25,8 +25,9 @@ FailureOr> rewriteFft(Operation *op, Value operand, } // Skip else getBitReversalOrder produces invalid dense elements attr. - if (!operandType.getElementType().isF32()) + if (!operandType.getElementType().isF32()) { return rewriter.notifyMatchFailure(op, "expected F32 types"); + } ImplicitLocOpBuilder b(loc, rewriter); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp index f03682a23394..20baa05c7a3c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp @@ -658,9 +658,10 @@ splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp, Value outVal = args[1]; Value outIdx = args[2]; Value reductionIdx = linalg::IndexOp::create(b, loc, reductionDim + 1); - if (outIdx.getType() != reductionIdx.getType()) + if (outIdx.getType() != reductionIdx.getType()) { reductionIdx = arith::IndexCastOp::create(b, loc, outIdx.getType(), reductionIdx); + } Value inCast = in; Type inType = in.getType(); Type outType = outVal.getType(); @@ -715,8 +716,9 @@ splitArgmaxReduction(RewriterBase &rewriter, linalg::GenericOp genericOp, Value outIdx = inputs[3]; Value outer = linalg::IndexOp::create(b, loc, insertSplitDimension); Value offset = arith::MulIOp::create(b, loc, outer, tileSize); - if (offset.getType() != local.getType()) + if (offset.getType() != local.getType()) { offset = arith::IndexCastOp::create(b, loc, local.getType(), offset); + } // gidx = outer * ratio + local. Value gidx = arith::AddIOp::create(b, loc, offset, local); Operation *clonedMax = b.clone(*combinerOps.maxOp); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp index 468ccc8f8bf1..2481289a6904 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.cpp @@ -97,8 +97,9 @@ findPermutationsIndexingOperand(AffineMap indexingMap, if (iterators[d.getPosition()] == iter && llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) { return e.isFunctionOfDim(d.getPosition()); - }) == 1) + }) == 1) { res.insert(d.getPosition()); + } } } return res; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index 464bf7b4f80d..39a0ed0cf37e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -461,8 +461,9 @@ FailureOr getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { auto convDimsOrFailure = linalg::inferConvolutionDims(linalgOp); MLIRContext *ctx = linalgOp->getContext(); - if (failed(convDimsOrFailure)) + if (failed(convDimsOrFailure)) { return failure(); + } const mlir::linalg::ConvolutionDimensions &convDims = *convDimsOrFailure; LLVM_DEBUG({ llvm::dbgs() << "conv: " << linalgOp; @@ -524,8 +525,9 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { LDBG() << "output image or output channel dim not found in output."; return failure(); } - if (outputChannelLastDim.value() < outputImageFirstDim.value()) + if (outputChannelLastDim.value() < outputImageFirstDim.value()) { isOutputChannelFirst = true; + } SmallVector filterkPos; for (auto reductionDim : reductionDims) { @@ -620,8 +622,9 @@ getIGEMMGenericConvDetails(linalg::LinalgOp linalgOp) { // Lambda to remap conv dim indices to igemm dimensions. auto remapDims = [&](ArrayRef dims) -> SmallVector { SmallVector mapped; - for (unsigned d : dims) + for (unsigned d : dims) { mapped.push_back(convToIgemmDimMap.at(d)); + } return mapped; }; @@ -721,8 +724,9 @@ static Value getSourceSkipUnary(Value value) { Operation *op = value.getDefiningOp(); while (op && op->getNumOperands() == 1) { auto iface = dyn_cast(op); - if (!iface || !iface.hasNoEffect()) + if (!iface || !iface.hasNoEffect()) { break; + } value = op->getOperand(0); op = value.getDefiningOp(); } @@ -782,13 +786,15 @@ template static bool isPairTemplateImpl(Operation *add, Operation *mul) { static_assert(sizeof...(Args) % 2 == 0, "expected an even number of template arguments"); - if (isa(add) && isa(mul)) + if (isa(add) && isa(mul)) { return true; + } - if constexpr (sizeof...(Args) > 0) + if constexpr (sizeof...(Args) > 0) { return isPairTemplateImpl(add, mul); - else + } else { return false; + } } /// Returns true if the block is a body of a contraction with the kinds of @@ -918,19 +924,22 @@ bool isArgmaxOp(linalg::GenericOp genericOp) { // TODO: Add better affine map checks. auto indexing_maps = genericOp.getIndexingMapsArray(); - if (!indexing_maps[0].isIdentity()) + if (!indexing_maps[0].isIdentity()) { return false; + } // Check that initial value is negative Infinite. // TODO: Move this check to ukernel once we implement // variant to handle non neg-Inf initial value. Value initVal = genericOp.getDpsInitOperand(0)->get(); auto fillOp = initVal.getDefiningOp(); - if (!fillOp) + if (!fillOp) { return false; + } Value fillVal = fillOp.getDpsInputOperand(0)->get(); - if (!matchPattern(fillVal, m_NegInfFloat())) + if (!matchPattern(fillVal, m_NegInfFloat())) { return false; + } // Work back from linalg.yield and check body of genericOp. // The genericOp should yield the result of an arith.select, @@ -965,13 +974,15 @@ bool isArgmaxOp(linalg::GenericOp genericOp) { } auto selectOp = cast(producerOutput.getDefiningOp()); Value trueVal = selectOp.getTrueValue(); - if (auto castOp = trueVal.getDefiningOp()) + if (auto castOp = trueVal.getDefiningOp()) { trueVal = castOp.getIn(); + } // Ensure the true value is directly produced by linalg.index. auto indexOp = trueVal.getDefiningOp(); - if (!indexOp) + if (!indexOp) { return false; + } } // Producer of arith.select op is arith.cmpf @@ -1034,11 +1045,13 @@ bool isPureBatchMatmul(Operation *op) { // it requires a single input where the indexing maps are full permutations and // non-equal. bool isaTransposeOpInterface(linalg::LinalgOp linalgOp) { - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) + if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { return false; + } - if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) + if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) { return false; + } auto mapRange = linalgOp.getIndexingMapsArray(); if (mapRange.size() != 2 || !mapRange.front().isPermutation() || !mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) { diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp index 93082e5a23ef..04faecbfeb40 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtOpFolders.cpp @@ -22,8 +22,9 @@ struct ReplaceBitCastIfTensorOperandEmpty final : OpRewritePattern { PatternRewriter &rewriter) const override { auto emptyOp = dyn_cast_if_present(op.getSource().getDefiningOp()); - if (!emptyOp) + if (!emptyOp) { return failure(); + } rewriter.replaceOpWithNewOp(op, op.getResult().getType(), op.getResultDims()); return success(); @@ -36,8 +37,9 @@ struct BitCastOfTensorCastStaticInfo final : OpRewritePattern { LogicalResult matchAndRewrite(BitCastOp bitcastOp, PatternRewriter &rewriter) const final { auto tensorCastOp = bitcastOp.getSource().getDefiningOp(); - if (!tensorCastOp) + if (!tensorCastOp) { return failure(); + } auto tensorCastSrcType = dyn_cast(tensorCastOp.getOperand().getType()); if (!tensorCastSrcType) { @@ -66,8 +68,9 @@ struct BitCastOfTensorCastStaticInfo final : OpRewritePattern { // Drop the dynamic dims that become static after incorporating the cast. for (auto [castSize, sourceSize] : llvm::zip_equal( tensorCastSrcType.getShape(), intermediateTensorType.getShape())) { - if (!ShapedType::isDynamic(sourceSize)) + if (!ShapedType::isDynamic(sourceSize)) { continue; + } while (!ShapedType::isDynamic(resShape[resDynamicDim])) { ++resDynamicDim; @@ -135,8 +138,9 @@ static bool updateTensorOpDims(RewriterBase &rewriter, Operation *op, MutableOperandRange mutableDimValues) { auto dynamicDimsOr = IREE::Util::findDynamicDims(tensorValue, op->getBlock(), Block::iterator(op)); - if (!dynamicDimsOr.has_value()) + if (!dynamicDimsOr.has_value()) { return false; + } auto dynamicDims = dynamicDimsOr.value(); bool anyChanged = false; OperandRange oldValueRange = mutableDimValues; @@ -235,8 +239,9 @@ canonicalizeSubViewParts(OpTy op, RankedTensorType sliceType, llvm::SmallVector newShape; llvm::SmallBitVector droppedDims = op.getDroppedDims(); for (auto size : llvm::enumerate(mixedSizes)) { - if (droppedDims.test(size.index())) + if (droppedDims.test(size.index())) { continue; + } std::optional staticSize = getConstantIntValue(size.value()); newShape.push_back(staticSize ? staticSize.value() : ShapedType::kDynamic); } @@ -256,8 +261,9 @@ struct DispatchTensorLoadOpWithOffsetSizesAndStridesConstantArgumentFolder final RankedTensorType resultType = loadOp.getType(); auto newResultType = canonicalizeSubViewParts( loadOp, resultType, mixedOffsets, mixedSizes, mixedStrides); - if (failed(newResultType)) + if (failed(newResultType)) { return failure(); + } // We need to resolve the new inferred type with the specified type. Location loc = loadOp.getLoc(); @@ -355,8 +361,9 @@ struct DispatchTensorStoreOpWithOffsetSizesAndStridesConstantArgumentFolder RankedTensorType valueType = storeOp.getValueType(); auto newValueType = canonicalizeSubViewParts( storeOp, valueType, mixedOffsets, mixedSizes, mixedStrides); - if (failed(newValueType)) + if (failed(newValueType)) { return failure(); + } Value value = storeOp.getValue(); Location loc = storeOp.getLoc(); diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp index d6f62207d57c..b70ff74860c3 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/TensorExt/IR/TensorExtTypes.cpp @@ -58,8 +58,9 @@ int64_t DispatchTensorType::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); auto shape = getShape(); int64_t num = 1; - for (auto dim : shape) + for (auto dim : shape) { num *= dim; + } return num; } @@ -197,10 +198,12 @@ void printType(DispatchTensorType &type, DialectAsmPrinter &p) { Type IREETensorExtDialect::parseType(DialectAsmParser &parser) const { StringRef mnemonic; - if (parser.parseKeyword(&mnemonic)) + if (parser.parseKeyword(&mnemonic)) { return {}; - if (mnemonic == "dispatch.tensor") + } + if (mnemonic == "dispatch.tensor") { return DispatchTensorType::parse(parser); + } parser.emitError(parser.getCurrentLocation()) << "unknown TensorExt type: " << mnemonic; return {}; diff --git a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp index 72dd9f38d568..591fc681a6a3 100644 --- a/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp +++ b/compiler/src/iree/compiler/Dialect/TensorExt/Transforms/Folders.cpp @@ -23,8 +23,9 @@ struct FoldTensorLoadWithExtractSlice auto dispatchTensorLoadOp = extractSliceOp.getSource() .getDefiningOp(); - if (!dispatchTensorLoadOp) + if (!dispatchTensorLoadOp) { return failure(); + } SmallVector offsets, sizes, strides; // `tensor.extract_slice` (i.e. the producer) folds **into** @@ -56,8 +57,9 @@ struct FoldInsertSliceWithTensorStoreOp PatternRewriter &rewriter) const override { auto insertSliceOp = dispatchTensorStoreOp.getValue().getDefiningOp(); - if (!insertSliceOp) + if (!insertSliceOp) { return failure(); + } SmallVector offsets, sizes, strides; // `tensor.insert_slice` (i.e. the producer) folds **into** diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp index 62c5f9ac8b23..074e028fbf3b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Attributes/Range.cpp @@ -43,8 +43,9 @@ void FloatRangeStats::addDomainValue(double value) { } std::string FloatRangeStats::getAsStr(AsmState &asmState) const { - if (!valid) + if (!valid) { return std::string("<>"); + } std::string s("["); s += std::to_string(minValue); s += ", "; @@ -192,8 +193,9 @@ ChangeStatus FloatRangeValueElement::updateValue(Value value, newState ^= inner; // Stop traversal if tied OpOperand is not used in the op body. if (!linalgOp.payloadUsesValueFromOperand( - linalgOp.getDpsInitOperand(result.getResultNumber()))) + linalgOp.getDpsInitOperand(result.getResultNumber()))) { return WalkResult::skip(); + } return WalkResult::advance(); } else if (auto minfOp = dyn_cast(definingOp)) { auto lhs = solver.getElementFor( diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp index 586809176887..b03e8116379d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp @@ -24,8 +24,9 @@ namespace mlir::iree_compiler::IREE::Util { static OpOperand *findOperandFor(Operation *op, Value input) { for (OpOperand &operand : op->getOpOperands()) { - if (operand.get() == input) + if (operand.get() == input) { return &operand; + } } return nullptr; } @@ -33,10 +34,12 @@ static OpOperand *findOperandFor(Operation *op, Value input) { bool ConstExprAnalysis::isConstExprOperation(Operation *queryOp) const { if (queryOp->getNumResults() == 0) { bool hasNoMemoryEffects = false; - if (auto effectOp = dyn_cast(queryOp)) + if (auto effectOp = dyn_cast(queryOp)) { hasNoMemoryEffects = effectOp.hasNoEffect(); - if (hasNoMemoryEffects && queryOp->hasTrait()) + } + if (hasNoMemoryEffects && queryOp->hasTrait()) { return true; + } return false; } // NOTE: this only checks the first result as all results are added to the map @@ -79,25 +82,31 @@ ConstExprAnalysis::ConstExprAnalysis(Operation *rootOp) // such as if they are initialized based on values only available at runtime. explorer.forEachGlobal([&](const Explorer::GlobalInfo *info) { // Rely on globals having been canonicalized to immutable correctly. - if (info->isIndirect || info->op.isGlobalMutable()) + if (info->isIndirect || info->op.isGlobalMutable()) { return; - if (!isLegalConstExprRootType(info->op.getGlobalType())) + } + if (!isLegalConstExprRootType(info->op.getGlobalType())) { return; - for (auto loadOp : info->getLoads()) + } + for (auto loadOp : info->getLoads()) { constantRoots[loadOp.getLoadedGlobalValue()] = loadOp; + } }); // Populate the constant roots for all inline constants in the program. explorer.forEachFunctionLikeOp([&](FunctionOpInterface funcOp) { funcOp.walk([&](Operation *op) { - if (!op->hasTrait()) + if (!op->hasTrait()) { return; + } for (auto resultType : op->getResultTypes()) { - if (!isLegalConstExprRootType(resultType)) + if (!isLegalConstExprRootType(resultType)) { return; + } } - for (auto result : op->getResults()) + for (auto result : op->getResults()) { constantRoots[result] = op; + } }); }); @@ -135,8 +144,9 @@ ConstExprAnalysis::ConstExprAnalysis(Operation *rootOp) iterWorklist.clear(); iterWorklist.swap(worklist); for (ConstValueInfo *info : iterWorklist) { - if (info->state != ConstValueInfo::UNKNOWN) + if (info->state != ConstValueInfo::UNKNOWN) { continue; + } bool allConstants = true; for (ConstValueInfo *producerInfo : info->producers) { assert(producerInfo->state != ConstValueInfo::UNANALYZED && @@ -220,12 +230,14 @@ void ConstExprAnalysis::expandToOpStep( ConstExprOpInfo opInfo = ConstExprOpInfo::getForOp(op); for (auto result : op->getResults()) { auto *valueInfo = constInfoMap.lookup(result); - if (valueInfo && valueInfo->state != ConstValueInfo::UNANALYZED) + if (valueInfo && valueInfo->state != ConstValueInfo::UNANALYZED) { continue; + } // Generate new info record. - if (!valueInfo) + if (!valueInfo) { valueInfo = addInfo(result); + } // Update the producers first as we might early-return below. for (Value producer : opInfo.producers) { @@ -288,8 +300,9 @@ void ConstExprAnalysis::expandToOpStep( void ConstExprAnalysis::print(raw_ostream &os) const { os << "[ConstExprAnalysis] found constants:\n"; for (auto &info : allocedConstInfos) { - if (info->state != ConstValueInfo::CONSTANT || info->isRoot) + if (info->state != ConstValueInfo::CONSTANT || info->isRoot) { continue; + } if (!info->roots.empty()) { os << "\n[ConstExprAnalysis] constexpr "; info->constValue.print(os, asmState); @@ -334,8 +347,9 @@ void ConstExprHoistingPolicy::initialize() { for (auto &it : analysis.allocedConstInfos) { auto *info = it.get(); // Skip unanalyzed values. - if (info->state == ConstExprAnalysis::ConstValueInfo::UNANALYZED) + if (info->state == ConstExprAnalysis::ConstValueInfo::UNANALYZED) { continue; + } worklist.push_back(info); } @@ -366,8 +380,9 @@ void ConstExprHoistingPolicy::initialize() { bool madeChange = false; for (auto *info : worklist) { Decision *decision = getDecision(info); - if (decision->getOutcome() != UNDECIDED) + if (decision->getOutcome() != UNDECIDED) { continue; + } makeDecision(info, decision); if (decision->getOutcome() != UNDECIDED) { @@ -481,8 +496,9 @@ void ConstExprHoistingPolicy::makeDecision( if (!hasLegalEscape) { for (auto *consumerInfo : info->consumers) { Decision *consumerDecision = getDecision(consumerInfo); - if (consumerDecision->getOutcome() != DISABLE_HOIST) + if (consumerDecision->getOutcome() != DISABLE_HOIST) { continue; + } Operation *consumerOp = consumerInfo->getOperation(); OpOperand *consumerOperand = findOperandFor(consumerOp, info->constValue); @@ -544,13 +560,15 @@ struct DOTGraphTraits getNodeAttributes(const ConstExprAnalysis::ConstValueInfo *Node, const ConstExprHoistingPolicy *g) { // Roots are colored red. - if (Node->isRoot) + if (Node->isRoot) { return "fillcolor=red,style=filled"; + } // Hoisted values are colored green. ConstExprHoistingPolicy::Outcome outcome = g->getOutcome(Node); - if (outcome == ConstExprHoistingPolicy::Outcome::ENABLE_HOIST) + if (outcome == ConstExprHoistingPolicy::Outcome::ENABLE_HOIST) { return "fillcolor=green,style=filled"; + } return ""; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h index 9c0bb4290767..235743357bf9 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h @@ -44,8 +44,9 @@ class ConstExprAnalysis { // uninitialized. If they are all initialized, then they will either be all // const-expr or all non const-expr, so just return the first result's info. const ConstValueInfo *lookup(Operation *queryOp) const { - if (queryOp->getNumResults() == 0) + if (queryOp->getNumResults() == 0) { return nullptr; + } if (llvm::any_of(queryOp->getResults(), [&](Value v) { return !lookup(v); })) { return nullptr; @@ -56,8 +57,9 @@ class ConstExprAnalysis { // Returns true if the given value is only derived from immutable inputs. bool isConstExprValue(Value queryValue) const { ConstValueInfo *found = constInfoMap.lookup(queryValue); - if (!found) + if (!found) { return false; + } return found->state == ConstValueInfo::CONSTANT; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp index da22db8f0bcd..4045bace9b0a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp @@ -97,8 +97,9 @@ static bool isEligibleConstExpr(Operation *op) { Operation *parent = op; while (auto hoistableParent = parent->getParentOfType()) { - if (hoistableParent.isAtomicallyHoistableOp()) + if (hoistableParent.isAtomicallyHoistableOp()) { return false; + } parent = hoistableParent; } @@ -154,8 +155,9 @@ bool isHoistableConstExprLeaf(const ConstExprAnalysis::ConstValueInfo *info) { // If implementing the HoistableOpInterface, check whether the op is legal to // hoist. We still need to check for type legality afterwards though. if (auto hoistableOp = dyn_cast(op)) { - if (!hoistableOp.isHoistableLeafOp()) + if (!hoistableOp.isHoistableLeafOp()) { return false; + } } // If implementing the HoistableTypeInterface, at this point we can just diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp index d1a5a116e9cf..5361ad5fa02a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/DepGraph.cpp @@ -49,8 +49,9 @@ void DepGraph::dumpGraph() { std::error_code ec; llvm::raw_fd_ostream file(filename, ec, llvm::sys::fs::OF_TextWithCRLF); - if (!ec) + if (!ec) { llvm::WriteGraph(file, this); + } callTimes++; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp index ca8d7329460b..4724a5c26d82 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.cpp @@ -15,8 +15,9 @@ namespace mlir::iree_compiler::DFX { ChangeStatus AbstractElement::update(Solver &solver) { ChangeStatus changeStatus = ChangeStatus::UNCHANGED; - if (getState().isAtFixpoint()) + if (getState().isAtFixpoint()) { return changeStatus; + } LLVM_DEBUG({ llvm::dbgs() << "[Solver] updating: "; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h index 0a7d95fc28aa..803feb60958f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Element.h @@ -143,8 +143,9 @@ struct TypedOperationElement : public AbstractElement { ChangeStatus updateImpl(Solver &solver) override { if (isOperation()) { auto op = dyn_cast(getOperation()); - if (op) + if (op) { return updateOperation(op, solver); + } } return getState().indicatePessimisticFixpoint(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp index adbe01cf13f6..80df8a1736ef 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp @@ -113,8 +113,9 @@ LogicalResult Solver::runTillFixpoint(int maxIterations) { // Use the invalidElements vector to propagate invalid states fast // transitively without requiring updates. - if (!elementState.isValidState()) + if (!elementState.isValidState()) { invalidElements.insert(element); + } } // Add elements to the changed set if they have been created in the last @@ -140,8 +141,9 @@ LogicalResult Solver::runTillFixpoint(int maxIterations) { SmallPtrSet visitedElements; for (size_t i = 0; i < changedElements.size(); i++) { auto *changedElement = changedElements[i]; - if (!visitedElements.insert(changedElement).second) + if (!visitedElements.insert(changedElement).second) { continue; + } auto &elementState = changedElement->getState(); if (!elementState.isAtFixpoint()) { @@ -183,8 +185,9 @@ ChangeStatus Solver::updateElement(AbstractElement &element) { // will not change and we can indicate that right away. elementState.indicateOptimisticFixpoint(); } - if (!elementState.isAtFixpoint()) + if (!elementState.isAtFixpoint()) { rememberDependencies(); + } // Verify the stack is balanced by ensuring we pop the vector we pushed above. auto *poppedDependencies = dependencyStack.pop_back_val(); @@ -198,15 +201,18 @@ ChangeStatus Solver::updateElement(AbstractElement &element) { void Solver::recordDependency(const AbstractElement &fromElement, const AbstractElement &toElement, Resolution resolution) { - if (resolution == Resolution::NONE) + if (resolution == Resolution::NONE) { return; + } // If we are outside of an update, thus before the actual fixpoint iteration // started (= when we create elements), we do not track dependencies because // we will put all elements into the initial worklist anyway. - if (dependencyStack.empty()) + if (dependencyStack.empty()) { return; - if (fromElement.getState().isAtFixpoint()) + } + if (fromElement.getState().isAtFixpoint()) { return; + } // NOTE: this may record several of the same dependency as there is no // deduplication. Deduplication is more expensive than the rarer case of // duplication, though, so we deal with it. diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h index d340812ecd83..8bec39cd311c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h @@ -193,8 +193,9 @@ class Solver { // Lookup the abstract element of type ElementT and if found return it after // registering a dependence of queryingElement on the one returned element. auto *elementPtr = elementMap.lookup({&ElementT::ID, pos}); - if (!elementPtr) + if (!elementPtr) { return nullptr; + } auto *element = static_cast(elementPtr); // Do not register a dependence on an element with an invalid state. diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp index d765f4c838c4..3966328941ec 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.cpp @@ -24,10 +24,12 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, if (!S.isValidState()) { os << "full-set"; } else { - for (auto &it : S.getAssumedSet()) + for (auto &it : S.getAssumedSet()) { os << it << ", "; - if (S.isUndefContained()) + } + if (S.isUndefContained()) { os << "undef "; + } } os << "} >)"; return os; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h index 75ac3c11e831..5782b9a3d272 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/DFX/State.h @@ -195,12 +195,14 @@ struct BooleanState : public IntegerStateBase { private: void handleNewKnownValue(base_t value) override { - if (value) + if (value) { known = (assumed = value); + } } void handleNewAssumedValue(base_t value) override { - if (!value) + if (!value) { assumed = known; + } } void joinOR(base_t assumedValue, base_t knownValue) override { @@ -423,12 +425,15 @@ struct PotentialValuesState : AbstractState { } bool operator==(const PotentialValuesState &rhs) const { - if (isValidState() != rhs.isValidState()) + if (isValidState() != rhs.isValidState()) { return false; - if (!isValidState() && !rhs.isValidState()) + } + if (!isValidState() && !rhs.isValidState()) { return true; - if (isUndefContained() != rhs.isUndefContained()) + } + if (isUndefContained() != rhs.isUndefContained()) { return false; + } return set == rhs.getAssumedSet(); } @@ -487,8 +492,9 @@ struct PotentialValuesState : AbstractState { // Inserts an element into this set. void insert(const MemberTy &c) { - if (!isValidState()) + if (!isValidState()) { return; + } set.insert(c); checkAndInvalidate(); } @@ -496,15 +502,17 @@ struct PotentialValuesState : AbstractState { // Takes union with |rhs|. void unionWith(const PotentialValuesState &rhs) { // If this is a full set, do nothing. - if (!isValidState()) + if (!isValidState()) { return; + } // If rhs is full set, change L to a full set. if (!rhs.isValidState()) { indicatePessimisticFixpoint(); return; } - for (const MemberTy &c : rhs.set) + for (const MemberTy &c : rhs.set) { set.insert(c); + } undefIsContained |= rhs.isUndefContained(); checkAndInvalidate(); } @@ -518,8 +526,9 @@ struct PotentialValuesState : AbstractState { // Takes intersection with |rhs|. void intersectWith(const PotentialValuesState &rhs) { // If rhs is a full set, do nothing. - if (!rhs.isValidState()) + if (!rhs.isValidState()) { return; + } // If this is a full set, change this to rhs. if (!isValidState()) { *this = rhs; @@ -527,8 +536,9 @@ struct PotentialValuesState : AbstractState { } SetTy intersectSet; for (const MemberTy &c : set) { - if (rhs.set.count(c)) + if (rhs.set.count(c)) { intersectSet.insert(c); + } } set = intersectSet; undefIsContained &= rhs.isUndefContained(); diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index bcbd72cb1918..eb7621c10bf1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -33,8 +33,9 @@ static std::optional mapSuccessorOperand(BranchOpInterface branchOp, // I don't know if there's a better way to do this - the interface doesn't // help. auto operandRange = branchOp.getSuccessorOperands(successorIdx); - if (operandRange.empty()) + if (operandRange.empty()) { return std::nullopt; + } unsigned beginIdx = operandRange.getForwardedOperands().getBeginOperandIndex(); if (operandIdx >= beginIdx && operandIdx < beginIdx + operandRange.size()) { @@ -187,8 +188,9 @@ void Explorer::initializeInverseCallGraph() { const Explorer::GlobalInfo * Explorer::getGlobalInfo(IREE::Util::GlobalOpInterface globalOp) { auto it = globalInfos.find(globalOp); - if (it == globalInfos.end()) + if (it == globalInfos.end()) { return nullptr; + } return it->second.get(); } @@ -198,11 +200,13 @@ const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName, auto &symbolTable = symbolTables.getSymbolTable(symbolTableOp); auto op = symbolTable.lookupNearestSymbolFrom( from, StringAttr::get(from->getContext(), globalName)); - if (!op) + if (!op) { return nullptr; + } auto it = globalInfos.find(op); - if (it == globalInfos.end()) + if (it == globalInfos.end()) { return nullptr; + } return it->second.get(); } @@ -259,8 +263,9 @@ void Explorer::forEachFunctionLikeOp( } bool Explorer::mayValuesAlias(Value a, Value b) { - if (a == b) + if (a == b) { return true; + } bool mayAlias = false; auto traversalResult = walkTransitiveUses(a, [&](OpOperand &value) { mayAlias = value.get() == b; @@ -287,8 +292,9 @@ TraversalResult Explorer::walk(OperationWalkFn fn) { LLVM_DEBUG(llvm::dbgs() << "? entering scc slice with " << scc.size() << " callables\n"); for (auto *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } // Ensure we want to step into this region. // Note that SCC returns every function like in the whole program, @@ -296,8 +302,9 @@ TraversalResult Explorer::walk(OperationWalkFn fn) { auto &callableRegion = *node->getCallableRegion(); auto *callableOp = callableRegion.getParentOp(); auto action = getTraversalAction(callableOp); - if (action == TraversalAction::IGNORE) + if (action == TraversalAction::IGNORE) { continue; + } bool validInPlace = true; for (auto *parentOp = callableOp->getParentOp(); parentOp != rootOp; parentOp = parentOp->getParentOp()) { @@ -315,10 +322,12 @@ TraversalResult Explorer::walk(OperationWalkFn fn) { LLVM_DEBUG(llvm::dbgs() << " + entering callable region @" << getRegionName(callableRegion) << "\n"); auto emitResult = recursiveWalk(callableOp, fn); - if (emitResult.wasInterrupted()) + if (emitResult.wasInterrupted()) { break; - if (emitResult.wasSkipped()) + } + if (emitResult.wasSkipped()) { continue; + } } } @@ -338,10 +347,12 @@ WalkResult Explorer::recursiveWalk(Operation *parentOp, LLVM_DEBUG(llvm::dbgs() << " == emitting op " << getOpName(parentOp) << "\n"); auto emitResult = fn(parentOp); - if (emitResult.wasInterrupted()) + if (emitResult.wasInterrupted()) { return WalkResult::interrupt(); - if (emitResult.wasSkipped()) + } + if (emitResult.wasSkipped()) { return WalkResult::advance(); + } if (parentOp->getNumRegions() == 0 || parentAction != TraversalAction::RECURSE) { @@ -355,8 +366,9 @@ WalkResult Explorer::recursiveWalk(Operation *parentOp, for (auto &block : region.getBlocks()) { for (auto &op : block) { auto opResult = recursiveWalk(&op, fn); - if (opResult.wasInterrupted()) + if (opResult.wasInterrupted()) { return WalkResult::interrupt(); + } } } } @@ -374,8 +386,9 @@ TraversalResult Explorer::walkAllValues(ValueWalkFn fn, LLVM_DEBUG(llvm::dbgs() << "? entering scc slice with " << scc.size() << " callables\n"); for (auto *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } // Ensure we want to step into this region. // Note that SCC returns every function like in the whole program, @@ -383,8 +396,9 @@ TraversalResult Explorer::walkAllValues(ValueWalkFn fn, auto &callableRegion = *node->getCallableRegion(); auto *callableOp = callableRegion.getParentOp(); auto action = getTraversalAction(callableOp); - if (action == TraversalAction::IGNORE) + if (action == TraversalAction::IGNORE) { continue; + } bool validInPlace = true; for (auto *parentOp = callableOp->getParentOp(); parentOp != rootOp; parentOp = parentOp->getParentOp()) { @@ -403,10 +417,12 @@ TraversalResult Explorer::walkAllValues(ValueWalkFn fn, << getRegionName(callableRegion) << "\n"); auto emitResult = recursiveWalkValues(callableOp, visitedValues, fn, typeID); - if (emitResult.wasInterrupted()) + if (emitResult.wasInterrupted()) { break; - if (emitResult.wasSkipped()) + } + if (emitResult.wasSkipped()) { continue; + } } } @@ -442,16 +458,18 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, LLVM_DEBUG(llvm::dbgs() << " + processing op results " << getOpName(parentOp) << "\n"); for (auto result : parentOp->getResults()) { - if (typeID.has_value() && result.getType().getTypeID() != *typeID) + if (typeID.has_value() && result.getType().getTypeID() != *typeID) { continue; + } if (visitedValues.insert(result).second) { LLVM_DEBUG({ llvm::dbgs() << " == emitting value "; result.printAsOperand(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); - if (fn(result).wasInterrupted()) + if (fn(result).wasInterrupted()) { return WalkResult::interrupt(); + } } } } @@ -473,23 +491,26 @@ WalkResult Explorer::recursiveWalkValues(Operation *parentOp, llvm::dbgs() << " arguments\n"; }); for (auto arg : block.getArguments()) { - if (typeID.has_value() && arg.getType().getTypeID() != *typeID) + if (typeID.has_value() && arg.getType().getTypeID() != *typeID) { continue; + } if (visitedValues.insert(arg).second) { LLVM_DEBUG({ llvm::dbgs() << " == emitting block arg "; arg.printAsOperand(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); - if (fn(arg).wasInterrupted()) + if (fn(arg).wasInterrupted()) { return WalkResult::interrupt(); + } } } } for (auto &op : block) { auto opResult = recursiveWalkValues(&op, visitedValues, fn, typeID); - if (opResult.wasInterrupted()) + if (opResult.wasInterrupted()) { return WalkResult::interrupt(); + } } } } @@ -502,8 +523,9 @@ Explorer::walkIncomingCalls(CallableOpInterface callableOp, auto it = callGraphInv.find(callableOp.getCallableRegion()); if (it != callGraphInv.end()) { for (auto &callOp : it->second) { - if (fn(callOp).wasInterrupted()) + if (fn(callOp).wasInterrupted()) { break; + } } } bool isPublic = false; @@ -560,8 +582,9 @@ TraversalResult Explorer::walkReturnOps(Operation *parentOp, return WalkResult::advance(); }; for (auto ®ion : regionOp->getRegions()) { - if (enumerateTerminatorOps(region).wasInterrupted()) + if (enumerateTerminatorOps(region).wasInterrupted()) { break; + } } } else if (auto parentFuncOp = dyn_cast(parentOp)) { @@ -582,8 +605,9 @@ TraversalResult Explorer::walkReturnOps(Operation *parentOp, terminatorOp->print(llvm::dbgs(), asmState); llvm::dbgs() << "\n"; }); - if (fn(terminatorOp).wasInterrupted()) + if (fn(terminatorOp).wasInterrupted()) { break; + } } } } @@ -711,8 +735,9 @@ TraversalResult Explorer::walkOutgoingBranchOperandArguments( ++successorIdx) { auto successorOperandIdx = mapSuccessorOperand(branchOp, successorIdx, operandIdx); - if (!successorOperandIdx.has_value()) + if (!successorOperandIdx.has_value()) { continue; + } auto *targetBlock = branchOp->getSuccessor(successorIdx); auto blockArg = targetBlock->getArgument(*successorOperandIdx); if (fn(targetBlock, blockArg).wasInterrupted()) { @@ -833,8 +858,9 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn, << loadOp.getGlobalName() << ":\n"); for (auto *user : globalInfo->uses) { auto storeOp = dyn_cast(user); - if (!storeOp) + if (!storeOp) { continue; + } LLVM_DEBUG({ llvm::dbgs() << " + queuing stored value from "; storeOp.print(llvm::dbgs(), asmState); @@ -886,8 +912,9 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn, do { // Pop the next work item; avoiding processing values more than once. auto work = worklist.pop_back_val(); - if (!processedValues.insert(work.getAsOpaquePointer()).second) + if (!processedValues.insert(work.getAsOpaquePointer()).second) { continue; + } LLVM_DEBUG({ llvm::dbgs() << " ? working on "; @@ -1115,8 +1142,9 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn, << storeOp.getGlobalName() << ":\n"); for (auto *user : globalInfo->uses) { auto loadOp = dyn_cast(user); - if (!loadOp) + if (!loadOp) { continue; + } LLVM_DEBUG({ llvm::dbgs() << " + queuing loaded value from "; loadOp.print(llvm::dbgs(), asmState); @@ -1143,8 +1171,9 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn, // times!). for (auto &use : work.getUses()) { auto *ownerOp = use.getOwner(); - if (!processedValues.insert(&use).second) + if (!processedValues.insert(&use).second) { continue; + } auto action = getTraversalAction(ownerOp); if (action == TraversalAction::IGNORE) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp index f87864ebb0ad..2b17acd1bcb7 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp @@ -37,8 +37,9 @@ LogicalResult IntegerDivisibilityAnalysis::visitOperation( }); auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) { auto result = dyn_cast(v); - if (!result) + if (!result) { return; + } assert(llvm::is_contained(op->getResults(), result)); LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n"); diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp index a3e19b82aa48..e66e96252bb1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/FuncToUtil/Patterns.cpp @@ -98,8 +98,9 @@ struct FuncFuncOpPattern : public OpConversionPattern { for (auto retainAttrName : retainedAttributes) { StringRef attrName(retainAttrName); Attribute attr = srcOp->getAttr(attrName); - if (attr) + if (attr) { newFuncOp->setAttr(attrName, attr); + } } // Copy all arg/result attrs. We could filter these. diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp index d58a5c1cb9aa..de119b161b51 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/ClosureOpUtils.cpp @@ -155,8 +155,9 @@ static SmallVector findDuplicateRegionResults(Region ®ion) { auto uniformDupeIndexMap = llvm::to_vector(llvm::seq(0u, resultCount)); // old -> new for (unsigned idx = 0; idx < resultCount; ++idx) { - if (deadResultsMap.test(idx)) + if (deadResultsMap.test(idx)) { continue; + } // Each bit represents a result that duplicates the result at idx. // We walk all the sites and AND their masks together to get the safe // set of duplicate results. @@ -257,15 +258,17 @@ static void inlineClosureOperands(const ClosureOptimizationOptions &options, for (auto opArg : llvm::enumerate(closureOp.getClosureOperands())) { auto outerValue = opArg.value(); auto *sourceOp = outerValue.getDefiningOp(); - if (!sourceOp) + if (!sourceOp) { continue; // can't clone block arguments into closures + } // We cannot just simply inline and replace all users if this is an // argument that can be written; for example, the region might perform // work after loading a initial constant from the argument and then // write back. - if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) + if (!closureOp.getOperandAccess(opArg.index()).isReadOnly()) { continue; + } if (closureOp.canClosureContainOp(sourceOp) && shouldInlineIntoClosure(options, outerValue)) { diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp index 12b2127db67a..097fc4ad3395 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp @@ -258,15 +258,17 @@ class PackedWriter { : logicalBitWidth(logicalBitWidth), endian(endian), os(os) {} void write(const uint64_t value) { - if (bitOffset + logicalBitWidth > physicalBitWidth) + if (bitOffset + logicalBitWidth > physicalBitWidth) { flush(); + } physicalBuffer |= value << bitOffset; bitOffset += logicalBitWidth; } void flush() { - if (bitOffset == 0) + if (bitOffset == 0) { return; + } physicalType physicalValue = llvm::support::endian::byte_swap(physicalBuffer, endian); os.write((const char *)&physicalValue, sizeof(physicalValue)); @@ -533,8 +535,9 @@ LogicalResult BytePatternAttr::serializeToStream(Location loc, //===----------------------------------------------------------------------===// Attribute ByteRangeAttr::parse(AsmParser &p, Type type) { - if (failed(p.parseLess())) + if (failed(p.parseLess())) { return {}; + } // TODO(benvanik): support the range syntax; the dialect asm parser fights // with it though by checking for proper []/() nesting. @@ -573,8 +576,9 @@ Attribute ByteRangeAttr::parse(AsmParser &p, Type type) { return {}; } - if (failed(p.parseGreater())) + if (failed(p.parseGreater())) { return {}; + } start = startInclusive ? start : start + 1; end = endInclusive ? end : end - 1; @@ -912,8 +916,9 @@ void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp, } } } - if (auto *parentOp = fromOp->getParentOp()) + if (auto *parentOp = fromOp->getParentOp()) { gatherHoistableAttrs(parentOp, dialectAttrs); + } } // static @@ -923,8 +928,9 @@ void HoistableAttrInterface::gatherHoistableAttrs(Operation *fromOp, // precedence over any from ancestors. We also want to preserve any // non-hoistable attrs when we reassign the dialect attrs. NamedAttrList dialectAttrs; - for (auto attr : toOp->getDialectAttrs()) + for (auto attr : toOp->getDialectAttrs()) { dialectAttrs.push_back(attr); + } // Gather attributes from the op and its parents, only adding ones not already // set on the op. diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp index 86ed3a8281a5..d116e5905e12 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp @@ -54,8 +54,9 @@ struct UtilInlinerInterface : public DialectInlinerInterface { if (auto inliningPolicy = callable->getAttrOfType( "inlining_policy")) { - if (!inliningPolicy.isLegalToInline(call, callable)) + if (!inliningPolicy.isLegalToInline(call, callable)) { return false; + } } // Check any extended inlining policies that may come from dialect @@ -64,8 +65,9 @@ struct UtilInlinerInterface : public DialectInlinerInterface { if (auto inliningPolicy = dyn_cast( attr.getValue())) { - if (!inliningPolicy.isLegalToInline(call, callable)) + if (!inliningPolicy.isLegalToInline(call, callable)) { return false; + } } } @@ -86,8 +88,9 @@ struct UtilInlinerInterface : public DialectInlinerInterface { } void handleTerminator(Operation *op, Block *newDest) const final { - if (!op->hasTrait()) + if (!op->hasTrait()) { return; + } OpBuilder builder(op); if (auto returnOp = dyn_cast(op)) { @@ -159,8 +162,9 @@ struct FoldDimOp : public OpRewritePattern { } auto shapeAwareOp = dyn_cast_if_present(source.getDefiningOp()); - if (!shapeAwareOp) + if (!shapeAwareOp) { return failure(); + } // We only support static dimension indices today (as in general we only // support ranked shapes). If we find dynamic indices sneaking in we will diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp index 59f09570eb48..44cd74d55d91 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp @@ -81,14 +81,16 @@ static LogicalResult canonicalizeAssumeIntOp(AssumeIntOp op, needsRewrite = true; } } - if (!needsRewrite) + if (!needsRewrite) { return failure(); + } // Need to rewrite the assumption. auto normalizeAssumptions = [](Attribute row, bool &madeChange) { auto rowArray = cast(row); - if (rowArray.size() <= 1) + if (rowArray.size() <= 1) { return rowArray; + } bool allSame = true; for (unsigned i = 1; i < rowArray.size(); ++i) { @@ -98,8 +100,9 @@ static LogicalResult canonicalizeAssumeIntOp(AssumeIntOp op, } } - if (!allSame) + if (!allSame) { return rowArray; + } // All entries are the same: compress down to a single column. madeChange = true; @@ -350,8 +353,9 @@ struct FoldCastIntoNullOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto nullOp = dyn_cast_if_present(castOp.getOperand().getDefiningOp()); - if (!nullOp) + if (!nullOp) { return failure(); + } rewriter.replaceOpWithNewOp(castOp, castOp.getResult().getType()); return success(); } @@ -425,8 +429,9 @@ static OpFoldResult foldRangeOp(Type type, ValueRange operands, int64_t value = initialValue; for (auto operand : attrOperands) { auto intValue = dyn_cast_if_present(operand); - if (!intValue) + if (!intValue) { return {}; + } value = expr(value, intValue.getValue().getSExtValue()); } return IntegerAttr::get(type, value); @@ -566,8 +571,9 @@ struct FoldConstantRanges : public OpRewritePattern { lengths.push_back(length); } } - if (offsets.size() == op.getOffsets().size()) + if (offsets.size() == op.getOffsets().size()) { return failure(); + } // Preserve dynamic ranges. Value min; @@ -627,8 +633,9 @@ struct ExpandSimpleRangeExtentsOp : public OpRewritePattern { op.getLengths().back(), one, rewriter); maxValue = arith::MaxUIOp::create(rewriter, loc, endLhs, endRhs); } - if (!minValue || !maxValue) + if (!minValue || !maxValue) { return failure(); + } rewriter.replaceOp(op, {minValue, maxValue}); return success(); } @@ -645,8 +652,9 @@ struct DeduplicateRangeExtentsOp : public OpRewritePattern { for (auto range : llvm::zip_equal(op.getOffsets(), op.getLengths())) { ranges.insert(range); } - if (ranges.size() == op.getOffsets().size()) + if (ranges.size() == op.getOffsets().size()) { return failure(); + } // Recreate with the deduplicated ranges. SmallVector offsets; @@ -702,8 +710,9 @@ static bool isAlignedTo(Value value, Value alignment) { // If the value is produced by an align op we can check that. if (auto sourceAlignOp = value.getDefiningOp()) { // Check for same exact alignment - even if dynamic. - if (sourceAlignOp.getAlignment() == alignment) + if (sourceAlignOp.getAlignment() == alignment) { return true; + } // If the alignments are constant we can compare them inline. APInt sourceAlignment; @@ -762,8 +771,9 @@ static bool isAlignedTo(Value value, Value alignment) { OpFoldResult AlignOp::fold(FoldAdaptor operands) { // If aligning an already-aligned value then fold if this is provably a // no-op. We can check this for equality even with dynamic alignments. - if (isAlignedTo(getValue(), getAlignment())) + if (isAlignedTo(getValue(), getAlignment())) { return getValue(); + } // If values are static we can perform the alignment here. APInt staticValue; @@ -992,8 +1002,9 @@ struct DropEmptyInitializerOp : public OpRewritePattern { LogicalResult matchAndRewrite(InitializerOp op, PatternRewriter &rewriter) const override { - if (op.getBody().getBlocks().size() != 1) + if (op.getBody().getBlocks().size() != 1) { return failure(); + } auto &block = op.getBody().front(); // Empty block or block with only a ReturnLike terminator. if (block.empty() || (block.getOperations().size() == 1 && @@ -1128,8 +1139,9 @@ struct FoldBufferSubspanOps : public OpRewritePattern { LogicalResult matchAndRewrite(BufferSubspanOp op, PatternRewriter &rewriter) const override { auto parentOp = BufferSubspanOp::findSubspanOp(op.getSource()); - if (!parentOp) + if (!parentOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({parentOp.getLoc(), op.getLoc()}); auto newOffset = rewriter.createOrFold( fusedLoc, parentOp.getSourceOffset(), op.getSourceOffset()); @@ -1159,8 +1171,9 @@ struct FoldBufferSubspanOpsIntoConsumers for (auto &use : llvm::make_early_inc_range(op.getResult().getUses())) { auto subrangeOp = dyn_cast(use.getOwner()); - if (!subrangeOp) + if (!subrangeOp) { continue; + } didUpdateAny = true; rewriter.setInsertionPoint(subrangeOp); auto oldRange = subrangeOp.getSubrangeOperand(use.getOperandNumber()); @@ -1193,14 +1206,16 @@ struct SinkSubspanAcrossSelectOps using Base::Base; LogicalResult matchAndRewrite(mlir::arith::SelectOp op, PatternRewriter &rewriter) const override { - if (!isa(op.getType())) + if (!isa(op.getType())) { return failure(); + } auto trueSubspan = dyn_cast_if_present( op.getTrueValue().getDefiningOp()); auto falseSubspan = dyn_cast_if_present( op.getFalseValue().getDefiningOp()); - if (!trueSubspan || !falseSubspan) + if (!trueSubspan || !falseSubspan) { return failure(); + } if (trueSubspan.getSource() != falseSubspan.getSource() || trueSubspan.getResultSize() != falseSubspan.getResultSize()) { return failure(); @@ -1275,8 +1290,9 @@ struct SelectBufferSizeOp : public OpRewritePattern { LogicalResult matchAndRewrite(BufferSizeOp op, PatternRewriter &rewriter) const override { auto selectOp = op.getOperand().getDefiningOp(); - if (!selectOp) + if (!selectOp) { return failure(); + } auto trueSize = rewriter.createOrFold( op.getLoc(), selectOp.getTrueValue()); auto falseSize = rewriter.createOrFold( @@ -1313,8 +1329,9 @@ struct FoldSubspansIntoStorageOp : public OpRewritePattern { LogicalResult matchAndRewrite(BufferStorageOp op, PatternRewriter &rewriter) const override { auto subspanOp = BufferSubspanOp::findSubspanOp(op.getOperand()); - if (!subspanOp) + if (!subspanOp) { return failure(); + } auto fusedLoc = rewriter.getFusedLoc({subspanOp.getLoc(), op.getLoc()}); rewriter.setInsertionPointAfter(op); auto newOffset = rewriter.createOrFold( diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 9f6b6f4adb2e..1baaaf55b6c1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -66,8 +66,9 @@ Value buildIfElseTree( ArrayAttr deduplicateArrayElements(ArrayAttr arrayAttr) { SetVector attrsSet(arrayAttr.begin(), arrayAttr.end()); - if (attrsSet.size() == arrayAttr.size()) + if (attrsSet.size() == arrayAttr.size()) { return arrayAttr; + } return ArrayAttr::get(arrayAttr.getContext(), attrsSet.takeVector()); } @@ -202,8 +203,9 @@ void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, needsSpace = true; // subsequent attr value needs a space separator } if (attr) { - if (needsSpace) + if (needsSpace) { p << ' '; + } p << "= "; p.printAttribute(attr); } @@ -249,12 +251,14 @@ void printSymbolAlias(OpAsmPrinter &p, Operation *op, StringAttr sym_name, ParseResult parseTypeAlias(OpAsmParser &parser, TypeAttr &encodingTypeAttr, Type &storageType) { Type encodingType; - if (failed(parser.parseType(encodingType))) + if (failed(parser.parseType(encodingType))) { return failure(); + } storageType = encodingType; if (succeeded(parser.parseOptionalKeyword("as"))) { - if (failed(parser.parseType(storageType))) + if (failed(parser.parseType(storageType))) { return failure(); + } } encodingTypeAttr = TypeAttr::get(encodingType); return success(); @@ -356,18 +360,22 @@ void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, Value size) { ParseResult parseOperandTypeList(OpAsmParser &parser, SmallVectorImpl &operandTypes) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); - if (succeeded(parser.parseOptionalRParen())) + } + if (succeeded(parser.parseOptionalRParen())) { return success(); // empty + } do { Type type; - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } operandTypes.push_back(type); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } return success(); } @@ -403,8 +411,9 @@ parseTiedResultList(OpAsmParser &parser, } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = operandTypes[tiedOperandIndex]; @@ -443,8 +452,9 @@ void printTiedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, if (printType) { p.printType(resultType); } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -476,8 +486,9 @@ parseTiedFunctionResultListImpl(OpAsmParser &parser, } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = arguments[tiedOperandIndex].type; @@ -566,11 +577,13 @@ void printTiedFunctionResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, TypeRange operandTypes, TypeRange resultTypes, ArrayAttr tiedOperands) { - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printTiedResultList(p, op, operands, operandTypes, resultTypes, tiedOperands); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } //===----------------------------------------------------------------------===// @@ -583,8 +596,9 @@ parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl &types, SmallVectorImpl &dims) { do { Type type; - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } if (auto shapedType = dyn_cast(type)) { if (!shapedType.hasStaticShape()) { SmallVector dynamicDims; @@ -639,8 +653,9 @@ ParseResult parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl &types0, SmallVectorImpl &types1, SmallVectorImpl &dims) { - if (failed(parseShapedTypeList(parser, types0, dims))) + if (failed(parseShapedTypeList(parser, types0, dims))) { return failure(); + } types1 = types0; return success(); } @@ -672,11 +687,13 @@ ParseResult parseShapedTiedResult( int64_t tiedOperandIndex = IREE::Util::TiedOpInterface::kUntiedIndex; if (res.has_value() && succeeded(res.value())) { tiedOperandIndex = 0; - if (failed(parser.parseKeyword("as"))) + if (failed(parser.parseKeyword("as"))) { return failure(); + } } - if (failed(parser.parseType(resultType))) + if (failed(parser.parseType(resultType))) { return failure(); + } if (auto shapedType = dyn_cast(resultType)) { if (!shapedType.hasStaticShape()) { SmallVector dynamicDims; @@ -766,8 +783,9 @@ ParseResult parseShapedResultList( } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = operandTypes[tiedOperandIndex]; @@ -848,8 +866,9 @@ void printShapedResultList(OpAsmPrinter &p, Operation *op, ValueRange operands, p << "}"; resultDims = resultDims.drop_front(1); } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -865,16 +884,18 @@ ParseResult parseShapedFunctionType( SmallVectorImpl &resultTypes, SmallVectorImpl &resultDims, ArrayAttr &tiedOperands) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } if (failed(parser.parseOptionalRParen())) { if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) || failed(parser.parseRParen())) { return failure(); } } - if (failed(parser.parseArrow())) + if (failed(parser.parseArrow())) { return failure(); + } if (succeeded(parser.parseOptionalLParen())) { if (succeeded(parser.parseOptionalRParen())) { // Empty list/no results `()`. @@ -905,12 +926,14 @@ void printShapedFunctionType(OpAsmPrinter &p, Operation *op, p << "("; printShapedTypeList(p, op, operandTypes, operandDims); p << ") -> "; - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << "("; + } printShapedResultList(p, op, operands, operandTypes, operandDims, resultTypes, resultDims, tiedOperands); - if (resultTypes.size() != 1) + if (resultTypes.size() != 1) { p << ")"; + } } //===----------------------------------------------------------------------===// @@ -961,8 +984,9 @@ static ParseResult parseShapedFunctionResultList( } if (succeeded(parser.parseOptionalKeyword("as"))) { // Type _may_ differ from the operand. - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } } else { // Use the operands type. type = argTypes[tiedOperandIndex]; @@ -1016,8 +1040,9 @@ static void printShapedFunctionResultList(OpAsmPrinter &p, Operation *op, p.printOptionalAttrDict(attrs.getValue()); } } - if (i < resultTypes.size() - 1) + if (i < resultTypes.size() - 1) { p << ", "; + } } } @@ -1029,8 +1054,9 @@ ParseResult parseShapedFunctionSignature(OpAsmParser &parser, SmallVector args; SmallVector argTypes; SmallVector resultTypes; - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } if (failed(parser.parseOptionalRParen())) { if (failed(parseShapedFunctionArgumentList(parser, args, argTypes, argAttrs)) || @@ -1074,8 +1100,9 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, if (argAttrs) { auto attrs = dyn_cast_if_present(argAttrs.getValue()[argIndex]); - if (attrs && !attrs.empty()) + if (attrs && !attrs.empty()) { p.printOptionalAttrDict(attrs.getValue()); + } } ++argIndex; }); @@ -1087,12 +1114,14 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, resultAttrs && !resultAttrs.empty() && llvm::any_of(resultAttrs.getAsValueRange(), [](auto attr) { return !attr.empty(); }); - if (resultTypes.size() != 1 || anyResultAttrs) + if (resultTypes.size() != 1 || anyResultAttrs) { p << "("; + } printShapedFunctionResultList(p, op, functionType.getInputs(), resultTypes, resultAttrs, tiedOperands); - if (resultTypes.size() != 1 || anyResultAttrs) + if (resultTypes.size() != 1 || anyResultAttrs) { p << ")"; + } } } @@ -1121,24 +1150,27 @@ void AlignOp::inferResultRanges(ArrayRef argRanges, auto align = [&](APInt value, bool &invalid) -> APInt { APInt aligned = (value + alignmentM1) & alignmentM1Inv; // Detect overflow, which commonly happens at max range. - if (aligned.ult(value)) + if (aligned.ult(value)) { invalid = true; + } return aligned; }; bool invalid = false; auto alignedUmin = align(umin, invalid); auto alignedUmax = align(umax, invalid); - if (!invalid) + if (!invalid) { setResultRange(getResult(), ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax)); + } } } void AlignOp::inferResultDivisibility(ArrayRef argDivs, SetIntDivisibilityFn setResultDivs) { auto alignmentDiv = argDivs[1]; - if (alignmentDiv.isUninitialized()) + if (alignmentDiv.isUninitialized()) { return; + } setResultDivs(getResult(), alignmentDiv.getValue()); } @@ -1186,8 +1218,9 @@ AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { static bool isConstantZero(IntAssumptionAttr assumption) { std::optional umin = assumption.getUmin(); std::optional umax = assumption.getUmax(); - if (!umin || !umax) + if (!umin || !umax) { return false; + } return *umin == 0 && *umax == 0; } @@ -1199,14 +1232,16 @@ AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) { auto divisor = assumption.getUdiv(); if (!divisor) { // Constant zero is divisible by anything - if (isConstantZero(assumption)) + if (isConstantZero(assumption)) { continue; + } return std::nullopt; } - if (divisorUnion) + if (divisorUnion) { divisorUnion = std::gcd(*divisor, *divisorUnion); - else + } else { divisorUnion = *divisor; + } } return divisorUnion; } @@ -1216,19 +1251,22 @@ void AssumeIntOp::inferResultRanges(ArrayRef argRanges, for (auto [index, result] : llvm::enumerate(getResults())) { Type type = result.getType(); unsigned bitWidth; - if (isa(type)) + if (isa(type)) { bitWidth = 64; - else if (auto intType = dyn_cast(type)) + } else if (auto intType = dyn_cast(type)) { bitWidth = intType.getWidth(); - else + } else { continue; + } auto [umin, umax] = getUnionedUnsignedRange(index); auto uminAp = APInt::getMinValue(bitWidth); auto umaxAp = APInt::getMaxValue(bitWidth); - if (umin) + if (umin) { uminAp = APInt(bitWidth, *umin); - if (umax) + } + if (umax) { umaxAp = APInt(bitWidth, *umax); + } setResultRange(result, ConstantIntRanges::fromUnsigned(uminAp, umaxAp)); } @@ -1238,8 +1276,9 @@ void AssumeIntOp::inferResultDivisibility(ArrayRef argDivs, SetIntDivisibilityFn setResultDivs) { for (auto [index, result] : llvm::enumerate(getResults())) { Type type = result.getType(); - if (!isa(type) && !isa(type)) + if (!isa(type) && !isa(type)) { continue; + } auto udiv = getUnionedUnsignedDivisor(index); if (udiv) { setResultDivs(result, @@ -1261,8 +1300,9 @@ void AssumeIntOp::build(OpBuilder &builder, OperationState &state, ArrayRef operands, ArrayRef assumptions) { state.addOperands(operands); - for (auto operand : operands) + for (auto operand : operands) { state.addTypes({operand.getType()}); + } state.addAttribute("assumptions", ArrayAttr::get(builder.getContext(), ArrayRef(assumptions.begin(), @@ -1282,12 +1322,14 @@ LogicalResult AssumeIntOp::verify() { llvm::enumerate(allOperandAssumptions)) { auto operandAssumptions = cast(operandAssumptionsAttr); // We always allow a single row to broadcast to any requested size. - if (operandAssumptions.size() == 1) + if (operandAssumptions.size() == 1) { continue; - if (rank && *rank != operandAssumptions.size()) + } + if (rank && *rank != operandAssumptions.size()) { return emitOpError() << "expected operand #" << index << " to have " << *rank << " assumptions but it has " << operandAssumptions.size(); + } rank = operandAssumptions.size(); } @@ -1304,29 +1346,35 @@ ParseResult AssumeIntOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand &parsedOperand = parsedOperands.back(); SmallVector operandAssumptions; - if (parser.parseOperand(parsedOperand)) + if (parser.parseOperand(parsedOperand)) { return failure(); + } // Parse as a single assumption or a list. if (failed(parser.parseOptionalLSquare())) { // Single assumption. IntAssumptionAttr singleAssumption; - if (parser.parseCustomAttributeWithFallback(singleAssumption)) + if (parser.parseCustomAttributeWithFallback(singleAssumption)) { return failure(); + } operandAssumptions.push_back(singleAssumption); } else { // Multiple assumptions. if (failed(parser.parseOptionalRSquare())) { if (parser.parseCommaSeparatedList([&]() { IntAssumptionAttr singleAssumption; - if (parser.parseCustomAttributeWithFallback(singleAssumption)) + if (parser.parseCustomAttributeWithFallback( + singleAssumption)) { return failure(); + } operandAssumptions.push_back(singleAssumption); return success(); - })) + })) { return failure(); - if (parser.parseRSquare()) + } + if (parser.parseRSquare()) { return failure(); + } } } @@ -1335,22 +1383,26 @@ ParseResult AssumeIntOp::parse(OpAsmParser &parser, OperationState &result) { parser.getBuilder().getArrayAttr(operandAssumptions)); return success(); - })) + })) { return failure(); + } // Parse `:` type. - if (parser.parseColon() || parser.parseTypeList(parsedOperandTypes)) + if (parser.parseColon() || parser.parseTypeList(parsedOperandTypes)) { return failure(); + } result.addTypes(parsedOperandTypes); if (parser.resolveOperands(parsedOperands, parsedOperandTypes, - parser.getNameLoc(), result.operands)) + parser.getNameLoc(), result.operands)) { return failure(); + } result.attributes.append( "assumptions", parser.getBuilder().getArrayAttr(allOperandAssumptions)); - if (parser.parseOptionalAttrDict(result.attributes)) + if (parser.parseOptionalAttrDict(result.attributes)) { return failure(); + } return success(); } @@ -1425,15 +1477,17 @@ ParseResult UnfoldableConstantOp::parse(OpAsmParser &parser, OperationState &state) { Attribute valueAttr; if (parser.parseOptionalAttrDict(state.attributes) || - parser.parseAttribute(valueAttr, "value", state.attributes)) + parser.parseAttribute(valueAttr, "value", state.attributes)) { return failure(); + } // If the attribute is a symbol reference, then we expect a trailing type. Type type; - if (!isa(valueAttr)) + if (!isa(valueAttr)) { type = cast(valueAttr).getType(); - else if (parser.parseColonType(type)) + } else if (parser.parseColonType(type)) { return failure(); + } // Add the attribute type to the list. return parser.addTypeToList(type, state.types); @@ -1444,13 +1498,15 @@ void UnfoldableConstantOp::print(OpAsmPrinter &p) { p << " "; p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); - if (op->getAttrs().size() > 1) + if (op->getAttrs().size() > 1) { p << ' '; + } p << getValue(); // If the value is a symbol reference, print a trailing type. - if (isa(getValue())) + if (isa(getValue())) { p << " : " << getType(); + } } //===----------------------------------------------------------------------===// @@ -1458,8 +1514,9 @@ void UnfoldableConstantOp::print(OpAsmPrinter &p) { //===----------------------------------------------------------------------===// bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { - if (inputs.size() != 1 || outputs.size() != 1) + if (inputs.size() != 1 || outputs.size() != 1) { return false; + } Type a = inputs.front(), b = outputs.front(); if (a == b) { // Both types are the same. @@ -1509,8 +1566,9 @@ SmallVector CastOp::getTiedResultOperandIndices() { std::optional> NumericOptionalNarrowOp::getIntegerRange() { - if (!getMinValue() || !getMaxValue()) + if (!getMinValue() || !getMaxValue()) { return {}; + } bool signExtend = isSigned(); // Note: Cannot sign extend 0 bit values. int64_t minValue = signExtend && getMinValue()->getBitWidth() > 0 @@ -1621,22 +1679,26 @@ parseFunctionArgumentList(OpAsmParser &parser, auto argPresent = parser.parseOptionalArgument( argument, /*allowType=*/true, /*allowAttrs=*/true); if (argPresent.has_value()) { - if (failed(argPresent.value())) + if (failed(argPresent.value())) { return failure(); // Present but malformed. - if (!arguments.empty() && arguments.back().ssaName.name.empty()) + } + if (!arguments.empty() && arguments.back().ssaName.name.empty()) { return parser.emitError(argument.ssaName.location, "expected type instead of SSA identifier"); + } } else { argument.ssaName.location = parser.getCurrentLocation(); - if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) { return parser.emitError(argument.ssaName.location, "expected SSA identifier"); + } NamedAttrList attrs; if (parser.parseType(argument.type) || parser.parseOptionalAttrDict(attrs) || - parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) { return failure(); + } argument.attrs = attrs.getDictionary(parser.getContext()); } arguments.push_back(argument); @@ -1648,52 +1710,61 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); StringAttr symVisibilityAttr; - if (failed(parseSymbolVisibility(parser, symVisibilityAttr))) + if (failed(parseSymbolVisibility(parser, symVisibilityAttr))) { return failure(); - if (symVisibilityAttr) + } + if (symVisibilityAttr) { result.addAttribute(SymbolTable::getVisibilityAttrName(), symVisibilityAttr); + } StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), - result.attributes)) + result.attributes)) { return failure(); + } SmallVector arguments; - if (parseFunctionArgumentList(parser, arguments)) + if (parseFunctionArgumentList(parser, arguments)) { return failure(); + } SmallVector resultTypes; SmallVector resultAttrs; ArrayAttr tiedOperands; if (succeeded(parser.parseOptionalArrow())) { if (failed(parseTiedFunctionResultList(parser, arguments, resultTypes, - resultAttrs, tiedOperands))) + resultAttrs, tiedOperands))) { return failure(); + } } - if (tiedOperands) + if (tiedOperands) { result.addAttribute("tied_operands", tiedOperands); + } SmallVector argumentTypes; - for (auto argument : arguments) + for (auto argument : arguments) { argumentTypes.push_back(argument.type); + } result.addAttribute("function_type", TypeAttr::get(builder.getFunctionType( argumentTypes, resultTypes))); NamedAttrList parsedAttributes; SMLoc attributeDictLocation = parser.getCurrentLocation(); - if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) { return failure(); + } for (StringRef disallowed : { SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), StringRef("function_type"), }) { - if (parsedAttributes.get(disallowed)) + if (parsedAttributes.get(disallowed)) { return parser.emitError(attributeDictLocation, "'") << disallowed << "' is an inferred attribute and should not be specified in the " "explicit attribute dictionary"; + } } result.attributes.append(parsedAttributes); @@ -1707,10 +1778,12 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto parseResult = parser.parseOptionalRegion(*body, arguments, /*enableNameShadowing=*/false); if (parseResult.has_value()) { - if (failed(*parseResult)) + if (failed(*parseResult)) { return failure(); - if (body->empty()) + } + if (body->empty()) { return parser.emitError(loc, "expected non-empty function body"); + } } return success(); } @@ -1745,8 +1818,9 @@ bool IREE::Util::FuncOp::canDiscardOnUseEmpty() { bool IREE::Util::FuncOp::hasAnyTiedOperands() { auto tiedOperandsAttr = getTiedOperandsAttr(); - if (!tiedOperandsAttr) + if (!tiedOperandsAttr) { return false; + } return llvm::any_of( tiedOperandsAttr.getAsRange(), [](IntegerAttr attr) { return attr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; @@ -1773,8 +1847,9 @@ void IREE::Util::FuncOp::expandSignature( expandArgument(oldIndex, argType, newArgumentTypes); size_t expandedCount = newArgumentTypes.size() - newIndex; for (size_t i = 0; i < adjustedTiedOperands.size(); ++i) { - if (adjustedTiedOperands[i] == oldIndex) + if (adjustedTiedOperands[i] == oldIndex) { adjustedTiedOperands[i] = newIndex; + } } newArgumentAttrs.push_back(oldArgumentAttrs[oldIndex]); newArgumentAttrs.append(expandedCount - 1, @@ -1819,8 +1894,9 @@ FunctionType CallOp::getCalleeType() { static bool areTiedOperandsEqual(ArrayAttr a, ArrayAttr b) { auto hasAnyTied = [](ArrayAttr tiedOperandsAttr) { - if (!tiedOperandsAttr) + if (!tiedOperandsAttr) { return false; + } return llvm::any_of( tiedOperandsAttr.getAsRange(), [](IntegerAttr attr) { return attr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex; @@ -1828,10 +1904,12 @@ static bool areTiedOperandsEqual(ArrayAttr a, ArrayAttr b) { }; bool hasAnyTiedA = hasAnyTied(a); bool hasAnyTiedB = hasAnyTied(b); - if (hasAnyTiedA != hasAnyTiedB) + if (hasAnyTiedA != hasAnyTiedB) { return false; - if (!a || !b) + } + if (!a || !b) { return true; + } return a == b; } @@ -1877,8 +1955,9 @@ IREE::Util::CallOp IREE::Util::CallOp::cloneAndExpand( size_t newIndex = newOperands.size(); expandOperand(oldIndex, operand, newOperands); for (size_t i = 0; i < adjustedTiedOperands.size(); ++i) { - if (adjustedTiedOperands[i] == oldIndex) + if (adjustedTiedOperands[i] == oldIndex) { adjustedTiedOperands[i] = newIndex; + } } } @@ -2026,8 +2105,9 @@ void GlobalLoadOp::getEffects( SmallVectorImpl &effects) { // HACK: mlir doesn't have symbol side effects so we have to mark as a global // read if not immutable and not in an initializer. - if (!isGlobalImmutable()) + if (!isGlobalImmutable()) { effects.emplace_back(MemoryEffects::Read::get()); + } } LogicalResult diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp index 1ea0d7f75794..c96075d5c2ba 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp @@ -168,10 +168,12 @@ bool isValueUsableForOp(Value value, Block *block, } if (definingBlock == block) { // Defined in the same block; ensure block order. - if (isa(value)) + if (isa(value)) { return true; - if (insertionPoint == block->end()) + } + if (insertionPoint == block->end()) { return true; + } if (value.getDefiningOp()->isBeforeInBlock(&*insertionPoint)) { return true; } @@ -255,12 +257,14 @@ Operation *materializeConstant(OpBuilder &builder, Location loc, bool isPublicOrExternal(CallableOpInterface callableOp) { if (auto symbolOp = dyn_cast(callableOp.getOperation())) { - if (symbolOp.isPublic()) + if (symbolOp.isPublic()) { return true; + } } auto *region = callableOp.getCallableRegion(); - if (!region || region->empty()) + if (!region || region->empty()) { return true; + } return false; } @@ -396,22 +400,27 @@ std::optional detail::getTiedResultOperandIndex(Operation *op, unsigned resultIndex) { auto storageAttr = op->getAttrOfType( IREE::Util::TiedOpInterface::getStorageAttrName()); - if (!storageAttr) + if (!storageAttr) { return std::nullopt; + } auto valueAttrs = storageAttr.getValue(); - if (valueAttrs.empty()) + if (valueAttrs.empty()) { return std::nullopt; + } if (auto tiedOp = dyn_cast(op)) { auto indexAndLength = tiedOp.getTiedResultsIndexAndLength(); - if (resultIndex < indexAndLength.first) + if (resultIndex < indexAndLength.first) { return std::nullopt; + } resultIndex -= indexAndLength.first; - if (resultIndex >= indexAndLength.second) + if (resultIndex >= indexAndLength.second) { return std::nullopt; + } } int64_t value = cast(valueAttrs[resultIndex]).getInt(); - if (value == IREE::Util::TiedOpInterface::kUntiedIndex) + if (value == IREE::Util::TiedOpInterface::kUntiedIndex) { return std::nullopt; + } if (auto tiedOp = dyn_cast(op)) { unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; return tiedOperandsOffset + static_cast(value); @@ -436,8 +445,9 @@ void detail::setTiedResultOperandIndex(Operation *op, unsigned resultIndex, // returned by `getTiedOperandsIndexAndLength`. unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; for (auto &index : indices) { - if (index != TiedOpInterface::kUntiedIndex) + if (index != TiedOpInterface::kUntiedIndex) { index -= tiedOperandsOffset; + } } } @@ -451,11 +461,13 @@ SmallVector detail::getTiedResultOperandIndices(Operation *op) { SmallVector indices; auto storageAttr = op->getAttrOfType( IREE::Util::TiedOpInterface::getStorageAttrName()); - if (!storageAttr) + if (!storageAttr) { return indices; + } auto valueAttrs = storageAttr.getValue(); - if (valueAttrs.empty()) + if (valueAttrs.empty()) { return indices; + } auto tiedOp = cast(op); auto resultRange = tiedOp.getTiedResultsIndexAndLength(); unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first; @@ -475,8 +487,9 @@ Value TiedOpInterface::findTiedBaseValue(Value derivedValue) { while (auto definingOp = dyn_cast_if_present( baseValue.getDefiningOp())) { auto tiedValue = definingOp.getTiedResultOperand(baseValue); - if (!tiedValue) + if (!tiedValue) { break; + } baseValue = tiedValue; } return baseValue; @@ -503,8 +516,9 @@ bool detail::isOperandTied(Operation *op, unsigned operandIndex) { SmallVector detail::getOperandTiedResults(Operation *op, unsigned operandIndex) { auto tiedOp = dyn_cast(op); - if (!tiedOp) + if (!tiedOp) { return {}; + } auto resultRange = tiedOp.getTiedResultsIndexAndLength(); SmallVector results; auto tiedIndices = tiedOp.getTiedResultOperandIndices(); @@ -518,8 +532,9 @@ SmallVector detail::getOperandTiedResults(Operation *op, LogicalResult detail::verifyTiedOp(IREE::Util::TiedOpInterface tiedOp) { auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices(); - if (tiedOperandIndices.empty()) + if (tiedOperandIndices.empty()) { return success(); + } auto resultRange = tiedOp.getTiedResultsIndexAndLength(); if (tiedOperandIndices.size() != resultRange.second) { return tiedOp.emitError("op results/tied operand indices mismatch"); @@ -566,8 +581,9 @@ void excludeTiedOperandAndResultIndices( // Count up the number of removed operands prior to this one. unsigned offset = 0; for (unsigned i = 0; i < tiedOperandIndex; ++i) { - if (i < excludedOperands.size() && excludedOperands[i]) + if (i < excludedOperands.size() && excludedOperands[i]) { ++offset; + } } tiedOperandIndex -= offset; @@ -591,16 +607,18 @@ Value SizeAwareTypeInterface::findSizeValue(Value resourceValue, Block *block, while (!worklist.empty()) { auto value = worklist.pop_back_val(); auto *definingOp = value.getDefiningOp(); - if (!definingOp) + if (!definingOp) { continue; + } if (auto sizeAwareOp = dyn_cast(definingOp)) { return sizeAwareOp.getResultSizeFromValue(value); } if (auto tiedOp = dyn_cast(definingOp)) { auto tiedOperand = tiedOp.getTiedResultOperand(value); - if (tiedOperand) + if (tiedOperand) { worklist.push_back(tiedOperand); + } } } @@ -663,8 +681,9 @@ std::optional findDynamicDims(Value workValue) { // {|block|, |insertionPoint|} implicitly. while (workValue) { auto workOp = workValue.getDefiningOp(); - if (!workOp) + if (!workOp) { break; + } if (auto shapeAwareOp = dyn_cast(workOp)) { return shapeAwareOp.getResultDynamicDimsFromValue(workValue); @@ -708,8 +727,9 @@ std::optional findDynamicDims(Value shapedValue, Block *block, // Look up the use-def chain: always safe, as any value we reach dominates // {|block|, |insertionPoint|} implicitly. auto upwardRange = findDynamicDims(shapedValue); - if (upwardRange.has_value()) + if (upwardRange.has_value()) { return upwardRange.value(); + } // Look down the use-def chain: not safe at some point because we'll move past // where {|block|, |insertionPoint|} is dominated. This is often fine for a @@ -747,8 +767,9 @@ ValueRange findDynamicDimsInList(unsigned idx, ValueRange values, } else if (isa(value.getType())) { dynamicDimCount = 1; } - if (!dynamicDimCount) + if (!dynamicDimCount) { return ValueRange{}; + } // Find where the dynamic dims start in the flattened list. unsigned offset = 0; diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h index 6b1307a55f72..ecc630c043db 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h @@ -233,10 +233,12 @@ class IntegerDivisibility { static IntegerDivisibility join(const IntegerDivisibility &lhs, const IntegerDivisibility &rhs) { - if (lhs.isUninitialized()) + if (lhs.isUninitialized()) { return rhs; - if (rhs.isUninitialized()) + } + if (rhs.isUninitialized()) { return lhs; + } return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue())); } diff --git a/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp b/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp index 7cbae6aea599..c10a59b3252c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.cpp @@ -75,11 +75,13 @@ IREE::Util::transform_dialect::CreateSerializedModuleOp::apply( DiagnosedSilenceableFailure result = state.applyTransform(cast(transform)); // TODO: Support better error propagation. - if (result.isSilenceableFailure()) + if (result.isSilenceableFailure()) { return DiagnosedSilenceableFailure::definiteFailure(); + } // Pass through the error message from definite failures. - if (result.isDefiniteFailure()) + if (result.isDefiniteFailure()) { return result; + } } // Serialize the module as bytecode to a string. @@ -280,13 +282,15 @@ DiagnosedSilenceableFailure IREE::Util::transform_dialect::CastAndCallOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector inputs; - if (getInputs()) + if (getInputs()) { llvm::append_range(inputs, state.getPayloadValues(getInputs())); + } SetVector outputs; if (getOutputs()) { - for (auto output : state.getPayloadValues(getOutputs())) + for (auto output : state.getPayloadValues(getOutputs())) { outputs.insert(output); + } // Verify that the set of output values to be replaced is unique. if (outputs.size() != @@ -386,10 +390,11 @@ DiagnosedSilenceableFailure IREE::Util::transform_dialect::CastAndCallOp::apply( } } - if (insertAfter) + if (insertAfter) { rewriter.setInsertionPointAfter(insertionPoint); - else + } else { rewriter.setInsertionPoint(insertionPoint); + } for (auto [input, type] : llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) { @@ -504,12 +509,15 @@ LogicalResult IREE::Util::transform_dialect::CastAndCallOp::verify() { void IREE::Util::transform_dialect::CastAndCallOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getInsertionPointMutable(), effects); - if (getInputs()) + if (getInputs()) { transform::onlyReadsHandle(getInputsMutable(), effects); - if (getOutputs()) + } + if (getOutputs()) { transform::onlyReadsHandle(getOutputsMutable(), effects); - if (getFunction()) + } + if (getFunction()) { transform::onlyReadsHandle(getFunctionMutable(), effects); + } transform::producesHandle(getOperation()->getOpResults(), effects); transform::modifiesPayload(effects); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp index c7d64381bad3..dfa3fdeaade2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp @@ -33,8 +33,9 @@ struct DropCompilerHintsPass // undone. If LLVMGPU wants to keep the hints it should have its own // codegen op that carries the information. DropCompilerHints is meant // to drop all compiler hints. - if (keepAssumeInt) + if (keepAssumeInt) { return; + } op.replaceAllUsesWith(op.getOperands()); op.erase(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp index 9690b148edd7..641c280fd74a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FixedPointIterator.cpp @@ -61,10 +61,12 @@ FixedPointIteratorPass::FixedPointIteratorPass(OpPassManager pipeline) LogicalResult FixedPointIteratorPass::initializeOptions( StringRef options, function_ref errorHandler) { - if (failed(Pass::initializeOptions(options, errorHandler))) + if (failed(Pass::initializeOptions(options, errorHandler))) { return failure(); - if (pipeline) + } + if (pipeline) { return success(); + } // Pipelines are expected to be of the form `()`. // TODO: This was lifted from the Inliner pass. We should provide a parse @@ -73,12 +75,14 @@ LogicalResult FixedPointIteratorPass::initializeOptions( // See: https://github.com/llvm/llvm-project/issues/52813 StringRef pipelineSr = pipelineStr; size_t pipelineStart = pipelineSr.find_first_of('('); - if (pipelineStart == StringRef::npos || !pipelineSr.consume_back(")")) + if (pipelineStart == StringRef::npos || !pipelineSr.consume_back(")")) { return failure(); + } StringRef opName = pipelineSr.take_front(pipelineStart); OpPassManager pm(opName); - if (failed(parsePassPipeline(pipelineSr.drop_front(1 + pipelineStart), pm))) + if (failed(parsePassPipeline(pipelineSr.drop_front(1 + pipelineStart), pm))) { return failure(); + } pipeline = std::move(pm); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp index 38c4a801035f..6721c172b6fc 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp @@ -80,8 +80,9 @@ class FuseGlobalsPass : public impl::FuseGlobalsPassBase { llvm::dbgs() << ":\n"; }); auto *region = callableOp.getCallableRegion(); - if (!region) + if (!region) { continue; + } for (auto &block : *region) { DenseMap> valueStores; @@ -93,8 +94,9 @@ class FuseGlobalsPass : public impl::FuseGlobalsPassBase { storeOp.print(llvm::dbgs(), *asmState); llvm::dbgs() << "; candidate=" << global.isCandidate() << "\n"; }); - if (!global.isCandidate()) + if (!global.isCandidate()) { continue; + } valueStores[storeOp.getStoredGlobalValue()].push_back(storeOp); } for (auto valueStore : valueStores) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index 98fd7b413c14..e19758e4d9bb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -49,8 +49,9 @@ static std::string getHoistedName(Type type) { type.print(os); } str = sanitizeSymbolName(str); - if (str.substr(str.size() - 1) == "_") + if (str.substr(str.size() - 1) == "_") { str = str.substr(0, str.size() - 1); // strip trailing _ + } return str; } @@ -107,14 +108,16 @@ class HoistIntoGlobalsPass // yet. for (auto funcOp : getOperation().getOps()) { // Ignore initializers. - if (isa(funcOp.getOperation())) + if (isa(funcOp.getOperation())) { continue; + } auto walkRes = funcOp.walk([&](Operation *iterOp) { // We only want to look at const-expr ops (non roots) since they may // have interesting escapes. Early exit here for efficiency. auto *iterInfo = constExprs.lookup(iterOp); - if (!iterInfo) + if (!iterInfo) { return WalkResult::advance(); + } for (Value constExprResult : iterOp->getResults()) { auto *resultInfo = constExprs.lookup(constExprResult); assert(resultInfo && "must have const-expr info"); @@ -129,8 +132,9 @@ class HoistIntoGlobalsPass } return WalkResult::advance(); }); - if (walkRes.wasInterrupted()) + if (walkRes.wasInterrupted()) { return signalPassFailure(); + } } // Apply any remaining RAUW cleanups. We have to do these at the cleanup @@ -167,8 +171,9 @@ class HoistIntoGlobalsPass Operation *getTopLevelOp(Operation *childOp) { auto *moduleBlock = getOperation().getBody(); auto *op = childOp; - while (op->getBlock() != moduleBlock) + while (op->getBlock() != moduleBlock) { op = op->getParentOp(); + } return op; } @@ -176,8 +181,9 @@ class HoistIntoGlobalsPass SymbolTable &moduleSymbols, const ConstExprAnalysis &constExprs) { IREE::Util::GlobalOp existingGlobal = hoistedMap.lookup(originalValue); - if (existingGlobal) + if (existingGlobal) { return success(); + } // Gather any dialect attributes we may need to preserve. auto *topLevelOp = getTopLevelOp(originalValue.getDefiningOp()); @@ -213,8 +219,9 @@ class HoistIntoGlobalsPass const ConstExprAnalysis::ConstValueInfo *producerInfo, HoistedValueMap &hoistedMap, IRMapping &cloneMapping, const ConstExprAnalysis &constExprs) { - if (cloneMapping.contains(producerInfo->constValue)) + if (cloneMapping.contains(producerInfo->constValue)) { return; + } // We either have a global associated already or we need to traverse // down and materialize producers. @@ -331,8 +338,9 @@ class HoistIntoGlobalsPass // longer be valid after this point. for (auto funcOp : getOperation().getOps()) { // Ignore initializers. - if (isa(funcOp.getOperation())) + if (isa(funcOp.getOperation())) { continue; + } funcOp.walk( [&](Operation *iterOp) { if (allOps.contains(iterOp) && iterOp->use_empty()) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp index 553742032256..691882708f9d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp @@ -191,8 +191,9 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp, // Walk callee arguments. for (auto [i, value] : llvm::enumerate(funcOp.getArguments())) { - if (value.use_empty()) + if (value.use_empty()) { analysis.calleeUsedArgs.reset(i); + } } // Walk all return sites in the function. @@ -327,8 +328,9 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp, // Note that we need to track unused results as an AND such that all callers // need to not use them. We'll flip the bits below so that `used = true`. for (auto [i, value] : llvm::enumerate(callOp.getResults())) { - if (!value.use_empty()) + if (!value.use_empty()) { callerUnusedResults.reset(i); + } } } if (!analysis.callOps.empty()) { @@ -376,8 +378,9 @@ static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp, // we know all callers will stop passing them. for (unsigned i = 0; i < resultCount; ++i) { int argIndex = analysis.passthroughResultArgs[i]; - if (argIndex == kUnassigned) + if (argIndex == kUnassigned) { continue; + } auto arg = funcOp.getArgument(argIndex); bool onlyReturnUsers = true; for (auto user : arg.getUsers()) { @@ -518,14 +521,16 @@ static bool applyFuncChanges(FuncAnalysis &analysis, } // Early out if no changes. - if (deadArgs.none() && deadResults.none()) + if (deadArgs.none() && deadResults.none()) { return false; + } // Erase dead results from all return sites. funcOp.walk([&](IREE::Util::ReturnOp returnOp) { for (int i = deadResults.size() - 1; i >= 0; --i) { - if (deadResults.test(i)) + if (deadResults.test(i)) { returnOp.getOperandsMutable().erase(i); + } } }); @@ -612,8 +617,9 @@ static bool applyCallChanges(FuncAnalysis &analysis, } // Early out if no changes. - if (deadOperands.none() && deadResults.none()) + if (deadOperands.none() && deadResults.none()) { return false; + } // Fully replace call op because we may have changed result count. // TODO(benvanik): update tied operands, arg_attrs, and res_attrs. diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp index 966ef7f02843..0d8bc5882804 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp @@ -102,8 +102,9 @@ class ImportResourcesPass } } } - if (updated) + if (updated) { op->setAttrs(attrs); + } }); LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n"); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 647183f3f85b..49107588a034 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -99,8 +99,9 @@ struct ConvertOpToUnsigned : public OpRewritePattern { LogicalResult matchAndRewrite(Signed op, PatternRewriter &rewriter) const override { - if (failed(staticallyLegalToConvertToUnsignedOp(solver, op))) + if (failed(staticallyLegalToConvertToUnsignedOp(solver, op))) { return failure(); + } rewriter.replaceOpWithNewOp(op, op->getResultTypes(), op->getOperands(), op->getAttrs()); return success(); @@ -135,15 +136,18 @@ struct ConvertUnsignedI64IndexCastProducerToIndex PatternRewriter &rewriter) const override { Type inType = origIndexOp.getIn().getType(); Type outType = origIndexOp.getOut().getType(); - if (!inType.isSignlessInteger(64) || !isa(outType)) + if (!inType.isSignlessInteger(64) || !isa(outType)) { return failure(); + } Operation *producer = origIndexOp.getIn().getDefiningOp(); - if (!producer) + if (!producer) { return failure(); + } auto producerResult = producer->getResult(0); - if (!producerResult.hasOneUse()) + if (!producerResult.hasOneUse()) { return failure(); + } auto pred = [&](Value v) -> bool { auto *result = solver.lookupState(v); @@ -163,17 +167,20 @@ struct ConvertUnsignedI64IndexCastProducerToIndex if (!isa_and_present(producer)) + arith::RemUIOp, arith::SubIOp>(producer)) { return failure(); - if (!isOpStaticallyLegal(producer)) + } + if (!isOpStaticallyLegal(producer)) { return failure(); + } // Make modifications. rewriter.modifyOpInPlace(producer, [&]() { rewriter.setInsertionPoint(producer); for (auto &operand : producer->getOpOperands()) { - if (operand.get().getType() != inType) + if (operand.get().getType() != inType) { continue; + } Value newOperand = arith::IndexCastUIOp::create( rewriter, producer->getLoc(), outType, operand.get()); operand.set(newOperand); @@ -204,20 +211,24 @@ struct RemoveIndexCastForAssumeOfI32 PatternRewriter &rewriter) const override { llvm::SmallBitVector needNarrowing(op.getNumOperands(), false); for (auto [idx, arg] : llvm::enumerate(op.getOperands())) { - if (!arg.getType().isIndex()) + if (!arg.getType().isIndex()) { continue; + } auto castOp = arg.getDefiningOp(); - if (!castOp) + if (!castOp) { continue; + } Value castIn = castOp.getIn(); Type intType = castIn.getType(); - if (intType.getIntOrFloatBitWidth() > 32) + if (intType.getIntOrFloatBitWidth() > 32) { continue; + } needNarrowing[idx] = true; } - if (needNarrowing.none()) + if (needNarrowing.none()) { return failure(); + } SmallVector newArgs; newArgs.reserve(op.getNumOperands()); @@ -267,22 +278,27 @@ struct NarrowSCFForIvToI32 : public OpRewritePattern { Location loc = forOp.getLoc(); Value iv = forOp.getInductionVar(); Type srcType = iv.getType(); - if (!srcType.isIndex() && !srcType.isInteger(64)) + if (!srcType.isIndex() && !srcType.isInteger(64)) { return rewriter.notifyMatchFailure(forOp, "IV isn't an index or i64"); - if (!staticallyLegalToConvertToUnsigned(solver, iv)) + } + if (!staticallyLegalToConvertToUnsigned(solver, iv)) { return rewriter.notifyMatchFailure(forOp, "IV isn't non-negative"); - if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep())) + } + if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep())) { return rewriter.notifyMatchFailure(forOp, "Step isn't non-negative"); + } auto *ivState = solver.lookupState(iv); - if (ivState->getValue().getValue().smax().getActiveBits() > 31) + if (ivState->getValue().getValue().smax().getActiveBits() > 31) { return rewriter.notifyMatchFailure(forOp, "IV won't fit in signed int32"); + } Type i32 = rewriter.getI32Type(); auto doCastDown = [&](Value v) -> Value { - if (srcType.isIndex()) + if (srcType.isIndex()) { return arith::IndexCastUIOp::create(rewriter, loc, i32, v); - else + } else { return arith::TruncIOp::create(rewriter, loc, i32, v); + } }; Value newLb = doCastDown(forOp.getLowerBound()); Value newUb = doCastDown(forOp.getUpperBound()); @@ -322,9 +338,10 @@ static LogicalResult getDivisibility(DataFlowSolver &solver, Operation *op, Value value, PatternRewriter &rewriter, ConstantIntDivisibility &out) { auto *div = solver.lookupState(value); - if (!div || div->getValue().isUninitialized()) + if (!div || div->getValue().isUninitialized()) { return rewriter.notifyMatchFailure(op, "divisibility could not be determined"); + } out = div->getValue().getValue(); LLVM_DEBUG(dbgs() << " * Resolved divisibility: " << out << "\n"); @@ -338,17 +355,20 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern { LogicalResult matchAndRewrite(arith::RemUIOp op, PatternRewriter &rewriter) const override { APInt rhsConstant; - if (!matchPattern(op.getRhs(), m_ConstantInt(&rhsConstant))) + if (!matchPattern(op.getRhs(), m_ConstantInt(&rhsConstant))) { return rewriter.notifyMatchFailure(op, "rhs is not constant"); + } ConstantIntDivisibility lhsDiv; - if (failed(getDivisibility(solver, op, op.getLhs(), rewriter, lhsDiv))) + if (failed(getDivisibility(solver, op, op.getLhs(), rewriter, lhsDiv))) { return failure(); + } uint64_t rhsValue = rhsConstant.getZExtValue(); if (rhsValue > 0 && lhsDiv.udiv() > 0) { - if (lhsDiv.udiv() % rhsValue != 0) + if (lhsDiv.udiv() % rhsValue != 0) { return rewriter.notifyMatchFailure(op, "rhs does not divide lhs"); + } rewriter.replaceOpWithNewOp( op, rewriter.getZeroAttr(op.getResult().getType())); @@ -397,10 +417,12 @@ struct ElideTruncOfIndexCast : public OpRewritePattern { LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { Operation *producer = truncOp.getOperand().getDefiningOp(); - if (!producer) + if (!producer) { return failure(); - if (!isa(producer)) + } + if (!isa(producer)) { return failure(); + } rewriter.replaceOpWithNewOp( truncOp, truncOp.getResult().getType(), producer->getOperand(0)); return success(); @@ -418,8 +440,9 @@ class DataFlowListener : public RewriterBase::Listener { protected: void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); - for (Value res : op->getResults()) + for (Value res : op->getResults()) { s.eraseState(res); + } } void notifyOperationModified(Operation *op) override { @@ -463,8 +486,9 @@ class OptimizeIntArithmeticPass // Populate canonicalization patterns. auto arithDialect = ctx->getOrLoadDialect(); for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) { - if (&name.getDialect() == arithDialect) + if (&name.getDialect() == arithDialect) { name.getCanonicalizationPatterns(patterns, ctx); + } } // General optimization patterns. @@ -513,8 +537,9 @@ class OptimizeIntArithmeticPass return signalPassFailure(); } - if (!changed) + if (!changed) { break; + } } } }; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp index 6c1d0305106c..01a9d460dcb7 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp @@ -118,11 +118,13 @@ struct FoldBlockArgumentsPattern using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(CallableOpInterface op, PatternRewriter &rewriter) const override { - if (!op.getCallableRegion()) + if (!op.getCallableRegion()) { return failure(); + } auto ®ion = *op.getCallableRegion(); - if (region.empty() || region.hasOneBlock()) + if (region.empty() || region.hasOneBlock()) { return failure(); + } // Analyze all branches in the op to compute the information we'll need to // analyze across branch sources. @@ -171,11 +173,13 @@ struct FoldBlockArgumentsPattern for (auto &block : llvm::make_range(++region.getBlocks().begin(), region.getBlocks().end())) { unsigned numArgs = block.getNumArguments(); - if (numArgs == 0) + if (numArgs == 0) { continue; + } auto blockSources = llvm::ArrayRef(blockSourceMap[&block]); - if (blockSources.size() == 0) + if (blockSources.size() == 0) { continue; + } // Which args we'll end up erasing. // We need to do the actual removal after we've done the remapping below @@ -263,11 +267,13 @@ struct ElideBranchOperandsPattern using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(CallableOpInterface op, PatternRewriter &rewriter) const override { - if (!op.getCallableRegion()) + if (!op.getCallableRegion()) { return failure(); + } auto ®ion = *op.getCallableRegion(); - if (region.empty()) + if (region.empty()) { return failure(); + } DominanceInfo dominance(op); // Analyze all branches to build a map of blocks to their sources. @@ -298,11 +304,13 @@ struct ElideBranchOperandsPattern for (auto &block : llvm::make_range(++region.getBlocks().begin(), region.getBlocks().end())) { unsigned numArgs = block.getNumArguments(); - if (numArgs == 0) + if (numArgs == 0) { continue; + } auto blockSources = llvm::ArrayRef(blockSourceMap[&block]); - if (blockSources.size() == 0) + if (blockSources.size() == 0) { continue; + } // Which args we'll end up erasing. // We need to do the actual removal after we've done the remapping below @@ -342,8 +350,9 @@ struct ElideBranchOperandsPattern uniformValue = nullptr; break; } - if (!uniformValue) + if (!uniformValue) { continue; + } // See if the uniform value dominates this block; if so we can use it. if (!uniformValue.getDefiningOp() || @@ -354,8 +363,9 @@ struct ElideBranchOperandsPattern elidedArgs.set(argIndex); } } - if (elidedArgs.none()) + if (elidedArgs.none()) { continue; + } // Erase all the block arguments we remapped. for (auto &blockSource : blockSources) { @@ -407,8 +417,9 @@ struct IndexSwitchToIfPattern : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(scf::IndexSwitchOp switchOp, PatternRewriter &rewriter) const override { - if (switchOp.getNumCases() != 1) + if (switchOp.getNumCases() != 1) { return failure(); + } Value caseValue = arith::ConstantIndexOp::create( rewriter, switchOp.getLoc(), switchOp.getCases().front()); Value isCaseValue = rewriter.createOrFold( @@ -472,16 +483,19 @@ struct MergeIndexSwitchPattern : public OpRewritePattern { // Inspect the previous op to see if it's also a switch. auto prevOp = dyn_cast_if_present(nextOp->getPrevNode()); - if (!prevOp) + if (!prevOp) { return failure(); + } // Require that the cases line up exactly. There's probably some merging // we could do in other cases but it'd be best to leave other patterns to // hoist/CSE cases/etc instead. - if (prevOp.getNumCases() != nextOp.getNumCases()) + if (prevOp.getNumCases() != nextOp.getNumCases()) { return rewriter.notifyMatchFailure(nextOp, "number of cases differ"); - if (!llvm::equal(prevOp.getCases(), nextOp.getCases())) + } + if (!llvm::equal(prevOp.getCases(), nextOp.getCases())) { return rewriter.notifyMatchFailure(nextOp, "case values differ"); + } // Create a new switch to replace nextOp that contains the same cases but // combined results from both ops. @@ -518,8 +532,9 @@ struct MergeIndexSwitchPattern : public OpRewritePattern { // values for the particular case. auto yieldA = *regionA.getOps().begin(); for (auto &op : regionA.getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } // Clone each op and map its original value to the new local value. targetBuilder.clone(op, localMapping); } @@ -534,8 +549,9 @@ struct MergeIndexSwitchPattern : public OpRewritePattern { // Clone regionB into target. auto yieldB = *regionB.getOps().begin(); for (auto &op : regionB.getOps()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } // Clone each op and map its original value to the new local value. targetBuilder.clone(op, localMapping); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp index 15735e9a0243..bacff11291e6 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp @@ -65,8 +65,9 @@ static ExpandedGlobalMap expandResourceGlobals(Operation *rootOp, // Gather all of the resource globals in the root. for (auto ®ion : rootOp->getRegions()) { for (auto globalOp : region.getOps()) { - if (!isResourceType(globalOp.getType())) + if (!isResourceType(globalOp.getType())) { continue; + } expandedGlobals[globalOp.getName()].resourceOp = globalOp; } } @@ -127,8 +128,9 @@ static void expandType(Type type, SmallVectorImpl &newTypes) { // Expands resources in the given |types| list to (resource, size, offset, len). // This could be changed to some iterator magic to avoid the alloc. static SmallVector expandTypes(TypeRange types) { - if (types.empty()) + if (types.empty()) { return {}; + } SmallVector newTypes; newTypes.reserve(types.size() * 2); for (auto type : types) { @@ -221,14 +223,16 @@ static void expandSubranges(Operation *op, SymbolTable &symbolTable, static void expandRegion(Region ®ion, bool canModifyEntryBlock, SymbolTable &symbolTable, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap subrangeMap) { - if (region.empty()) + if (region.empty()) { return; + } // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) + if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) { continue; + } // Entry blocks that we can't modify are fully handled by // MutableRegionBranchOpInterface (via wrapExpandedBlockArgFn callback). @@ -245,8 +249,9 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, // Insert new arguments for each resource argument. for (int i = block.getNumArguments() - 1; i >= 0; --i) { auto arg = block.getArgument(i); - if (!isResourceType(arg.getType())) + if (!isResourceType(arg.getType())) { continue; + } Subrange subrange; subrange.resource = arg; subrange.resourceSize = @@ -306,10 +311,12 @@ static void updateSubrangeOp(IREE::Util::SubrangeOpInterface op, // Ignore ops that are already in the map (we likely inserted them ourselves // earlier). auto resultResource = op.getSubrangeResult(); - if (!resultResource) + if (!resultResource) { return; - if (subrangeMap.count(resultResource)) + } + if (subrangeMap.count(resultResource)) { return; + } // Get the subrange of the source resource which we should have by way of the // other insertions (func/block args, etc). @@ -317,8 +324,9 @@ static void updateSubrangeOp(IREE::Util::SubrangeOpInterface op, builder.setInsertionPointAfter(op); auto sourceSubrange = consumeSubrange(op.getLoc(), op.getSubrangeResource(), subrangeMap, indexSet, builder); - if (op.getSubrangeResource() == sourceSubrange.resource) + if (op.getSubrangeResource() == sourceSubrange.resource) { return; + } // Update the subrange in the map by adding the source offset and the local // offset from the op. Future ops that consume subranges will reference back @@ -347,8 +355,9 @@ static void updateSubrangeOp(IREE::Util::SubrangeOpInterface op, static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto &expandedGlobal = globalMap[op.getGlobalName()]; @@ -386,8 +395,9 @@ static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto subrange = consumeSubrange(op.getLoc(), op.getStoredGlobalValue(), @@ -460,13 +470,15 @@ static void expandFuncOp(IREE::Util::FuncOp op, SymbolTable &symbolTable, // %2 = stream.resource.subview %r[%ro] : {%rsz} -> {%rl} static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } // Ignore calls to public/external functions. auto calleeOp = symbolTable.lookup(op.getCallee()); - if (IREE::Util::isPublicOrExternal(calleeOp)) + if (IREE::Util::isPublicOrExternal(calleeOp)) { return; + } // Build the new call op with expanded operands and results. OpBuilder builder(op); @@ -518,10 +530,13 @@ static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, // util.return %0, %sz, %o, %l static void expandReturnOp(IREE::Util::ReturnOp op, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; - if (IREE::Util::isPublicOrExternal(op->getParentOfType())) + } + if (IREE::Util::isPublicOrExternal( + op->getParentOfType())) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getOperands(), subrangeMap, indexSet, builder); @@ -551,8 +566,9 @@ static void expandBranchOp(mlir::cf::BranchOp op, IndexSet &indexSet, static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); mlir::cf::CondBranchOp::create( builder, op.getLoc(), op.getCondition(), op.getTrueDest(), @@ -568,8 +584,9 @@ static ValueRange asValueRange(ArrayRef values) { return values; } static void expandSwitchOp(mlir::cf::SwitchOp op, IndexSet &indexSet, SubrangeMap &subrangeMap) { - if (!usesResources(op)) + if (!usesResources(op)) { return; + } OpBuilder builder(op); auto caseOperands = llvm::to_vector( llvm::map_range(op.getCaseOperands(), [&](ValueRange operands) { @@ -742,8 +759,9 @@ class PropagateSubrangesPass // NOTE: the callable may be empty (like when an extern) - we still want // to process it but don't need an IndexSet. auto *region = callableOp.getCallableRegion(); - if (!region || region->empty()) + if (!region || region->empty()) { continue; + } IndexSet indexSet(callableOp.getLoc(), OpBuilder::atBlockBegin(®ion->front())); SubrangeMap subrangeMap; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index 62f84fdff442..cddc2e7ee12b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp @@ -50,8 +50,9 @@ static void hoistImmutableLoads(Region ®ion, auto ops = llvm::to_vector<8>(block.getOps()); for (auto &op : ops) { - if (!immutableGlobals.contains(op.getGlobalName())) + if (!immutableGlobals.contains(op.getGlobalName())) { continue; + } auto globalRef = cast(op.getGlobalAttr()); auto it = loadOps.find(globalRef); if (it == loadOps.end()) { @@ -89,8 +90,9 @@ static bool doesOpBlockMotion(Operation *op) { static SetVector getOpsThatBlockMotion(Block &block) { SetVector ops; for (auto &op : block.getOperations()) { - if (doesOpBlockMotion(&op)) + if (doesOpBlockMotion(&op)) { ops.insert(&op); + } } return ops; } @@ -100,12 +102,14 @@ static void moveOpUpInBlock(Block &block, Operation *op, // Find the earliest node that does not block op motion then move before it. mlir::Operation *earliestValidNode = op; while (earliestValidNode->getPrevNode()) { - if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode())) + if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode())) { break; + } earliestValidNode = earliestValidNode->getPrevNode(); } - if (earliestValidNode != op) + if (earliestValidNode != op) { op->moveBefore(earliestValidNode); + } } static void @@ -114,12 +118,14 @@ moveOpDownInBlock(Block &block, Operation *op, // Find the latest node that does not block op motion then move after it. mlir::Operation *latestValidNode = op; while (latestValidNode->getNextNode()) { - if (opsThatBlockMotion.contains(latestValidNode->getNextNode())) + if (opsThatBlockMotion.contains(latestValidNode->getNextNode())) { break; + } latestValidNode = latestValidNode->getNextNode(); } - if (latestValidNode != op) + if (latestValidNode != op) { op->moveAfter(latestValidNode); + } } // Optimizes the load/store ops for each given bucket. @@ -176,8 +182,9 @@ optimizeBuckets(Block &block, didRemoveAny = true; } } - if (ops.empty()) + if (ops.empty()) { continue; + } if (auto loadOp = dyn_cast(ops.front())) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp index 7d0a940e5ed7..1241ee0e43a5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/LinearScan/LiveIntervals.cpp @@ -36,8 +36,9 @@ LogicalResult LiveIntervals::annotateIR(IREE::VM::FuncOp funcOp) { // Annotate each block with its instruction range. for (auto *block : liveIntervals.getBlockOrder()) { - if (block->empty()) + if (block->empty()) { continue; + } uint32_t blockStart = liveIntervals.getInstructionIndex(&block->front()); uint32_t blockEnd = liveIntervals.getInstructionIndex(&block->back()); @@ -55,8 +56,9 @@ LogicalResult LiveIntervals::annotateIR(IREE::VM::FuncOp funcOp) { uint32_t opIndex = liveIntervals.getInstructionIndex(&op); op.setAttr("op_index", builder.getI32IntegerAttr(opIndex)); - if (op.getNumResults() == 0) + if (op.getNumResults() == 0) { continue; + } SmallVector intervalStrs; for (auto result : op.getResults()) { @@ -141,8 +143,9 @@ LogicalResult LiveIntervals::build(IREE::VM::FuncOp funcOp) { const LiveInterval *LiveIntervals::getInterval(Value value) const { auto it = valueToInterval_.find(value); - if (it == valueToInterval_.end()) + if (it == valueToInterval_.end()) { return nullptr; + } return &intervals_[it->second]; } @@ -168,8 +171,9 @@ void LiveIntervals::sortBlocksInDominanceOrder(IREE::VM::FuncOp funcOp) { } llvm::SmallSetVector markedBlocks; std::function visit = [&](Block *block) { - if (markedBlocks.count(block) > 0) + if (markedBlocks.count(block) > 0) { return; + } for (auto *childBlock : dominanceInfo.getNode(block)->children()) { visit(childBlock->getBlock()); } @@ -201,8 +205,9 @@ void LiveIntervals::buildIntervals(ValueLiveness &liveness) { for (auto *block : blockOrder_) { // Process block arguments. for (auto blockArg : block->getArguments()) { - if (valueToInterval_.count(blockArg)) + if (valueToInterval_.count(blockArg)) { continue; + } // Block arguments are "defined" at the start of the block. // We use the first op's index as the start. @@ -228,8 +233,9 @@ void LiveIntervals::buildIntervals(ValueLiveness &liveness) { uint32_t opIndex = opToIndex_[&op]; for (auto result : op.getResults()) { - if (valueToInterval_.count(result)) + if (valueToInterval_.count(result)) { continue; + } uint32_t start = opIndex; uint32_t end = findLastUse(result, liveness); diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp index 8cbb34d1b8d0..605eca8bfad1 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/OrdinalAnalysis.cpp @@ -58,8 +58,9 @@ OrdinalAnalysis::OrdinalAnalysis(IREE::VM::ModuleOp moduleOp) { int globalBytes = 0; for (auto sizeGlobalOps : llvm::enumerate(primitiveGlobalOps)) { size_t storageSize = sizeGlobalOps.index(); - if (sizeGlobalOps.value().empty()) + if (sizeGlobalOps.value().empty()) { continue; + } nextGlobalBytesOrdinal = llvm::alignTo(nextGlobalBytesOrdinal, storageSize); for (auto &globalOp : sizeGlobalOps.value()) { ordinals_[globalOp] = nextGlobalBytesOrdinal; diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp index 49c8510917c8..717a2d648724 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp @@ -62,8 +62,9 @@ LogicalResult RegisterAllocation::annotateIR(IREE::VM::FuncOp funcOp) { registerAllocation.remapSuccessorRegisters(&op, i); auto succOperands = branchOp.getSuccessorOperands(i).getForwardedOperands(); - if (succOperands.empty()) + if (succOperands.empty()) { continue; + } unsigned baseIdx = succOperands.getBeginOperandIndex(); // remapSuccessorRegisters only returns pairs where src != dst. // For display, we need ALL operands with correct MOVE bits. @@ -114,8 +115,9 @@ LogicalResult RegisterAllocation::annotateIR(IREE::VM::FuncOp funcOp) { op.setAttr("operand_registers", getStrArrayAttr(builder, operandRegStrs)); } - if (op.getNumResults() == 0) + if (op.getNumResults() == 0) { continue; + } SmallVector regStrs; regStrs.reserve(op.getNumResults()); for (auto result : op.getResults()) { @@ -167,8 +169,9 @@ sortBlocksInDominanceOrder(IREE::VM::FuncOp funcOp) { } llvm::SmallSetVector markedBlocks; std::function visit = [&](Block *block) { - if (markedBlocks.count(block) > 0) + if (markedBlocks.count(block) > 0) { return; + } for (auto *childBlock : dominanceInfo.getNode(block)->children()) { visit(childBlock->getBlock()); } @@ -369,15 +372,18 @@ LogicalResult RegisterAllocation::recalculate(IREE::VM::FuncOp funcOp) { llvm::DenseMap coalesceSource; auto recordCoalesceCandidate = [&](Value dest, Value src) { - if (dest.getType() != src.getType()) + if (dest.getType() != src.getType()) { return; + } auto srcInterval = liveIntervals.getInterval(src); auto destInterval = liveIntervals.getInterval(dest); - if (!srcInterval || !destInterval) + if (!srcInterval || !destInterval) { return; + } // Only coalesce if intervals meet exactly (hand-off). - if (srcInterval->end != destInterval->start) + if (srcInterval->end != destInterval->start) { return; + } coalesceSource[dest] = src; }; @@ -386,17 +392,20 @@ LogicalResult RegisterAllocation::recalculate(IREE::VM::FuncOp funcOp) { // Block arguments can coalesce with branch operands from predecessors. for (auto *pred : block->getPredecessors()) { auto branchOp = dyn_cast(pred->getTerminator()); - if (!branchOp) + if (!branchOp) { continue; + } for (unsigned succIdx = 0; succIdx < pred->getTerminator()->getNumSuccessors(); ++succIdx) { - if (pred->getTerminator()->getSuccessor(succIdx) != block) + if (pred->getTerminator()->getSuccessor(succIdx) != block) { continue; + } OperandRange operands = branchOp.getSuccessorOperands(succIdx).getForwardedOperands(); for (auto [idx, operand] : llvm::enumerate(operands)) { - if (idx >= block->getNumArguments()) + if (idx >= block->getNumArguments()) { break; + } recordCoalesceCandidate(block->getArgument(idx), operand); } } @@ -481,8 +490,9 @@ void RegisterAllocation::computeElidableDiscards(IREE::VM::FuncOp funcOp) { for (auto &block : funcOp.getBlocks()) { for (auto &op : block.getOperations()) { auto discardOp = dyn_cast(&op); - if (!discardOp) + if (!discardOp) { continue; + } SmallVector operandElidability; for (Value ref : discardOp.getRefs()) { @@ -510,8 +520,9 @@ void RegisterAllocation::computeElidableDiscards(IREE::VM::FuncOp funcOp) { break; } } - if (hasPrecedingMoveUse) + if (hasPrecedingMoveUse) { break; + } } operandElidability.push_back(hasPrecedingMoveUse); } @@ -633,18 +644,21 @@ struct FeedbackArcSet { SmallVector outEdges; outEdges.reserve(node->outdegree); for (auto &edge : edges) { - if (edge.sink == node) + if (edge.sink == node) { inEdges.push_back(edge); - if (edge.source == node) + } + if (edge.source == node) { outEdges.push_back(edge); + } } bool collectInEdges = node->indegree <= node->outdegree; bool collectOutEdges = !collectInEdges; SmallVector results; for (auto &edge : inEdges) { - if (edge.source == node) + if (edge.source == node) { continue; + } if (collectInEdges) { results.push_back({edge.source->id, edge.sink->id}); } @@ -654,8 +668,9 @@ struct FeedbackArcSet { assignBucket(edge.source); } for (auto &edge : outEdges) { - if (edge.sink == node) + if (edge.sink == node) { continue; + } if (collectOutEdges) { results.push_back({edge.source->id, edge.sink->id}); } @@ -681,11 +696,13 @@ struct FeedbackArcSet { ends.erase(ends.begin()); removeNode(node); } - if (remainingNodes.empty()) + if (remainingNodes.empty()) { break; + } for (ssize_t i = buckets.size() - 1; i >= 0; --i) { - if (buckets[i].empty()) + if (buckets[i].empty()) { continue; + } auto *bucket = buckets[i].front(); buckets[i].erase(buckets[i].begin()); auto feedbackEdges = removeNode(bucket); @@ -715,11 +732,13 @@ struct FeedbackArcSet { llvm::SmallSetVector unmarkedNodes = acyclicNodes; llvm::SmallSetVector markedNodes; std::function visit = [&](NodeID node) { - if (markedNodes.count(node) > 0) + if (markedNodes.count(node) > 0) { return; + } for (auto &edge : acyclicEdges) { - if (edge.first != node) + if (edge.first != node) { continue; + } visit(edge.second); } markedNodes.insert(node); @@ -729,8 +748,9 @@ struct FeedbackArcSet { } for (auto node : markedNodes.takeVector()) { for (auto &edge : acyclicEdges) { - if (edge.first != node) + if (edge.first != node) { continue; + } result.acyclicEdges.push_back({edge.first, edge.second}); } } diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h index 1af00b0d8775..919a04100d2d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h @@ -227,8 +227,9 @@ class RegisterAllocation : public VMRegisterAllocation { // operands have already been released via MOVE on preceding operations. bool isDiscardElidable(Operation *op) const { auto it = discardOperandElidability_.find(op); - if (it == discardOperandElidability_.end()) + if (it == discardOperandElidability_.end()) { return false; + } return llvm::all_of(it->second, [](bool b) { return b; }); } @@ -238,10 +239,12 @@ class RegisterAllocation : public VMRegisterAllocation { bool isDiscardOperandElidable(Operation *op, unsigned operandIndex) const override { auto it = discardOperandElidability_.find(op); - if (it == discardOperandElidability_.end()) + if (it == discardOperandElidability_.end()) { return false; - if (operandIndex >= it->second.size()) + } + if (operandIndex >= it->second.size()) { return false; + } return it->second[operandIndex]; } diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp index 5385554ca8bd..ff7aeab3e89d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp @@ -260,14 +260,17 @@ LogicalResult ValueLiveness::computeLiveIntervals(IREE::VM::FuncOp funcOp) { // Handle values entering the block and dying within. for (auto value : blockSets.liveIn) { - if (blockSets.liveOut.count(value)) + if (blockSets.liveOut.count(value)) { continue; + } Operation *lastUse = &block.front(); for (auto &use : value.getUses()) { - if (use.getOwner()->getBlock() != &block) + if (use.getOwner()->getBlock() != &block) { continue; - if (lastUse == use.getOwner()) + } + if (lastUse == use.getOwner()) { continue; + } if (lastUse->isBeforeInBlock(use.getOwner())) { lastUse = use.getOwner(); } @@ -277,14 +280,16 @@ LogicalResult ValueLiveness::computeLiveIntervals(IREE::VM::FuncOp funcOp) { // Handle values defined within the block and not escaping. for (auto value : blockSets.defined) { - if (blockSets.liveOut.count(value)) + if (blockSets.liveOut.count(value)) { continue; + } Operation *firstUse = value.getDefiningOp() ? value.getDefiningOp() : &block.front(); Operation *lastUse = firstUse; for (auto &use : value.getUses()) { - if (use.getOwner()->getBlock() != &block) + if (use.getOwner()->getBlock() != &block) { continue; + } if (lastUse->isBeforeInBlock(use.getOwner())) { lastUse = use.getOwner(); } @@ -386,8 +391,9 @@ bool ValueLiveness::isLastRealValueUse(Value value, Operation *useOp, break; } } - if (valueIsSuccessorOperand) + if (valueIsSuccessorOperand) { break; + } } } // Check if the value escapes to any successor blocks. diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp index 678c8c41b2ce..027e6f8d222f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp @@ -100,8 +100,9 @@ struct CmpI32OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpIOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isInteger(32)) + if (!adaptor.getLhs().getType().isInteger(32)) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpIPredicate::eq: @@ -155,8 +156,9 @@ struct CmpI64OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpIOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isInteger(64)) + if (!adaptor.getLhs().getType().isInteger(64)) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpIPredicate::eq: @@ -210,8 +212,9 @@ struct CmpF32OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpFOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isF32()) + if (!adaptor.getLhs().getType().isF32()) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: // 0 @@ -300,8 +303,9 @@ struct CmpF64OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(arith::CmpFOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getLhs().getType().isF64()) + if (!adaptor.getLhs().getType().isF64()) { return failure(); + } auto returnType = rewriter.getIntegerType(32); switch (srcOp.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: // 0 @@ -623,13 +627,15 @@ struct ExtendFOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto srcType = dyn_cast_if_present(srcOp.getIn().getType()); auto resultType = dyn_cast_if_present(srcOp.getType()); - if (!srcType || !resultType) + if (!srcType || !resultType) { return failure(); + } auto dstType = getTypeConverter()->convertType(resultType); auto srcBits = srcType.getWidth(); auto resultBits = resultType.getWidth(); - if (srcBits != 32 || resultBits != 64) + if (srcBits != 32 || resultBits != 64) { return rewriter.notifyMatchFailure(srcOp, "unsupported extf conversion"); + } rewriter.replaceOpWithNewOp(srcOp, dstType, adaptor.getIn()); return success(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index 20bfb1bd9df0..95c219e49430 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -63,8 +63,9 @@ LogicalResult ImportTable::build(Operation *rootOp, std::optional ImportTable::find(StringRef symbolName) { auto it = symbols.find(symbolName); - if (it == symbols.end()) + if (it == symbols.end()) { return std::nullopt; + } return it->second; } @@ -110,8 +111,9 @@ LogicalResult appendImportModule(StringRef importModuleSrc, Value castToImportType(Value value, Type targetType, OpBuilder &builder) { auto sourceType = value.getType(); - if (sourceType == targetType) + if (sourceType == targetType) { return value; + } bool sourceIsInteger = isa(sourceType); // Allow bitcast between same width float/int types. This is used for @@ -202,8 +204,9 @@ std::optional> rewriteAttrToOperands(Location loc, for (auto elementAttr : arrayAttr) { auto flattenedValues = rewriteAttrToOperands(loc, elementAttr, inputType, builder); - if (!flattenedValues) + if (!flattenedValues) { return std::nullopt; + } allValues.append(flattenedValues->begin(), flattenedValues->end()); } return allValues; @@ -226,8 +229,9 @@ std::optional> rewriteAttrToOperands(Location loc, int ordinal = 0; LogicalResult walkStatus = conversionInterface->walkAttributeStorage( attrValue, [&](Attribute elementAttr) { - if (anyFailed) + if (anyFailed) { return; + } auto elementType = tupleTypes[ordinal++]; auto flattenedValues = rewriteAttrToOperands(loc, elementAttr, elementType, builder); @@ -237,14 +241,16 @@ std::optional> rewriteAttrToOperands(Location loc, } allValues.append(flattenedValues->begin(), flattenedValues->end()); }); - if (failed(walkStatus)) + if (failed(walkStatus)) { return std::nullopt; + } } else { // Custom dialect type maps into zero or more input types (ala arrays). LogicalResult walkStatus = conversionInterface->walkAttributeStorage( attrValue, [&](Attribute elementAttr) { - if (anyFailed) + if (anyFailed) { return; + } auto flattenedValues = rewriteAttrToOperands(loc, elementAttr, inputType, builder); if (!flattenedValues) { @@ -253,11 +259,13 @@ std::optional> rewriteAttrToOperands(Location loc, } allValues.append(flattenedValues->begin(), flattenedValues->end()); }); - if (failed(walkStatus)) + if (failed(walkStatus)) { return std::nullopt; + } } - if (anyFailed) + if (anyFailed) { return std::nullopt; + } return allValues; } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h index c5d557f5b42d..5093a686b7c1 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h @@ -102,8 +102,9 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, if (auto attrValue = op->getAttr(inputName)) { auto flattenedAttrs = detail::rewriteAttrToOperands( op.getLoc(), attrValue, inputType, builder); - if (!flattenedAttrs) + if (!flattenedAttrs) { return std::nullopt; + } state.addOperands(*flattenedAttrs); if (importOp.isFuncArgumentVariadic(input.index())) { segmentSizes.push_back(flattenedAttrs->size() / @@ -162,8 +163,9 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, for (auto [result, targetType] : llvm::zip_equal(callOp->getResults(), operation->getResultTypes())) { targetType = typeConverter.convertType(targetType); - if (!targetType) + if (!targetType) { return std::nullopt; + } results.push_back(castFromImportType(result, targetType, builder)); } return results; @@ -185,8 +187,9 @@ class VMImportOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto results = rewriteToCall(op, adaptor, importOp, *this->getTypeConverter(), rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } rewriter.replaceOp(op, results.value()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp index 1a16eacffe91..e74176e95b84 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/Patterns.cpp @@ -30,8 +30,9 @@ class UnaryArithmeticOpConversion : public OpConversionPattern { matchAndRewrite(SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO(benvanik): support vectors. - if (isa(srcOp.getResult().getType())) + if (isa(srcOp.getResult().getType())) { return failure(); + } switch (adaptor.getOperand().getType().getIntOrFloatBitWidth()) { case 32: @@ -57,8 +58,9 @@ class BinaryArithmeticOpConversion : public OpConversionPattern { matchAndRewrite(SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO(benvanik): support vectors. - if (isa(srcOp.getResult().getType())) + if (isa(srcOp.getResult().getType())) { return failure(); + } switch (adaptor.getLhs().getType().getIntOrFloatBitWidth()) { case 32: diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp index efcea8fba614..e651bc9b995b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp @@ -115,16 +115,18 @@ struct FuncOpConversion : public OpConversionPattern { matchAndRewrite(func::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by import-specific conversion. - if (srcOp.isExternal()) + if (srcOp.isExternal()) { return failure(); + } // Convert function signature. TypeConverter::SignatureConversion signatureConversion( srcOp.getNumArguments()); auto newFuncType = convertFuncSignature(srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(newFuncType)) + if (failed(newFuncType)) { return failure(); + } // Create new function with converted argument and result types. // Note that attributes are dropped. Consider preserving some if needed. @@ -189,8 +191,9 @@ struct ExternalFuncOpConversion : public OpConversionPattern { matchAndRewrite(func::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by internal-specific conversion. - if (!srcOp.isExternal()) + if (!srcOp.isExternal()) { return failure(); + } // If the user declared an intended signature then we can use that instead // of running conversion ourselves. This can be used in cases where the @@ -210,8 +213,9 @@ struct ExternalFuncOpConversion : public OpConversionPattern { srcOp.getNumArguments()); auto convertedSignature = convertFuncSignature( srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(convertedSignature)) + if (failed(convertedSignature)) { return failure(); + } newSignature = *convertedSignature; } @@ -354,8 +358,9 @@ struct CallOpConversion : public OpConversionPattern { rewriter.setInsertionPointToStart(fallbackBlock); auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands, resultTypes, importTable, rewriter); - if (failed(fallbackResults)) + if (failed(fallbackResults)) { return failure(); + } IREE::VM::BranchOp::create(rewriter, loc, exitBlock, *fallbackResults); return exitResults; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp index f7f6ab226c34..4a6cd3f941ad 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertBufferOps.cpp @@ -20,15 +20,17 @@ namespace mlir::iree_compiler { namespace { static Value castToI64(Value value, OpBuilder &builder) { - if (value.getType().isInteger(64)) + if (value.getType().isInteger(64)) { return value; + } return builder.createOrFold( value.getLoc(), builder.getI64Type(), value); } static Value castToIndex(Value value, OpBuilder &builder) { - if (value.getType().isIndex()) + if (value.getType().isIndex()) { return value; + } return builder.createOrFold( value.getLoc(), builder.getIndexType(), value); } @@ -161,8 +163,9 @@ struct BufferCompareOpConversion static Value unscaleOffset(Location loc, Value offset, int64_t scale, OpBuilder &builder) { - if (scale == 1) + if (scale == 1) { return offset; + } return builder.createOrFold( loc, offset.getType(), offset, IREE::VM::ConstI64Op::create(builder, loc, scale)); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp index 1aefdfab5a80..1aec588a7b53 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertListOps.cpp @@ -20,15 +20,17 @@ namespace mlir::iree_compiler { namespace { static Value castToI32(Value value, OpBuilder &builder) { - if (value.getType().isInteger(32)) + if (value.getType().isInteger(32)) { return value; + } return builder.createOrFold( value.getLoc(), builder.getI32Type(), value); } static Value castToIndex(Value value, OpBuilder &builder) { - if (value.getType().isIndex()) + if (value.getType().isIndex()) { return value; + } return builder.createOrFold( value.getLoc(), builder.getIndexType(), value); } @@ -200,8 +202,9 @@ void populateUtilListToVMPatterns(MLIRContext *context, } else { elementType = typeConverter.convertType(type.getElementType()); } - if (!elementType) + if (!elementType) { return std::nullopt; + } return IREE::VM::RefType::get(IREE::VM::ListType::get(elementType)); }); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp index b9b2803fa9f7..65f723cf6513 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp @@ -89,16 +89,18 @@ class FuncOpConversion : public OpConversionPattern { matchAndRewrite(IREE::Util::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by import-specific conversion. - if (srcOp.isExternal()) + if (srcOp.isExternal()) { return failure(); + } // Convert function signature. TypeConverter::SignatureConversion signatureConversion( srcOp.getNumArguments()); auto newFuncType = convertFuncSignature(srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(newFuncType)) + if (failed(newFuncType)) { return failure(); + } // Create new function with converted argument and result types. // Note that attributes are dropped. Consider preserving some if needed. @@ -165,8 +167,9 @@ class ExternalFuncOpConversion matchAndRewrite(IREE::Util::FuncOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Handled by internal-specific conversion. - if (!srcOp.isExternal()) + if (!srcOp.isExternal()) { return failure(); + } // If the user declared an intended signature then we can use that instead // of running conversion ourselves. This can be used in cases where the @@ -186,8 +189,9 @@ class ExternalFuncOpConversion srcOp.getNumArguments()); auto convertedSignature = convertFuncSignature( srcOp, *getTypeConverter(), signatureConversion, rewriter); - if (failed(convertedSignature)) + if (failed(convertedSignature)) { return failure(); + } newSignature = *convertedSignature; } @@ -328,8 +332,9 @@ struct CallOpConversion : public OpConversionPattern { rewriter.setInsertionPointToStart(fallbackBlock); auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands, resultTypes, rewriter); - if (failed(fallbackResults)) + if (failed(fallbackResults)) { return failure(); + } IREE::VM::BranchOp::create(rewriter, loc, exitBlock, *fallbackResults); return exitResults; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 109b0ef21e1b..e7e456b32313 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -131,8 +131,9 @@ LogicalResult convertFuncOp(IREE::VM::FuncOp funcOp, } if (failed( - funcOp.replaceAllSymbolUses(builder.getStringAttr(name), moduleOp))) + funcOp.replaceAllSymbolUses(builder.getStringAttr(name), moduleOp))) { return funcOp.emitError() << "unable to update symbol name in module"; + } return success(); } @@ -1186,8 +1187,9 @@ createModuleStructure(IREE::VM::ModuleOp moduleOp, rodataOp.getAlignment() ? static_cast(rodataOp.getAlignment().value()) : 0; - if (alignment == 0) + if (alignment == 0) { alignment = kDefaultRodataAlignment; + } std::string bufferName = moduleOp.getName().str() + "_" + rodataOp.getName().str(); @@ -1196,8 +1198,9 @@ createModuleStructure(IREE::VM::ModuleOp moduleOp, ") static const uint8_t " + bufferName + "[] = {"; size_t index = 0; for (char value : byteBuffer) { - if (index++ > 0) + if (index++ > 0) { stmt += ", "; + } stmt += std::to_string( static_cast(static_cast(value))); } @@ -2785,11 +2788,13 @@ class CallOpConversion : public EmitCConversionPattern { IREE::VM::ImportOp importOp = lookupSymbolRef(op.getOperation(), "callee"); - if (!funcOp && !importOp) + if (!funcOp && !importOp) { return op.emitError() << "lookup of callee failed"; + } - if (funcOp && importOp) + if (funcOp && importOp) { return op.emitError() << "lookup of callee ambiguous"; + } const bool isImported = importOp != nullptr; @@ -2881,8 +2886,9 @@ class CallOpConversion : public EmitCConversionPattern { return failure(); } - if (!funcName.has_value()) + if (!funcName.has_value()) { return op->emitError() << "Couldn't build name to imported function"; + } auto callee = moduleOp.lookupSymbol(funcName.value()); if (callee == nullptr) { @@ -3823,8 +3829,9 @@ class BranchTableOpConversion { OpBuilder::InsertionGuard guard(rewriter); auto *nextBlock = rewriter.getInsertionBlock()->getNextNode(); - for (size_t i = 0; i < caseDestinations.size(); ++i) + for (size_t i = 0; i < caseDestinations.size(); ++i) { caseBlocks.push_back(rewriter.createBlock(nextBlock)); + } caseBlocks.push_back(rewriter.createBlock(nextBlock)); // default } IREE::VM::BranchOp::create(rewriter, op.getLoc(), caseBlocks.front()); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp index 4f36b525ab11..a4bc38fa0a36 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCBuilders.cpp @@ -261,8 +261,9 @@ void structDefinition(OpBuilder builder, Location location, std::string decl = std::string("struct ") + structName.str() + " {"; for (auto &field : fields) { decl += field.type + " " + field.name; - if (field.isArray()) + if (field.isArray()) { decl += "[" + std::to_string(field.arraySize.value()) + "]"; + } decl += ";"; } decl += "};"; diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp index dfb76203a164..378e77ce9c9a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp @@ -74,8 +74,9 @@ struct VMInlinerInterface : public DialectInlinerInterface { if (auto inliningPolicy = callable->getAttrOfType( "inlining_policy")) { - if (!inliningPolicy.isLegalToInline(call, callable)) + if (!inliningPolicy.isLegalToInline(call, callable)) { return false; + } } // Sure! return true; @@ -259,8 +260,9 @@ void VMDialect::printType(Type type, DialectAsmPrinter &os) const { Operation *VMDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { auto typedValue = dyn_cast(value); - if (!typedValue) + if (!typedValue) { return nullptr; + } if (ConstI32Op::isBuildableWith(typedValue, type)) { auto convertedValue = ConstI32Op::convertConstValue(typedValue); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index 381fc19e8bb4..93454ff077cd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -44,8 +44,9 @@ Attribute oneOfType(Type type) { } else if (isa(type)) { auto vtType = cast(type); auto element = oneOfType(vtType.getElementType()); - if (!element) + if (!element) { return {}; + } return DenseElementsAttr::get(vtType, element); } return {}; @@ -64,8 +65,9 @@ struct DropEmptyInitializerOp : public OpRewritePattern { using Base::Base; LogicalResult matchAndRewrite(InitializerOp op, PatternRewriter &rewriter) const override { - if (op.getBody().getBlocks().size() != 1) + if (op.getBody().getBlocks().size() != 1) { return failure(); + } auto &block = op.getBody().front(); if (block.empty() || isa(block.front())) { rewriter.eraseOp(op); @@ -85,12 +87,14 @@ struct InlineConstGlobalInitializer : public OpRewritePattern { PatternRewriter &rewriter) const override { SmallVector deadOps; op.walk([&](Operation *op) { - if (!isGlobalStoreOp(op)) + if (!isGlobalStoreOp(op)) { return; + } auto value = op->getOperand(0); Attribute valueAttr; - if (!matchPattern(value, m_Constant(&valueAttr))) + if (!matchPattern(value, m_Constant(&valueAttr))) { return; + } auto globalRefAttr = op->getAttrOfType("global"); assert(globalRefAttr); auto globalOp = @@ -100,10 +104,12 @@ struct InlineConstGlobalInitializer : public OpRewritePattern { globalOp, [&]() { globalOp.setGlobalInitialValue(valueAttr); }); deadOps.push_back(op); }); - if (deadOps.empty()) + if (deadOps.empty()) { return failure(); - for (auto deadOp : deadOps) + } + for (auto deadOp : deadOps) { rewriter.eraseOp(deadOp); + } return success(); } @@ -135,14 +141,17 @@ struct DropDefaultConstGlobalOpInitializer : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { - if (!op.getInitialValue().has_value()) + if (!op.getInitialValue().has_value()) { return failure(); + } if (auto value = dyn_cast(op.getInitialValueAttr())) { - if (value.getValue() != 0) + if (value.getValue() != 0) { return failure(); + } } else if (auto value = dyn_cast(op.getInitialValueAttr())) { - if (value.getValue().isNonZero()) + if (value.getValue().isNonZero()) { return failure(); + } } auto visibility = op.getVisibility(); auto newOp = rewriter.replaceOpWithNewOp( @@ -488,8 +497,9 @@ static Attribute constFoldUnaryOp(Attribute rawOperand, dyn_cast_if_present(rawOperand)) { auto elementResult = constFoldUnaryOp( {operand.getSplatValue()}, calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(operand.getType(), elementResult); } else if (auto operand = dyn_cast_if_present(rawOperand)) { return cast(operand).mapValues( @@ -511,8 +521,9 @@ constFoldFloatUnaryOp(Attribute rawOperand, dyn_cast_if_present(rawOperand)) { auto elementResult = constFoldFloatUnaryOp({operand.getSplatValue()}, calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(operand.getType(), elementResult); } else if (auto operand = dyn_cast_if_present(rawOperand)) { return cast(operand).mapValues( @@ -535,33 +546,38 @@ static TypedAttr constFoldBinaryOp(Attribute rawLhs, Attribute rawRhs, const CalculationT &calculate) { if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs) + if (!rhs) { return {}; + } return AttrElementT::get(lhs.getType(), calculate(lhs.getValue(), rhs.getValue())); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { // TODO(benvanik): handle splat/otherwise. auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto elementResult = constFoldBinaryOp( lhs.getSplatValue(), rhs.getSplatValue(), calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(lhs.getType(), elementResult); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto lhsIt = lhs.getValues().begin(); auto rhsIt = rhs.getValues().begin(); SmallVector resultAttrs(lhs.getNumElements()); for (int64_t i = 0; i < lhs.getNumElements(); ++i) { resultAttrs[i] = constFoldBinaryOp(*lhsIt, *rhsIt, calculate); - if (!resultAttrs[i]) + if (!resultAttrs[i]) { return {}; + } ++lhsIt; ++rhsIt; } @@ -597,8 +613,9 @@ static Attribute constFoldTernaryOp(Attribute rawA, Attribute rawB, auto elementResult = constFoldTernaryOp( a.getSplatValue(), b.getSplatValue(), c.getSplatValue(), calculate); - if (!elementResult) + if (!elementResult) { return {}; + } return DenseElementsAttr::get(a.getType(), elementResult); } else if (auto a = dyn_cast_if_present(rawA)) { auto b = dyn_cast_if_present(rawB); @@ -613,8 +630,9 @@ static Attribute constFoldTernaryOp(Attribute rawA, Attribute rawB, for (int64_t i = 0; i < a.getNumElements(); ++i) { resultAttrs[i] = constFoldTernaryOp(*aIt, *bIt, *cIt, calculate); - if (!resultAttrs[i]) + if (!resultAttrs[i]) { return {}; + } ++aIt; ++bIt; ++cIt; @@ -669,14 +687,16 @@ static OpFoldResult foldAddOp(ADD op, Attribute lhs, Attribute rhs) { if (auto subOp = dyn_cast_if_present(op.getLhs().getDefiningOp())) { // t = vm.sub x, y // = vm.add t, z - if (subOp.getRhs() == op.getRhs()) // y == z: - return subOp.getLhs(); // (x - y) + y = x + if (subOp.getRhs() == op.getRhs()) { // y == z: + return subOp.getLhs(); // (x - y) + y = x + } } else if (auto subOp = dyn_cast_if_present(op.getRhs().getDefiningOp())) { // t = vm.sub x, y // = vm.add z, t - if (subOp.getRhs() == op.getLhs()) // y == z: - return subOp.getLhs(); // y + (x - y) = x + if (subOp.getRhs() == op.getLhs()) { // y == z: + return subOp.getLhs(); // y + (x - y) = x + } } return constFoldBinaryOp( lhs, rhs, @@ -716,10 +736,12 @@ static OpFoldResult foldSubOp(SUB op, Attribute lhs, Attribute rhs) { if (auto addOp = dyn_cast_if_present(op.getLhs().getDefiningOp())) { // t = vm.add x, y // = vm.sub t, z - if (addOp.getLhs() == op.getRhs()) // x == z: - return addOp.getRhs(); // (x + y) - x = y - if (addOp.getRhs() == op.getRhs()) // y == z: - return addOp.getLhs(); // (x + y) - y = x + if (addOp.getLhs() == op.getRhs()) { // x == z: + return addOp.getRhs(); // (x + y) - x = y + } + if (addOp.getRhs() == op.getRhs()) { // y == z: + return addOp.getLhs(); // (x + y) - y = x + } } return constFoldBinaryOp( lhs, rhs, @@ -764,8 +786,9 @@ struct FoldConstantMulOperand : public OpRewritePattern { LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { AttrElementT c1, c2; - if (!matchPattern(op.getRhs(), m_Constant(&c1))) + if (!matchPattern(op.getRhs(), m_Constant(&c1))) { return failure(); + } if (auto mulOp = dyn_cast_if_present(op.getLhs().getDefiningOp())) { if (matchPattern(mulOp.getRhs(), m_Constant(&c2))) { auto c = rewriter.createOrFold( @@ -980,8 +1003,9 @@ OpFoldResult AbsI64Op::fold(FoldAdaptor operands) { } OpFoldResult MinI32SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smin(lhs, rhs); @@ -989,8 +1013,9 @@ OpFoldResult MinI32SOp::fold(FoldAdaptor operands) { } OpFoldResult MinI64SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smin(lhs, rhs); @@ -998,8 +1023,9 @@ OpFoldResult MinI64SOp::fold(FoldAdaptor operands) { } OpFoldResult MinI32UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umin(lhs, rhs); @@ -1007,8 +1033,9 @@ OpFoldResult MinI32UOp::fold(FoldAdaptor operands) { } OpFoldResult MinI64UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umin(lhs, rhs); @@ -1016,8 +1043,9 @@ OpFoldResult MinI64UOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI32SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smax(lhs, rhs); @@ -1025,8 +1053,9 @@ OpFoldResult MaxI32SOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI64SOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::smax(lhs, rhs); @@ -1034,8 +1063,9 @@ OpFoldResult MaxI64SOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI32UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umax(lhs, rhs); @@ -1043,8 +1073,9 @@ OpFoldResult MaxI32UOp::fold(FoldAdaptor operands) { } OpFoldResult MaxI64UOp::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp(operands.getLhs(), operands.getRhs(), [](const APInt &lhs, const APInt &rhs) { return llvm::APIntOps::umax(lhs, rhs); @@ -1274,16 +1305,18 @@ OpFoldResult MinF64Op::fold(FoldAdaptor operands) { } OpFoldResult MaxF32Op::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp( operands.getLhs(), operands.getRhs(), [](const APFloat &a, const APFloat &b) { return llvm::maxnum(a, b); }); } OpFoldResult MaxF64Op::fold(FoldAdaptor operands) { - if (getLhs() == getRhs()) + if (getLhs() == getRhs()) { return getLhs(); + } return constFoldBinaryOp( operands.getLhs(), operands.getRhs(), [](const APFloat &a, const APFloat &b) { return llvm::maxnum(a, b); }); @@ -1810,8 +1843,9 @@ struct FoldCastRefIntoOpResult : public OpRewritePattern { PatternRewriter &rewriter) const override { auto zeroOp = dyn_cast_if_present( castOp.getOperand().getDefiningOp()); - if (!zeroOp) + if (!zeroOp) { return failure(); + } rewriter.replaceOpWithNewOp(castOp, castOp.getResult().getType()); return success(); @@ -1821,8 +1855,9 @@ struct FoldCastRefIntoOpResult : public OpRewritePattern { } // namespace OpFoldResult CastAnyRefOp::fold(FoldAdaptor operands) { - if (getOperand().getType() == getResult().getType()) + if (getOperand().getType() == getResult().getType()) { return getOperand(); + } if (auto castOp = dyn_cast_if_present(getOperand().getDefiningOp())) { if (castOp.getOperand().getType() == getResult().getType()) { @@ -1838,8 +1873,9 @@ void CastAnyRefOp::getCanonicalizationPatterns(RewritePatternSet &results, } OpFoldResult CastRefAnyOp::fold(FoldAdaptor operands) { - if (getOperand().getType() == getResult().getType()) + if (getOperand().getType() == getResult().getType()) { return getOperand(); + } if (auto castOp = dyn_cast_if_present(getOperand().getDefiningOp())) { if (castOp.getOperand().getType() == getResult().getType()) { @@ -1894,8 +1930,9 @@ static Attribute constFoldBinaryCmpOp(Attribute rawLhs, Attribute rawRhs, const CalculationT &calculate) { if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs) + if (!rhs) { return {}; + } auto boolType = IntegerType::get(lhs.getContext(), 32); return AttrElementT::get(boolType, calculate(lhs.getValue(), rhs.getValue())); @@ -2321,35 +2358,40 @@ static TypedAttr constFoldBinaryCmpFOp(Attribute rawLhs, Attribute rawRhs, const CalculationT &calculate) { if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs) + if (!rhs) { return {}; + } return IntegerAttr::get(IntegerType::get(lhs.getContext(), 32), calculate(lhs.getValue(), rhs.getValue())); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { // TODO(benvanik): handle splat/otherwise. auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto elementResult = constFoldBinaryCmpFOp( lhs.getSplatValue(), rhs.getSplatValue(), calculate); - if (!elementResult) + if (!elementResult) { return {}; + } auto resultType = lhs.getType().clone({}, IntegerType::get(lhs.getContext(), 32)); return DenseElementsAttr::get(resultType, elementResult); } else if (auto lhs = dyn_cast_if_present(rawLhs)) { auto rhs = dyn_cast_if_present(rawRhs); - if (!rhs || lhs.getType() != rhs.getType()) + if (!rhs || lhs.getType() != rhs.getType()) { return {}; + } auto lhsIt = lhs.getValues().begin(); auto rhsIt = rhs.getValues().begin(); SmallVector resultAttrs(lhs.getNumElements()); for (int64_t i = 0; i < lhs.getNumElements(); ++i) { resultAttrs[i] = constFoldBinaryCmpFOp(*lhsIt, *rhsIt, calculate); - if (!resultAttrs[i]) + if (!resultAttrs[i]) { return {}; + } ++lhsIt; ++rhsIt; } @@ -2979,22 +3021,27 @@ static LogicalResult collapseBranch(Block *&successor, return failure(); } // Check that the successor only contains a unconditional branch. - if (std::next(successor->begin()) != successor->end()) + if (std::next(successor->begin()) != successor->end()) { return failure(); + } // Check that the terminator is an unconditional branch. BranchOp successorBranch = dyn_cast(successor->getTerminator()); - if (!successorBranch) + if (!successorBranch) { return failure(); + } // Check that the arguments are only used within the terminator. for (BlockArgument arg : successor->getArguments()) { - for (Operation *user : arg.getUsers()) - if (user != successorBranch) + for (Operation *user : arg.getUsers()) { + if (user != successorBranch) { return failure(); + } + } } // Don't try to collapse branches to infinite loops. Block *successorDest = successorBranch.getDest(); - if (successorDest == successor) + if (successorDest == successor) { return failure(); + } // Update the operands to the successor. If the branch parent has no // arguments, we can use the branch operands directly. @@ -3008,10 +3055,11 @@ static LogicalResult collapseBranch(Block *&successor, // Otherwise, we need to remap any argument operands. for (Value operand : operands) { BlockArgument argOperand = dyn_cast(operand); - if (argOperand && argOperand.getOwner() == successor) + if (argOperand && argOperand.getOwner() == successor) { argStorage.push_back(successorOperands[argOperand.getArgNumber()]); - else + } else { argStorage.push_back(operand); + } } successor = successorDest; successorOperands = argStorage; @@ -3312,8 +3360,9 @@ struct RequiredImportResolver : public OpRewritePattern { PatternRewriter &rewriter) const override { auto importOp = SymbolTable::lookupNearestSymbolFrom( op, op.getImportAttr()); - if (!importOp || importOp.getIsOptional()) + if (!importOp || importOp.getIsOptional()) { return failure(); + } rewriter.replaceOpWithNewOp(op, 1); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index c6798a85fdc5..199d9bc99786 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -65,19 +65,23 @@ void setResultIntegerName(OpAsmSetValueNameFn &setNameFn, Value result, // (type, type, ...) ParseResult parseResultTypeList(OpAsmParser &parser, ArrayAttr &resultTypes) { - if (failed(parser.parseLParen())) + if (failed(parser.parseLParen())) { return failure(); + } SmallVector typeAttrs; - if (succeeded(parser.parseOptionalRParen())) + if (succeeded(parser.parseOptionalRParen())) { goto done; // empty list + } do { Type type; - if (failed(parser.parseType(type))) + if (failed(parser.parseType(type))) { return failure(); + } typeAttrs.push_back(TypeAttr::get(type)); } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) + if (failed(parser.parseRParen())) { return failure(); + } done: resultTypes = parser.getBuilder().getArrayAttr(typeAttrs); return success(); @@ -172,9 +176,10 @@ Block *FuncOp::addEntryBlock() { LogicalResult FuncOp::verifyType() { auto type = getFunctionTypeAttr().getValue(); - if (!isa(type)) + if (!isa(type)) { return emitOpError("requires '" + getFunctionTypeAttrName().getValue() + "' attribute of function type"); + } return success(); } @@ -404,9 +409,10 @@ void ImportOp::build(OpBuilder &builder, OperationState &result, StringRef name, LogicalResult ImportOp::verifyType() { auto type = getFunctionTypeAttr().getValue(); - if (!isa(type)) + if (!isa(type)) { return emitOpError("requires '" + getFunctionTypeAttrName().getValue() + "' attribute of function type"); + } return success(); } @@ -609,8 +615,9 @@ static bool isConstFloatBuildableWith(TypedAttr value, Type type) { } else if (auto elementsAttr = dyn_cast(value)) { elementType = elementsAttr.getShapedType().getElementType(); } - if (!elementType) + if (!elementType) { return false; + } return elementType.getIntOrFloatBitWidth() == SZ; } @@ -920,8 +927,9 @@ static std::string makeSafeIdentifier(StringRef unsafeIdentifier) { llvm::raw_string_ostream os(result); bool lastUnderscore = true; for (char c : unsafeIdentifier) { - if (!llvm::isPrint(c)) + if (!llvm::isPrint(c)) { continue; + } if (llvm::isAlnum(c)) { os << llvm::toLower(c); lastUnderscore = false; @@ -1410,8 +1418,9 @@ void CallVariadicOp::print(OpAsmPrinter &p) { } p << tupleOperands; p << ')'; - if (i < segmentSize - 1) + if (i < segmentSize - 1) { p << ", "; + } } } else { SmallVector segmentOperands; @@ -1562,32 +1571,39 @@ static ParseResult parseBranchTableCases( SmallVectorImpl> &caseOperands, SmallVectorImpl> &caseOperandTypes) { if (parser.parseKeyword("default") || parser.parseColon() || - parser.parseSuccessor(defaultDestination)) + parser.parseSuccessor(defaultDestination)) { return failure(); + } if (succeeded(parser.parseOptionalLParen())) { if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, /*allowResultNumber=*/false) || - parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) + parser.parseColonTypeList(defaultOperandTypes) || + parser.parseRParen()) { return failure(); + } } while (succeeded(parser.parseOptionalComma())) { int64_t index = 0; - if (failed(parser.parseInteger(index))) + if (failed(parser.parseInteger(index))) { return failure(); - if (index != caseDestinations.size()) + } + if (index != caseDestinations.size()) { return failure(); + } Block *destination; SmallVector operands; SmallVector operandTypes; if (failed(parser.parseColon()) || - failed(parser.parseSuccessor(destination))) + failed(parser.parseSuccessor(destination))) { return failure(); + } if (succeeded(parser.parseOptionalLParen())) { if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None, /*allowResultNumber=*/false)) || failed(parser.parseColonTypeList(operandTypes)) || - failed(parser.parseRParen())) + failed(parser.parseRParen())) { return failure(); + } } caseDestinations.push_back(destination); caseOperands.emplace_back(operands); @@ -1628,8 +1644,9 @@ Block *BranchTableOp::getSuccessorForOperands(ArrayRef operands) { SuccessorRange caseDestinations = getCaseDestinations(); if (auto valueAttr = dyn_cast_if_present(operands.front())) { int64_t value = valueAttr.getValue().getSExtValue(); - if (value < 0 || value >= caseDestinations.size()) + if (value < 0 || value >= caseDestinations.size()) { return getDefaultDestination(); + } return caseDestinations[value]; } return nullptr; diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp index ef0f25011e52..02d2b7df847f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMTypes.cpp @@ -152,8 +152,9 @@ Attribute VMDialect::parseAttribute(DialectAsmParser &parser, Type type) const { Attribute genAttr; OptionalParseResult parseResult = generatedAttributeParser(parser, &mnemonic, type, genAttr); - if (parseResult.has_value()) + if (parseResult.has_value()) { return genAttr; + } parser.emitError(parser.getNameLoc()) << "unknown HAL attribute: " << mnemonic; return {}; diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp index e8f0bb528770..8784f5296d17 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/ArchiveWriter.cpp @@ -670,8 +670,9 @@ LogicalResult ZIPArchiveWriter::flush(FlatbufferBuilder &fbb) { return success(); }, os); - if (!zipFile.has_value()) + if (!zipFile.has_value()) { return failure(); + } fileRefs.push_back(*zipFile); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp index 070329e4c377..882c5fb377b9 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp @@ -222,12 +222,14 @@ class V0BytecodeEncoder : public BytecodeEncoder { LogicalResult encodeBranchTable(SuccessorRange caseSuccessors, OperandRangeRange caseOperands, int baseSuccessorIndex) override { - if (failed(writeUint16(caseSuccessors.size()))) + if (failed(writeUint16(caseSuccessors.size()))) { return failure(); + } for (auto [successor, operands] : llvm::zip_equal(caseSuccessors, caseOperands)) { - if (failed(encodeBranch(successor, operands, ++baseSuccessorIndex))) + if (failed(encodeBranch(successor, operands, ++baseSuccessorIndex))) { return failure(); + } } return success(); } @@ -321,11 +323,13 @@ class V0BytecodeEncoder : public BytecodeEncoder { LogicalResult ensureAlignment(size_t alignment) { size_t paddedSize = (bytecode_.size() + (alignment - 1)) & ~(alignment - 1); size_t padding = paddedSize - bytecode_.size(); - if (padding == 0) + if (padding == 0) { return success(); + } static const uint8_t kZeros[32] = {0}; - if (padding > sizeof(kZeros)) + if (padding > sizeof(kZeros)) { return failure(); + } return writeBytes(kZeros, padding); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 8bba90c5146d..0fa7bbfe6eea 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -159,8 +159,9 @@ canonicalizeModule(IREE::VM::BytecodeTargetOptions bytecodeOptions, // empty/null list). static iree_vm_AttrDef_vec_ref_t makeAttrDefs(DictionaryAttr attrs, FlatbufferBuilder &fbb) { - if (!attrs || attrs.empty()) + if (!attrs || attrs.empty()) { return 0; + } SmallVector attrRefs; for (auto attr : attrs) { auto key = attr.getName().strref(); @@ -216,8 +217,9 @@ makeImportFunctionSignatureDef(IREE::VM::ImportOp importOp, FlatbufferBuilder &fbb) { // Generate the signature calling convention string based on types. auto cconv = makeImportCallingConventionString(importOp); - if (!cconv.has_value()) + if (!cconv.has_value()) { return {}; + } return createFunctionSignatureDef(importOp.getFunctionType(), typeTable, cconv.value(), /*attrsRef=*/0, fbb); } @@ -229,8 +231,9 @@ makeFunctionSignatureDef(IREE::VM::FuncOp funcOp, FlatbufferBuilder &fbb) { // Generate the signature calling convention string based on types. auto cconv = makeCallingConventionString(funcOp); - if (!cconv.has_value()) + if (!cconv.has_value()) { return {}; + } // Encode reflection attributes. iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs( @@ -390,8 +393,9 @@ static LogicalResult buildFlatBufferModule( flatbuffers_uint8_vec_ref_t embeddedRef = serializeEmbeddedData( rodataRef.rodataOp.getLoc(), rodataRef.rodataOp.getValue(), rodataRef.alignment, rodataRef.totalSize, fbb); - if (!embeddedRef) + if (!embeddedRef) { return failure(); + } iree_vm_RodataSegmentDef_start(fbb); iree_vm_RodataSegmentDef_embedded_data_add(fbb, embeddedRef); rodataSegmentRefs.push_back(iree_vm_RodataSegmentDef_end(fbb)); @@ -502,10 +506,12 @@ static LogicalResult buildFlatBufferModule( // so that we can multi-version. For now the moduleRequirements will be the OR // of all functions. iree_vm_FeatureBits_enum_t allowedFeatures = 0; - if (vmOptions.f32Extension) + if (vmOptions.f32Extension) { allowedFeatures |= iree_vm_FeatureBits_EXT_F32; - if (vmOptions.f64Extension) + } + if (vmOptions.f64Extension) { allowedFeatures |= iree_vm_FeatureBits_EXT_F64; + } // Yield/unwind are core VM semantics once supported by the runtime. allowedFeatures |= iree_vm_FeatureBits_YIELD; allowedFeatures |= iree_vm_FeatureBits_UNWIND; diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp index 6e1edb581c00..172777efdde6 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/DebugDatabaseBuilder.cpp @@ -34,8 +34,9 @@ struct LocationTable { // Inserts a string into the location table string subtable if needed. flatbuffers_string_ref_t insert(StringRef value) { auto it = strings.find(value); - if (it != strings.end()) + if (it != strings.end()) { return it->second; + } auto stringRef = fbb.createString(value); strings[value] = stringRef; return stringRef; @@ -45,8 +46,9 @@ struct LocationTable { // Returns the ordinal of the location in the table. int32_t insert(Location baseLoc) { auto it = map.find(baseLoc); - if (it != map.end()) + if (it != map.end()) { return it->second; + } auto locationRef = llvm::TypeSwitch(baseLoc) .Case([&](CallSiteLoc loc) { @@ -103,8 +105,9 @@ struct LocationTable { iree_vm_DebugDatabaseDef_ref_t DebugDatabaseBuilder::build(FlatbufferBuilder &fbb) { - if (functionSourceMaps.empty()) + if (functionSourceMaps.empty()) { return 0; + } LocationTable locationTable(fbb); diff --git a/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp b/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp index 821a8701f894..b89d33c5cdcd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp @@ -36,11 +36,13 @@ bool emitEncodeFnDefs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { auto defs = recordKeeper.getAllDerivedDefinitions("VM_Op"); for (const auto *def : defs) { - if (def->isValueUnset("encoding")) + if (def->isValueUnset("encoding")) { continue; + } auto encodingExprs = def->getValueAsListOfDefs("encoding"); - if (encodingExprs.empty()) + if (encodingExprs.empty()) { continue; + } Operator op(def); tblgen::DialectNamespaceEmitter emitter(os, op.getDialect()); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp index 4b0be6ab54e8..10a5d9a3f223 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/AnnotateFunctions.cpp @@ -50,11 +50,13 @@ static FuncInfo analyzeFunction(IREE::VM::FuncOp funcOp, funcOp.walk([&](Operation *op) { // Collect callees. if (auto callOp = dyn_cast(op)) { - if (auto callee = symbolTable.lookup(callOp.getCallee())) + if (auto callee = symbolTable.lookup(callOp.getCallee())) { info.callees.push_back(callee); + } } else if (auto callOp = dyn_cast(op)) { - if (auto callee = symbolTable.lookup(callOp.getCallee())) + if (auto callee = symbolTable.lookup(callOp.getCallee())) { info.callees.push_back(callee); + } } // Check for yield ops. if (isa(op)) { @@ -127,8 +129,9 @@ class AnnotateFunctionsPass bool sccUnwind = false; for (CallGraphNode *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } Operation *op = node->getCallableRegion()->getParentOp(); auto it = funcInfos.find(op); if (it != funcInfos.end()) { @@ -139,17 +142,20 @@ class AnnotateFunctionsPass // Propagate from callees (already processed, outside this SCC). for (CallGraphNode *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } Operation *op = node->getCallableRegion()->getParentOp(); auto it = funcInfos.find(op); - if (it == funcInfos.end()) + if (it == funcInfos.end()) { continue; + } for (Operation *calleeOp : it->second.callees) { auto calleeIt = funcInfos.find(calleeOp); - if (calleeIt == funcInfos.end()) + if (calleeIt == funcInfos.end()) { continue; + } // Only propagate from callees outside this SCC (they have final // bits). @@ -170,8 +176,9 @@ class AnnotateFunctionsPass // Apply to all nodes in this SCC. for (CallGraphNode *node : scc) { - if (node->isExternal()) + if (node->isExternal()) { continue; + } Operation *op = node->getCallableRegion()->getParentOp(); auto it = funcInfos.find(op); if (it != funcInfos.end()) { @@ -184,8 +191,9 @@ class AnnotateFunctionsPass // Phase 4: Apply attributes to functions. for (auto funcOp : moduleOp.getOps()) { auto it = funcInfos.find(funcOp); - if (it == funcInfos.end()) + if (it == funcInfos.end()) { continue; + } if (it->second.needsYield && !funcOp->hasAttr("vm.yield")) { funcOp->setAttr("vm.yield", UnitAttr::get(funcOp.getContext())); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index 6f08662d394f..50243cf870fd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -73,11 +73,13 @@ gatherUsedDialectInterfaces(mlir::ModuleOp moduleOp) { // Generic dialect lookup. dialect = op->getDialect(); } - if (!dialect) + if (!dialect) { return; + } auto *dialectInterface = dialect->getRegisteredInterface(); - if (!dialectInterface) + if (!dialectInterface) { return; + } resultSet.insert(dialectInterface); }); @@ -97,8 +99,9 @@ class ConversionPass : public IREE::VM::impl::ConversionPassBase { using Base::Base; void runOnOperation() override { - if (getOperation().getBody()->empty()) + if (getOperation().getBody()->empty()) { return; + } auto targetOptions = targetOptionsFromConversionPass(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp index fc53769abb33..bebed5bd905a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/ConvertToYieldableCalls.cpp @@ -120,11 +120,13 @@ class ConvertToYieldableCallsPass // Extract segment info. SmallVector segmentSizes; - for (auto val : callOp.getSegmentSizes()) + for (auto val : callOp.getSegmentSizes()) { segmentSizes.push_back(val.getSExtValue()); + } SmallVector segmentTypes; - for (auto typeAttr : callOp.getSegmentTypes()) + for (auto typeAttr : callOp.getSegmentTypes()) { segmentTypes.push_back(cast(typeAttr).getValue()); + } // Create the vm.call.variadic.yieldable op and erase the original call. builder.setInsertionPoint(callOp); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp index 451288028435..e06138a84d63 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp @@ -77,8 +77,9 @@ class DeduplicateRodataPass replacer.addReplacement( [&](SymbolRefAttr attr) -> std::pair { auto replacement = replacements.find(attr); - if (replacement != replacements.end()) + if (replacement != replacements.end()) { return {replacement->getSecond(), WalkResult::skip()}; + } return {attr, WalkResult::skip()}; }); moduleOp.walk([&](Operation *op) { replacer.replaceElementsIn(op); }); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp index fd0b149f8c82..0ed40e5dcd39 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropEmptyModuleInitializers.cpp @@ -42,8 +42,9 @@ class DropEmptyModuleInitializersPass auto initFuncOp = symbolTable.lookup("__init"); if (initFuncOp && isFuncEmpty(initFuncOp)) { auto exportOp = exportOps[initFuncOp.getName()]; - if (exportOp) + if (exportOp) { exportOp.erase(); + } initFuncOp.erase(); } @@ -51,8 +52,9 @@ class DropEmptyModuleInitializersPass auto deinitFuncOp = symbolTable.lookup("__deinit"); if (deinitFuncOp && isFuncEmpty(deinitFuncOp)) { auto exportOp = exportOps[deinitFuncOp.getName()]; - if (exportOp) + if (exportOp) { exportOp.erase(); + } deinitFuncOp.erase(); } } diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 98b47b80ad9c..eb944a133605 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp @@ -96,8 +96,9 @@ static void fixupGlobalMutability(Operation *moduleOp, explorer.initialize(); SmallVector deadOps; explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) { - if (globalInfo->uses.empty()) + if (globalInfo->uses.empty()) { return; + } // TODO(benvanik): verify we want this behavior - we likely want to change // this to be mutable only if stores exist outside of initializers. // diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp index 3fdc6f1f1ebd..03f26cb7a95c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/MaterializeRefDiscards.cpp @@ -76,8 +76,9 @@ class MaterializeRefDiscardsPass // to callee). bool isTerminatorMoveOperand(Value value, Operation *terminator) { auto refMoveOp = dyn_cast(terminator); - if (!refMoveOp) + if (!refMoveOp) { return false; + } // Check if value is forwarded to any successor - if so, it's not a "pure" // MOVE to callee, it's a forward to successor block. @@ -119,14 +120,16 @@ class MaterializeRefDiscardsPass bool isForwardedOnEdge(Value value, Block *pred, Block *succ) { Operation *terminator = pred->getTerminator(); auto branchOp = dyn_cast(terminator); - if (!branchOp) + if (!branchOp) { return false; + } for (unsigned i = 0; i < terminator->getNumSuccessors(); ++i) { if (terminator->getSuccessor(i) == succ) { auto operands = branchOp.getSuccessorOperands(i); - if (llvm::is_contained(operands.getForwardedOperands(), value)) + if (llvm::is_contained(operands.getForwardedOperands(), value)) { return true; + } } } return false; @@ -223,8 +226,9 @@ class MaterializeRefDiscardsPass LogicalResult processFunction(FuncOp funcOp) { // Skip empty functions. - if (funcOp.getBlocks().empty()) + if (funcOp.getBlocks().empty()) { return success(); + } // Compute liveness information. ValueLiveness liveness; @@ -278,8 +282,9 @@ class MaterializeRefDiscardsPass SmallVector dyingRefs; for (Value ref : allRefs) { - if (escapingRefs.count(ref)) + if (escapingRefs.count(ref)) { continue; + } // Check if ref should be discarded on this edge. bool isInLiveOuts = llvm::is_contained(liveOuts, ref); @@ -294,22 +299,26 @@ class MaterializeRefDiscardsPass } // Skip if ref is neither in liveOuts nor forwarded on any edge. - if (!isInLiveOuts && !isForwardedOnAny) + if (!isInLiveOuts && !isForwardedOnAny) { continue; + } // Skip if ref is live-in to successor. - if (llvm::is_contained(succLiveIns, ref)) + if (llvm::is_contained(succLiveIns, ref)) { continue; + } // Skip if ref is forwarded on this specific edge. - if (isForwardedOnEdge(ref, &block, succ)) + if (isForwardedOnEdge(ref, &block, succ)) { continue; + } // Skip if ref is a MOVE operand of the terminator. // MOVE operands transfer ownership to the callee, so we must NOT // discard them - the callee takes responsibility for the ref. - if (isTerminatorMoveOperand(ref, terminator)) + if (isTerminatorMoveOperand(ref, terminator)) { continue; + } // Ref dies on this edge. dyingRefs.push_back(ref); @@ -337,17 +346,20 @@ class MaterializeRefDiscardsPass llvm::DenseMap opToIndex; for (Operation &op : block) { - if (isa(&op)) + if (isa(&op)) { continue; + } for (OpOperand &operand : op.getOpOperands()) { Value value = operand.get(); - if (!isa(value.getType())) + if (!isa(value.getType())) { continue; + } // Skip escaping refs. - if (escapingRefs.count(value)) + if (escapingRefs.count(value)) { continue; + } // Check if this is the last use and value doesn't escape via // live-outs. @@ -408,8 +420,9 @@ class MaterializeRefDiscardsPass // Unused block arguments. SmallVector unusedBlockArgs; for (BlockArgument arg : block.getArguments()) { - if (!isa(arg.getType())) + if (!isa(arg.getType())) { continue; + } if (arg.use_empty() && !escapingRefs.count(arg)) { unusedBlockArgs.push_back(arg); } @@ -425,8 +438,9 @@ class MaterializeRefDiscardsPass llvm::DenseMap opToResultIndex; for (Operation &op : block) { for (Value result : op.getResults()) { - if (!isa(result.getType())) + if (!isa(result.getType())) { continue; + } if (result.use_empty() && !escapingRefs.count(result)) { auto it = opToResultIndex.find(&op); if (it == opToResultIndex.end()) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp index ea072fe8eb85..d539c8209f9b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp @@ -85,8 +85,9 @@ class OrdinalAllocationPass int globalBytes = 0; for (auto sizeGlobalOps : llvm::enumerate(primitiveGlobalOps)) { size_t storageSize = sizeGlobalOps.index(); - if (sizeGlobalOps.value().empty()) + if (sizeGlobalOps.value().empty()) { continue; + } nextGlobalBytesOrdinal = llvm::alignTo(nextGlobalBytesOrdinal, storageSize); for (auto &globalOp : sizeGlobalOps.value()) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp index 8869fd662c82..3ede87b5b528 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/ResolveRodataLoads.cpp @@ -67,11 +67,13 @@ static void processBufferGlobal(Explorer &explorer, const Explorer::GlobalInfo *globalInfo, DenseSet &deadOps) { // Ignore indirect/unanalyzable globals. - if (globalInfo->isIndirect) + if (globalInfo->isIndirect) { return; + } // Ignore mutable globals, as they could be changed to various values. - if (globalInfo->op.isGlobalMutable()) + if (globalInfo->op.isGlobalMutable()) { return; + } // If there are no stores to the global then it's always null. if (globalInfo->getStores().empty()) { @@ -90,8 +92,9 @@ static void processBufferGlobal(Explorer &explorer, // the program (there may be multiple initializers or control flow that // determines the stored value). auto rodataOp = findUniformlyStoredRodata(explorer, globalInfo); - if (!rodataOp) + if (!rodataOp) { return; + } // All stores to the global are of the same rodata. // Replace all of the loads with direct references to the rodata and then @@ -136,8 +139,9 @@ class ResolveRodataLoadsPass }); // Erase all ops after we're done iterating them. - for (auto *deadOp : deadOps) + for (auto *deadOp : deadOps) { deadOp->erase(); + } } }; diff --git a/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp b/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp index a1c56544d366..b92d4f6cbf23 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Utils/TypeTable.cpp @@ -17,8 +17,9 @@ std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { if (auto refPtrType = dyn_cast(type)) { type = refPtrType.getObjectType(); } - if (typeMap.count(type)) + if (typeMap.count(type)) { return; + } std::string str; llvm::raw_string_ostream sstream(str); type.print(sstream); @@ -31,10 +32,12 @@ std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { }; for (auto funcOp : moduleOp.getBlock().getOps()) { funcOp.walk([&](Operation *op) { - for (auto type : op->getOperandTypes()) + for (auto type : op->getOperandTypes()) { tryInsertType(type); - for (auto type : op->getResultTypes()) + } + for (auto type : op->getResultTypes()) { tryInsertType(type); + } }); } diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp index 451168805145..8953a4760746 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.cpp @@ -66,8 +66,9 @@ class VMVXImportOpConversion : public OpConversionPattern { return failure(); } auto results = emitCall(op, adaptor, importOp, rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } rewriter.replaceOp(op, results.value()); return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp index 73e7a69a455c..a51de1e55cc7 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp @@ -55,8 +55,9 @@ class MaterializeConstantsPass final } // No constants found; omit the constant block entirely. - if (allLoadOps.empty()) + if (allLoadOps.empty()) { return; + } // Create global ops for each constant and replace the HAL ops so they load // from them. Each global will track what constant key it represents for diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp index 0be2c806c15f..4284b1b5b298 100644 --- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp +++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp @@ -230,8 +230,9 @@ struct FromMemRefSubView : public OpRewritePattern { LogicalResult matchAndRewrite(GetBufferDescriptorOp op, PatternRewriter &rewriter) const override { auto subview = op.getSource().template getDefiningOp(); - if (!subview) + if (!subview) { return failure(); + } auto loc = op.getLoc(); IndexSet indexSet(loc, rewriter); @@ -266,8 +267,9 @@ struct FromMemRefSubView : public OpRewritePattern { llvm::SmallBitVector droppedDims = subview.getDroppedDims(); int targetIndex = 0; for (int i = 0; i < sourceRank; ++i) { - if (droppedDims.test(i)) + if (droppedDims.test(i)) { continue; + } rewriter.replaceAllUsesWith( op.getSizes()[targetIndex], getValueOrCreateConstantIndexOp(rewriter, loc, @@ -297,8 +299,9 @@ struct FromHalInterfaceBindingSubspan auto binding = op.getSource() .template getDefiningOp(); - if (!binding) + if (!binding) { return failure(); + } auto loc = op.getLoc(); FailureOr resultDescriptor = @@ -379,8 +382,9 @@ struct FromAllocation : public OpRewritePattern { LogicalResult matchAndRewrite(GetBufferDescriptorOp op, PatternRewriter &rewriter) const override { auto alloca = op.getSource().template getDefiningOp(); - if (!alloca) + if (!alloca) { return failure(); + } auto memRefType = cast(alloca.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure(op, "not identity allocation"); @@ -413,8 +417,9 @@ struct FromGlobal : public OpRewritePattern { LogicalResult matchAndRewrite(GetBufferDescriptorOp op, PatternRewriter &rewriter) const override { auto global = op.getSource().template getDefiningOp(); - if (!global) + if (!global) { return failure(); + } auto memRefType = cast(global.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure(op, "not identity allocation"); From 4da52876b277d8a0177d99c408c4fc85d742dcee Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 16:13:48 -0500 Subject: [PATCH 50/71] Add braces in API, Bindings, and infrastructure code. NFC. 6/n (#23148) --- .../compiler/API/Internal/CompilerDriver.cpp | 24 +++-- .../compiler/API/Internal/Diagnostics.cpp | 23 +++-- .../API/Internal/IREEOptToolEntryPoint.cpp | 9 +- .../API/Internal/IREEReduceToolEntryPoint.cpp | 3 +- .../API/Internal/LLDToolEntryPoint.cpp | 6 +- .../Transforms/ConvertStreamableOps.cpp | 6 +- .../Native/Transforms/WrapEntryPoints.cpp | 6 +- .../TFLite/Transforms/WrapEntryPoints.cpp | 6 +- .../iree/compiler/ConstEval/JitGlobals.cpp | 59 +++++++++---- .../src/iree/compiler/ConstEval/Runtime.cpp | 26 ++++-- .../DispatchCreation/BubbleUpExpandShapes.cpp | 3 +- .../CloneProducersIntoDispatchRegions.cpp | 6 +- .../DispatchCreation/CollapseDimensions.cpp | 3 +- .../DispatchCreation/ConvertTensorToFlow.cpp | 28 ++++-- .../DispatchCreation/ElementwiseOpFusion.cpp | 3 +- .../DispatchCreation/FoldUnitExtentDims.cpp | 3 +- .../DispatchCreation/FormDispatchRegions.cpp | 33 ++++--- .../FuseMultiUseElementwiseProducer.cpp | 3 +- .../DispatchCreation/FusionPreprocessing.cpp | 3 +- .../compiler/DispatchCreation/FusionUtils.cpp | 12 ++- ...MaterializeDefaultWorkgroupCountRegion.cpp | 6 +- .../compiler/DispatchCreation/SetEncoding.cpp | 3 +- .../TensorPadToTensorInsertSlice.cpp | 3 +- .../DispatchCreation/TransposeGenericOps.cpp | 18 ++-- .../ExternalInterfaces/UtilExternalModels.cpp | 6 +- .../Convert1X1FilterConv2DToMatmul.cpp | 6 +- ...ConvertStridedContractionToContraction.cpp | 24 +++-- .../GlobalOptimization/DecomposeConcat.cpp | 3 +- .../DetachElementwiseFromNamedOps.cpp | 18 ++-- .../GlobalOptimization/ExpandTensorShapes.cpp | 37 +++++--- .../GlobalLoopInvariantCodeMotion.cpp | 18 ++-- .../MaterializeHomogeneousEncodings.cpp | 3 +- .../GlobalOptimization/OptimizeNumerics.cpp | 9 +- .../PropagateLinalgTranspose.cpp | 3 +- .../QuantizedConvToConv.cpp | 6 +- .../GlobalOptimization/RaiseSpecialOps.cpp | 9 +- .../compiler/GlobalOptimization/Utils.cpp | 3 +- .../Common/AutoInputConversionPipeline.cpp | 6 +- .../Common/ConvertPrimitiveType.cpp | 21 +++-- .../Common/ImportMLProgram.cpp | 31 ++++--- .../Common/SanitizeModuleNames.cpp | 3 +- .../Check/Conversion/ConversionPatterns.cpp | 3 +- .../Conversion/HALToHALInline/Patterns.cpp | 6 +- .../Conversion/StreamToHALInline/Patterns.cpp | 6 +- .../Modules/HAL/Inline/IR/HALInlineOps.cpp | 6 +- .../Conversion/HALLoaderToVM/Patterns.cpp | 3 +- .../Modules/HAL/Loader/IR/HALLoaderOps.cpp | 3 +- .../IO/Parameters/Transforms/ArchiveUtils.cpp | 6 +- .../IO/Parameters/Transforms/ArchiveUtils.h | 5 +- .../Transforms/ExportParameters.cpp | 21 +++-- .../GenerateSplatParameterArchive.cpp | 12 ++- .../Transforms/ImportParameters.cpp | 39 ++++++--- .../src/iree/compiler/Pipelines/Pipelines.cpp | 87 ++++++++++++------- .../iree/compiler/PluginAPI/PluginManager.cpp | 9 +- .../Preprocessing/Common/ApplyPDLPatterns.cpp | 9 +- .../Common/ConvertConv2DToImg2Col.cpp | 12 ++- .../ConvertConvFilterToChannelsLast.cpp | 3 +- .../Common/ConvertConvToChannelsLast.cpp | 12 ++- .../Preprocessing/Common/InterpreterPass.cpp | 3 +- .../Preprocessing/Common/PadLinalgOps.cpp | 15 ++-- .../Preprocessing/Common/PadToIntrinsics.cpp | 35 +++++--- .../iree/compiler/Reducer/Framework/Delta.cpp | 3 +- .../compiler/Reducer/Framework/WorkItem.h | 3 +- .../Strategies/ReduceLinalgOnTensorsDelta.cpp | 18 ++-- .../iree/compiler/Utils/ConversionUtils.cpp | 9 +- .../iree/compiler/Utils/EquivalenceUtils.cpp | 57 ++++++++---- .../src/iree/compiler/Utils/FlatbufferUtils.h | 12 ++- compiler/src/iree/compiler/Utils/Indexing.cpp | 6 +- .../src/iree/compiler/Utils/ModuleUtils.cpp | 18 ++-- .../src/iree/compiler/Utils/OptionUtils.cpp | 6 +- .../src/iree/compiler/Utils/OptionUtils.h | 12 ++- .../src/iree/compiler/Utils/ToolUtils.cpp | 21 +++-- 72 files changed, 641 insertions(+), 320 deletions(-) diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index 23770bcccdf0..1a22b8a05779 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -533,8 +533,9 @@ Error *Source::split(void (*callback)(iree_compiler_source_t *source, SmallVector rawSubBuffers; // Split dropping the last checkLen chars to enable flagging near misses. origMemBuffer->getBuffer().split(rawSubBuffers, splitMarker); - if (rawSubBuffers.empty()) + if (rawSubBuffers.empty()) { return nullptr; + } for (StringRef subBuffer : rawSubBuffers) { auto splitLoc = SMLoc::getFromPointer(subBuffer.data()); @@ -696,8 +697,9 @@ Error *Output::openMembuffer() { } void Output::keep() { - if (outputFile) + if (outputFile) { outputFile->keep(); + } } // Invocation corresponds to iree_compiler_invocation_t @@ -915,8 +917,9 @@ bool Invocation::importModule(Operation *inputModule, bool steal) { } Operation *Invocation::exportModule() { - if (!parsedModuleIsOwned) + if (!parsedModuleIsOwned) { return nullptr; + } parsedModuleIsOwned = false; return parsedModule; } @@ -960,14 +963,16 @@ bool Invocation::getCompilationPhase(IREEVMPipelinePhase &compileFrom, void Invocation::dumpCompilationPhase(IREEVMPipelinePhase phase, OpPassManager &passManager) { - if (!parsedModule || dumpCompilationPhasesTo.empty()) + if (!parsedModule || dumpCompilationPhasesTo.empty()) { return; + } std::string phaseName; enumerateIREEVMPipelinePhases( [&](IREEVMPipelinePhase enumeratedPhase, StringRef name, StringRef desc) { - if (enumeratedPhase == phase) + if (enumeratedPhase == phase) { phaseName = name; + } }); std::string fileName = @@ -1081,8 +1086,9 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { bool Invocation::runTextualPassPipeline(const char *textPassPipeline) { auto passManager = createPassManager(); if (failed(mlir::parsePassPipeline(textPassPipeline, *passManager, - llvm::errs()))) + llvm::errs()))) { return false; + } if (failed(passManager->run(parsedModule))) { return false; } @@ -1096,8 +1102,9 @@ Error *Invocation::outputIR(Output &output) { Error *Invocation::outputIRBytecode(Output &output, int bytecodeVersion) { mlir::BytecodeWriterConfig config; - if (bytecodeVersion >= 0) + if (bytecodeVersion >= 0) { config.setDesiredBytecodeVersion(bytecodeVersion); + } if (failed(mlir::writeBytecodeToFile(parsedModule, *output.outputStream, config))) { return new Error("illegal bytecode version requested"); @@ -1202,8 +1209,9 @@ void llvmVersionPrinter(llvm::raw_ostream &os) { #endif #if LLVM_VERSION_PRINTER_SHOW_HOST_TARGET_INFO std::string CPU = std::string(llvm::sys::getHostCPUName()); - if (CPU == "generic") + if (CPU == "generic") { CPU = "(unknown)"; + } os << ".\n" << " Default target: " << llvm::sys::getDefaultTargetTriple() << '\n' << " Host CPU: " << CPU; diff --git a/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp b/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp index cd279d9fe7eb..2a4aa1d7ad35 100644 --- a/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp +++ b/compiler/src/iree/compiler/API/Internal/Diagnostics.cpp @@ -22,10 +22,12 @@ namespace mlir::iree_compiler::embed { namespace { /// Return a processable CallSiteLoc from the given location. std::optional getCallSiteLoc(Location loc) { - if (auto callLoc = dyn_cast(loc)) + if (auto callLoc = dyn_cast(loc)) { return callLoc; - if (auto nameLoc = dyn_cast(loc)) + } + if (auto nameLoc = dyn_cast(loc)) { return getCallSiteLoc(cast(loc).getChildLoc()); + } if (auto fusedLoc = dyn_cast(loc)) { for (auto subLoc : cast(loc).getLocations()) { if (auto callLoc = getCallSiteLoc(subLoc)) { @@ -49,9 +51,11 @@ std::optional findLocToShow(Location loc) { .Case([&](FusedLoc fusedLoc) -> std::optional { // Fused location is unique in that we try to find a sub-location to // show, rather than the top-level location itself. - for (Location childLoc : fusedLoc.getLocations()) - if (std::optional showableLoc = findLocToShow(childLoc)) + for (Location childLoc : fusedLoc.getLocations()) { + if (std::optional showableLoc = findLocToShow(childLoc)) { return showableLoc; + } + } return std::nullopt; }) .Case([&](NameLoc nameLoc) -> std::optional { @@ -105,8 +109,9 @@ LogicalResult FormattingDiagnosticHandler::emit(Diagnostic &diag) { // Assemble location fragments. SmallVector> locationStack; auto addLocToStack = [&](Location loc, StringRef locContext) { - if (std::optional showableLoc = findLocToShow(loc)) + if (std::optional showableLoc = findLocToShow(loc)) { locationStack.emplace_back(*showableLoc, locContext); + } }; // Add locations to display for this diagnostic. @@ -121,10 +126,11 @@ LogicalResult FormattingDiagnosticHandler::emit(Diagnostic &diag) { const unsigned callStackLimit = 50; for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { addLocToStack(loc, "called from"); - if ((callLoc = getCallSiteLoc(loc))) + if ((callLoc = getCallSiteLoc(loc))) { loc = callLoc->getCaller(); - else + } else { break; + } } } @@ -134,8 +140,9 @@ LogicalResult FormattingDiagnosticHandler::emit(Diagnostic &diag) { appendDiag(diag.getLocation(), diag.str(), diag.getSeverity()); } else { appendDiag(locationStack.front().first, diag.str(), diag.getSeverity()); - for (auto &it : llvm::drop_begin(locationStack)) + for (auto &it : llvm::drop_begin(locationStack)) { appendDiag(it.first, it.second, DiagnosticSeverity::Note); + } } // Append each of the notes. diff --git a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp index 8c981a2b0e29..001680d9d52d 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp @@ -93,8 +93,9 @@ static LogicalResult ireeOptMainFromCL(int argc, char **argv, auto localBinder = mlir::iree_compiler::OptionsBinder::local(); mlir::iree_compiler::PluginManagerSession pluginSession( pluginManager, localBinder, pluginManagerOptions); - if (failed(pluginSession.initializePlugins())) + if (failed(pluginSession.initializePlugins())) { return failure(); + } pluginSession.registerDialects(registry); // In the normal compiler flow, activated plugins maintain a scoped registry @@ -127,9 +128,10 @@ static LogicalResult ireeOptMainFromCL(int argc, char **argv, // and the process "appears to be stuck". Print a message to let the user know // about it! if (inputFilename == "-" && - sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) + sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) { llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " "interrupt)\n"; + } // Set up the input file. std::string errorMessage; @@ -144,8 +146,9 @@ static LogicalResult ireeOptMainFromCL(int argc, char **argv, llvm::errs() << errorMessage << "\n"; return failure(); } - if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) + if (failed(MlirOptMain(output->os(), std::move(file), registry, config))) { return failure(); + } // Keep the output file if the invocation of MlirOptMain was successful. output->keep(); diff --git a/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp index 4cb61ec3c421..3b0e4f017f04 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEReduceToolEntryPoint.cpp @@ -84,9 +84,10 @@ static LogicalResult ireeReduceMainFromCL(int argc, char **argv, // and the process "appears to be stuck". Print a message to let the user know // about it! if (inputFilename == "-" && - sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) + sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) { llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " "interrupt)\n"; + } OwningOpRef module = loadModule(registry, inputFilename); diff --git a/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp index b61f6f6cba0f..8c492f4d24b6 100644 --- a/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/LLDToolEntryPoint.cpp @@ -71,11 +71,13 @@ static Flavor getFlavor(StringRef s) { static Flavor parseFlavor(std::vector &v) { // Parse -flavor option. if (v.size() > 1 && v[1] == StringRef("-flavor")) { - if (v.size() <= 2) + if (v.size() <= 2) { die("missing arg value for '-flavor'"); + } Flavor f = getFlavor(v[2]); - if (f == Invalid) + if (f == Invalid) { die("Unknown flavor: " + StringRef(v[2])); + } v.erase(v.begin() + 1, v.begin() + 3); return f; } diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp index d404dba88dba..0d258f8c204a 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/ConvertStreamableOps.cpp @@ -280,8 +280,9 @@ static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc, for (auto [i, resultType] : llvm::enumerate(callOp.getResultTypes())) { if (auto shapedType = dyn_cast(resultType)) { const auto &resultDimArgs = streamableFunc.resultDimArgs[i]; - if (resultDimArgs.empty()) + if (resultDimArgs.empty()) { continue; + } if (resultDimArgs.front() == kTiedDim) { // Source from a tied operand. Types must match exactly. assert(streamableFunc.tiedOperands[i] != @@ -360,8 +361,9 @@ class ConvertStreamableOpsPass for (auto originalFuncOp : originalFuncOps) { auto streamableFuncOr = convertStreamableFunc(moduleOp, originalFuncOp, symbolTable); - if (!streamableFuncOr.has_value()) + if (!streamableFuncOr.has_value()) { return signalPassFailure(); + } auto streamableFunc = std::move(streamableFuncOr).value(); streamableFuncs[streamableFunc.funcOp.getName()] = std::move(streamableFunc); diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index 9a61a5419a90..0fb8ba7e0375 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -400,14 +400,16 @@ static StringAttr inferResultName(MLIRContext *context, int index, } static DictionaryAttr getIOAttr(ArrayAttr allAttrs, unsigned i) { - if (!allAttrs) + if (!allAttrs) { return nullptr; + } return cast_or_null(allAttrs.getValue()[i]); } static void formatIOAttr(DictionaryAttr attrs, llvm::raw_ostream &os) { - if (!attrs || attrs.empty()) + if (!attrs || attrs.empty()) { return; + } auto shouldIncludeAttr = [](const NamedAttribute &attr) { return attr.getName().getValue() != "iree.abi.name"; }; diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index ec0c6bc606c4..1d755871f224 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp @@ -335,8 +335,9 @@ class WrapEntryPointsPass auto shapeType = dynamicDims.tensorType; unsigned dynamicDimIdx = 0; for (unsigned i = 0; i < shapeType.getRank(); ++i) { - if (!shapeType.isDynamicDim(i)) + if (!shapeType.isDynamicDim(i)) { continue; + } auto dimValue = IREE::Util::ListGetOp::create( builder, loc, builder.getIndexType(), listValue, builder.createOrFold(loc, i)) @@ -492,8 +493,9 @@ class WrapEntryPointsPass wrapperFuncOp.setAllResultAttrs(resultAttrDict); populateReflectionAttrs(entryFuncOp, wrapperFuncOp); - if (auto affinityAttr = entryFuncOp->getAttr("stream.affinity")) + if (auto affinityAttr = entryFuncOp->getAttr("stream.affinity")) { wrapperFuncOp->setAttr("stream.affinity", affinityAttr); + } // Call the entryFuncOp and return the results. // If we wanted to perform additional work here to invalidate cached shapes diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 882a5451f7cf..4c6cbda20509 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -49,10 +49,12 @@ static llvm::cl::opt clEnableDebug( namespace { static bool isDebugEnabled() { - if (clEnableDebug) + if (clEnableDebug) { return true; - if (std::getenv("IREE_COMPILER_DEBUG_CONSTEVAL")) + } + if (std::getenv("IREE_COMPILER_DEBUG_CONSTEVAL")) { return true; + } return false; } @@ -83,8 +85,9 @@ struct CompileOptions { }; static inline bool isAttrParameterized(Attribute attr) { - if (!attr) + if (!attr) { return false; + } return !isa(attr) && !isa(attr) && !isa(attr); } @@ -94,8 +97,9 @@ static inline bool isAccessorParameterized(const SymbolTable &moduleSymbols, AccessorTy op) { auto global = moduleSymbols.lookup(op.getGlobalName()); - if (!global) + if (!global) { return true; + } return isAttrParameterized(global.getGlobalInitialValue()); } @@ -118,8 +122,9 @@ static bool isParameterized(const SymbolTable &moduleSymbols, return isAttrParameterized(accessor.getValueAttr()); }) .Default([=](auto) { return false; }); - if (parameterized) + if (parameterized) { return WalkResult::interrupt(); + } return WalkResult::advance(); }); return res.wasInterrupted(); @@ -157,8 +162,9 @@ class InitializationAnalysis { Availability getInitializerAvailability(IREE::Util::InitializerOpInterface initializerOp) { auto it = initializerAvailability.find(initializerOp); - if (it == initializerAvailability.end()) + if (it == initializerAvailability.end()) { return Availability::Unknown; + } return it->second; } @@ -193,11 +199,13 @@ class InitializationAnalysis { Availability queryGlobalInitializationStatus(StringRef globalName, unsigned opOrdinal) { auto &timeline = globalTimelines[globalName]; - if (timeline.empty()) + if (timeline.empty()) { return Availability::Unknown; + } for (auto &timepoint : timeline) { - if (timepoint.first > opOrdinal) + if (timepoint.first > opOrdinal) { return timepoint.second; + } } return timeline.back().second; } @@ -222,10 +230,11 @@ class InitializationAnalysis { availability = static_cast( std::min(static_cast(availability), static_cast(newAvailability))); - if (previousAvailability != availability) + if (previousAvailability != availability) { emitDebugWarning( initializerOp.getLoc(), [&](InFlightDiagnostic &diagnostic) { diagnostic << reason; }); + } }; if (initializerOp->getRegions().size() != 1 || @@ -404,8 +413,9 @@ static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp, OpBuilder &moduleBuilder) { // Gather all symbol uses within the function. auto uses = SymbolTable::getSymbolUses(funcOp); - if (!uses.has_value()) + if (!uses.has_value()) { return success(); + } // Verify that all uses are to object-like types we can clone. for (auto use : uses.value()) { @@ -416,14 +426,16 @@ static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp, return use.getUser()->emitOpError() << "references undefined symbol " << use.getSymbolRef(); } - if (!objectOp->hasTrait()) + if (!objectOp->hasTrait()) { continue; + } // Check if the object exists in the target yet. Since we create the // target we know there should be no conflicts: the only symbols with the // same name will be already cloned copies of the same source. - if (targetSymbolTable.lookup(objectNameAttr)) + if (targetSymbolTable.lookup(objectNameAttr)) { continue; + } // Clone the object. It's isolated and safe to copy wholesale. auto *clonedOp = moduleBuilder.clone(*objectOp); @@ -464,16 +476,18 @@ class ProgramBuilder { // compile dynamic initializers. auto availability = initializationAnalysis.getInitializerAvailability(initializerOp); - if (availability != InitializationAnalysis::Availability::Compiler) + if (availability != InitializationAnalysis::Availability::Compiler) { return failure(); + } OpBuilder moduleBuilder = OpBuilder::atBlockEnd(targetModuleOp.getBody()); // Find any object-like symbol references used by the initializer and // clone them. if (failed(cloneUsedObjects(initializerOp, sourceSymbolTable, - targetSymbolTable, moduleBuilder))) + targetSymbolTable, moduleBuilder))) { return failure(); + } auto funcOp = IREE::Util::FuncOp::create( moduleBuilder, initializerOp.getLoc(), "jit_eval", @@ -536,8 +550,9 @@ class ProgramBuilder { for (auto constantOp : funcOp.getOps()) { auto tensorType = dyn_cast(constantOp.getResult().getType()); auto elementsAttr = dyn_cast(constantOp.getValue()); - if (!tensorType || !elementsAttr) + if (!tensorType || !elementsAttr) { continue; + } if (!supportedTypes.supportsType(tensorType)) { emitDebugWarning(funcOp.getLoc(), [&](InFlightDiagnostic &diagnostic) { diagnostic << "skipping consteval initializer: unsupported type for " @@ -668,15 +683,18 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { FunctionCall call(binary, jitFunction.argumentBindings.size(), jitFunction.resultBindings.size()); - if (failed(call.initialize(jitFunction.loc))) + if (failed(call.initialize(jitFunction.loc))) { return failure(); + } // Convert arguments. for (ArgumentBinding &arg : jitFunction.argumentBindings) { switch (arg.getType()) { case ArgumentBinding::Type::ElementsAttr: { - if (failed(call.addArgument(jitFunction.loc, arg.getElementsAttr()))) + if (failed( + call.addArgument(jitFunction.loc, arg.getElementsAttr()))) { return failure(); + } break; } case ArgumentBinding::Type::GlobalOp: { @@ -687,8 +705,10 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { "invalid: global " << arg.getGlobalOp().getGlobalName() << " has no value"; } - if (failed(call.addArgument(arg.getGlobalOp().getLoc(), globalValue))) + if (failed( + call.addArgument(arg.getGlobalOp().getLoc(), globalValue))) { return failure(); + } break; } } @@ -706,8 +726,9 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { TypedAttr attr; if (failed(call.getResultAsAttr( resultBinding.getGlobalOp().getLoc(), it.index(), - resultBinding.getGlobalOp().getGlobalType(), attr))) + resultBinding.getGlobalOp().getGlobalType(), attr))) { return failure(); + } resultBinding.getGlobalOp().setGlobalInitialValue(attr); break; } diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp index a7c48f011580..db59da95d0cd 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp +++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp @@ -22,8 +22,9 @@ namespace { LogicalResult handleRuntimeError(Location loc, iree_status_t status, bool freeStatus = true) { - if (iree_status_is_ok(status)) + if (iree_status_is_ok(status)) { return success(); + } std::string statusString = iree::Status::ToString(status); if (freeStatus) { iree_status_ignore(status); @@ -213,8 +214,9 @@ FunctionCall::importSerializableAttr( LogicalResult FunctionCall::addBufferArgumentAttr( Location loc, IREE::Util::SerializableAttrInterface serializableAttr) { auto buffer = importSerializableAttr(loc, serializableAttr); - if (failed(buffer)) + if (failed(buffer)) { return failure(); + } return handleRuntimeError( loc, iree_vm_list_push_ref_move(inputs.get(), std::move(*buffer))); } @@ -230,14 +232,16 @@ LogicalResult FunctionCall::addBufferViewArgumentAttr( shape[i] = shapedType.getDimSize(i); } iree_hal_element_type_t elementType = IREE_HAL_ELEMENT_TYPE_NONE; - if (failed( - convertToElementType(loc, shapedType.getElementType(), &elementType))) + if (failed(convertToElementType(loc, shapedType.getElementType(), + &elementType))) { return failure(); + } // Import buffer contents. auto buffer = importSerializableAttr(loc, serializableAttr); - if (failed(buffer)) + if (failed(buffer)) { return failure(); + } // Construct buffer view. iree::vm::ref bufferView; @@ -245,8 +249,9 @@ LogicalResult FunctionCall::addBufferViewArgumentAttr( loc, iree_hal_buffer_view_create(buffer->get(), rank, shape, elementType, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - iree_allocator_system(), &bufferView)))) + iree_allocator_system(), &bufferView)))) { return failure(); + } return handleRuntimeError( loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bufferView))); @@ -351,12 +356,14 @@ LogicalResult FunctionCall::getResultAsAttr(Location loc, size_t index, Type mlirType, TypedAttr &outAttr) { iree_vm_variant_t variant = iree_vm_variant_empty(); if (failed(handleRuntimeError(loc, iree_vm_list_get_variant_assign( - outputs.get(), index, &variant)))) + outputs.get(), index, &variant)))) { return failure(); + } outAttr = binary.convertVariantToAttribute(loc, variant, mlirType); - if (!outAttr) + if (!outAttr) { return failure(); + } return success(); } @@ -400,8 +407,9 @@ TypedAttr CompiledBinary::convertVariantToAttribute(Location loc, iree_hal_element_type_t halElementType = iree_hal_buffer_view_element_type(bufferView); Type elementType = mapElementType(loc, halElementType); - if (!elementType) + if (!elementType) { return {}; + } auto tensorType = RankedTensorType::get(shape, elementType); diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 656cc8984f70..3df39eb8ecb2 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -146,8 +146,9 @@ struct SwapExtractSliceOfFill final LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, PatternRewriter &rewriter) const override { auto fillOp = extractOp.getSource().getDefiningOp(); - if (!fillOp) + if (!fillOp) { return failure(); + } auto newExtractOp = tensor::ExtractSliceOp::create( rewriter, extractOp.getLoc(), extractOp.getType(), diff --git a/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp index 25b7c1915188..352f0dd5bcc5 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CloneProducersIntoDispatchRegions.cpp @@ -35,8 +35,9 @@ struct CloneProducersIntoDispatchRegionsPass final IREE::Flow::ClonableIntoDispatchOptions options; options.aggressive = aggressive; funcOp->walk([&](IREE::Flow::DispatchRegionOp regionOp) { - if (failed(cloneProducersToRegion(rewriter, regionOp, options))) + if (failed(cloneProducersToRegion(rewriter, regionOp, options))) { return signalPassFailure(); + } }); funcOp->walk([&](Operation *op) { @@ -58,8 +59,9 @@ struct CloneProducersIntoDispatchRegionsPass final // Rerun the cloning again to move still clonable operations into // dispatches. funcOp->walk([&](IREE::Flow::DispatchRegionOp regionOp) { - if (failed(cloneProducersToRegion(rewriter, regionOp, options))) + if (failed(cloneProducersToRegion(rewriter, regionOp, options))) { return signalPassFailure(); + } }); } }; diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp index 28ae827d99a3..a2bfeb53bff6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp @@ -291,8 +291,9 @@ populateReassocAndMaps(tensor::ExtractSliceOp sliceOp, auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, OpFoldResult sliceSize, int64_t inputDim) { - if (!isZeroInteger(offset)) + if (!isZeroInteger(offset)) { return false; + } ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); FailureOr maybeEqual = diff --git a/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp b/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp index 375ffe06cff2..56eea555be8e 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ConvertTensorToFlow.cpp @@ -44,20 +44,25 @@ static FailureOr wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter, Operation *op) { SmallVector dimOps = rewriter.getTensorDimOps(); - if (failed(IREE::Flow::simplifyDimOps(rewriter, rewriter.getTensorDimOps()))) + if (failed( + IREE::Flow::simplifyDimOps(rewriter, rewriter.getTensorDimOps()))) { return failure(); + } // Wrap operation. auto regionOp = IREE::Flow::wrapOpInDispatchRegion(rewriter, op); - if (failed(regionOp)) + if (failed(regionOp)) { return failure(); - if (failed(cloneProducersToRegion(rewriter, *regionOp))) + } + if (failed(cloneProducersToRegion(rewriter, *regionOp))) { return failure(); + } auto workgroupsOp = IREE::Flow::rewriteFlowDispatchRegionToFlowDispatchWorkgroups(*regionOp, rewriter); - if (failed(workgroupsOp)) + if (failed(workgroupsOp)) { return failure(); + } return *workgroupsOp; } @@ -68,8 +73,9 @@ wrapInWorkgroupsOp(mlir::TensorDimTrackingRewriter &rewriter, SmallVector result; for (Operation *rootOp : rootOps) { auto workgroupsOp = wrapInWorkgroupsOp(rewriter, rootOp); - if (failed(workgroupsOp)) + if (failed(workgroupsOp)) { return failure(); + } result.push_back(*workgroupsOp); } return result; @@ -84,8 +90,9 @@ static FailureOr convertInsertSliceOps( // Find eligible InsertSliceOps. SmallVector insertSliceOps; funcOp.walk([&](tensor::InsertSliceOp op) { - if (!isInDispatchRegion(op)) + if (!isInDispatchRegion(op)) { insertSliceOps.push_back(op); + } }); // Rewrite InsertSliceOps to FlowUpdateOps. @@ -102,8 +109,9 @@ static FailureOr convertInsertSliceOps( // Create a DispatchWorkgroupsOp for every remaining InsertSliceOp. FailureOr> newWorkgroupsOps = wrapInWorkgroupsOp(rewriter, remainingInsertSliceOps); - if (failed(newWorkgroupsOps)) + if (failed(newWorkgroupsOps)) { return failure(); + } workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end()); return numRemainingInsertSliceOps; @@ -118,8 +126,9 @@ static FailureOr convertExtractSliceOps( // Find eligible ExtractSliceOps. SmallVector extractSliceOps; funcOp.walk([&](tensor::ExtractSliceOp op) { - if (!isInDispatchRegion(op)) + if (!isInDispatchRegion(op)) { extractSliceOps.push_back(op); + } }); // Rewrite ExtractSliceOps to FlowSliceOps. @@ -137,8 +146,9 @@ static FailureOr convertExtractSliceOps( // Create a DispatchWorkgroupsOp for every remaining ExtractSliceOp. FailureOr> newWorkgroupsOps = wrapInWorkgroupsOp(rewriter, remainingExtractSliceOps); - if (failed(newWorkgroupsOps)) + if (failed(newWorkgroupsOps)) { return failure(); + } workgroupsOps.append(newWorkgroupsOps->begin(), newWorkgroupsOps->end()); return numRemainingExtractSliceOps; diff --git a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp index e463015c33f7..c0bff40d7084 100644 --- a/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp @@ -162,8 +162,9 @@ void ElementwiseOpFusionPass::runOnOperation() { operands.insert(std::next(consumer->operand_begin(), fusedOperand->getOperandNumber() + 1), consumer->operand_end()); - if (operands.size() >= kIreeMaxOperandCount) + if (operands.size() >= kIreeMaxOperandCount) { return false; + } ElementwiseOpsFusabilityOptions options; options.fuseMultiReduction = fuseMultiReduction; diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index 5c5e08f924bd..6a7870bce3a8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -390,8 +390,9 @@ foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global, } auto newGlobalType = globalType.clone(newShape); auto initialValue = global.getGlobalInitialValue(); - if (!initialValue) + if (!initialValue) { return success(); + } // TODO: Handle other cases auto newInitialValue = llvm::TypeSwitch(initialValue) diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 1695cda731fc..9df5a62a2d96 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -73,8 +73,9 @@ static llvm::SmallBitVector getOuterParallelLoops(Operation *op) { interfaceOp.getLoopIteratorTypes(); llvm::SmallBitVector parallelLoops(loopIteratorTypes.size()); for (auto iteratorType : llvm::enumerate(loopIteratorTypes)) { - if (iteratorType.value() != utils::IteratorType::parallel) + if (iteratorType.value() != utils::IteratorType::parallel) { break; + } parallelLoops.set(iteratorType.index()); } return parallelLoops; @@ -565,8 +566,9 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, // Check that the owner is a `generic` op. auto genericOp = dyn_cast(inOperand->getOwner()); - if (!genericOp) + if (!genericOp) { return false; + } // All loops to be parallel. if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) { @@ -574,13 +576,15 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, } /// The input operand cannot be an init operand already. - if (genericOp.isDpsInit(inOperand)) + if (genericOp.isDpsInit(inOperand)) { return false; + } // If the init operand value is used it cannot be reused for the input // operand. - if (genericOp.payloadUsesValueFromOperand(initOperand)) + if (genericOp.payloadUsesValueFromOperand(initOperand)) { return false; + } // Indexing map used to access the input and init have to match. if (genericOp.getMatchingIndexingMap(inOperand) != @@ -590,8 +594,9 @@ static bool canUseInOperandAsInitOperand(OpOperand *inOperand, // Types have to match for the input operand to reuse the buffer from the init // operand - if (inOperand->get().getType() != initOperand->get().getType()) + if (inOperand->get().getType() != initOperand->get().getType()) { return false; + } return true; } @@ -676,8 +681,9 @@ isFusableWithConsumer(OpOperand &fusedOperand, const FusionTracker &tracker, dyn_cast(producer); auto consumerFusionOp = dyn_cast(consumer); - if (!producerFusionOp || !consumerFusionOp) + if (!producerFusionOp || !consumerFusionOp) { return false; + } // Check that the consumer is all parallel. if (consumerFusionOp.getNumLoops() != @@ -727,8 +733,9 @@ isFusableWithConsumer(OpOperand &fusedOperand, const FusionTracker &tracker, } for (OpOperand *inputOperand : consumerDstOp.getDpsInputOperands()) { - if (inputOperand->get().getDefiningOp() != producer) + if (inputOperand->get().getDefiningOp() != producer) { continue; + } if (isa(producer) && !llvm::any_of( consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) { @@ -876,8 +883,9 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, Operation *candidate = worklist.pop_back_val(); for (OpOperand &operand : candidate->getOpOperands()) { Operation *producer = operand.get().getDefiningOp(); - if (!producer) + if (!producer) { continue; + } if (IREE::Flow::isClonableIntoDispatchOp(producer, clonableOptions) || tracker.isFusedOp(producer) || tracker.isRootOp(producer)) { continue; @@ -890,8 +898,9 @@ fuseRootsWithProducers(MLIRContext *context, Operation *root, SmallVector fusableUses = getFusableUses(context, producer, dominanceInfo, /*aggressiveFusion=*/options.aggressiveFusion); - if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) + if (fusableUses.empty() || fusableUses.front()->getOwner() != candidate) { continue; + } tracker.appendToFusionGroup(producer, fusionGroup); worklist.push_back(producer); @@ -926,8 +935,9 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, } // Start with a root operation and fuse its producers. - if (tracker.isFusedOp(&op) || !isRootLikeOp(&op)) + if (tracker.isFusedOp(&op) || !isRootLikeOp(&op)) { continue; + } FusionGroup &newGroup = tracker.createFusionGroup(context, &op); fuseRootsWithProducers(context, &op, newGroup, dominanceInfo, options, tracker, @@ -950,8 +960,9 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, SmallVector roots; for (Operation &op : llvm::reverse(block)) { // If it is part of a fusion group or root op, ignore it. - if (tracker.isFusedOp(&op) || tracker.isRootOp(&op)) + if (tracker.isFusedOp(&op) || tracker.isRootOp(&op)) { continue; + } // Only look for Linalg ops here. Avoid moving `linalg.fill` that aren't // fused with anything else into their own dispatches since it is better // to convert them to splats. Also avoid moving dequantization-like ops diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index ae1ef9271693..3510b6209d4b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -287,8 +287,9 @@ void FuseMultiUseElementwiseProducerPass::runOnOperation() { funcOp->emitError("failed to fuse multi-use producers"); return signalPassFailure(); } - if (numOfFusableCandidates.value() == 0) + if (numOfFusableCandidates.value() == 0) { break; + } } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp index 8d404bdc6074..0136587fbed6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionPreprocessing.cpp @@ -50,8 +50,9 @@ struct ElementwiseOpInterchangePattern final LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 || - genericOp.getNumDpsInputs() == 0) + genericOp.getNumDpsInputs() == 0) { return failure(); + } // All input maps must be equal and non-identity. All maps, including // output, must be be permutations. Permutation maps are checked by diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 1dd745fd964a..7f226201e986 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -20,16 +20,19 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, ElementwiseOpsFusabilityOptions options) { Operation *producerOp = fusedOperand->get().getDefiningOp(); Operation *consumerOp = fusedOperand->getOwner(); - if (!producerOp) + if (!producerOp) { return false; + } // Check for i1 return types, if so aggressively fuse to avoid `i1` buffers. if (llvm::all_of(producerOp->getResultTypes(), [](Type t) { - if (t.isInteger(1)) + if (t.isInteger(1)) { return true; + } if (auto shapedType = dyn_cast(t)) { - if (shapedType.getElementType().isInteger(1)) + if (shapedType.getElementType().isInteger(1)) { return true; + } } return false; })) { @@ -38,8 +41,9 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, // If the generic op is "just" copy, then fuse always. Block &body = producerOp->getRegion(0).front(); - if (std::begin(body)->hasTrait()) + if (std::begin(body)->hasTrait()) { return true; + } auto linalgConsumerOp = dyn_cast(consumerOp); if (!linalgConsumerOp) { diff --git a/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp b/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp index 4d7ee1d7c58d..1936ba45bc68 100644 --- a/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/MaterializeDefaultWorkgroupCountRegion.cpp @@ -57,8 +57,9 @@ static LogicalResult createDefaultWorkgroupCountRegion( SmallVector workloadLocs; for (auto argument : workgroupsOp.getArguments()) { Type argumentType = argument.getType(); - if (!isa(argumentType)) + if (!isa(argumentType)) { continue; + } workload.push_back(argument); workloadTypes.push_back(argumentType); workloadLocs.push_back(argument.getLoc()); @@ -114,8 +115,9 @@ static LogicalResult createDefaultWorkgroupCountRegion( rewriter.setInsertionPointToStart(&body.front()); int ordinalNumber = 0; for (auto [index, operand] : llvm::enumerate(workgroupsOp.getArguments())) { - if (!isa(operand.getType())) + if (!isa(operand.getType())) { continue; + } BlockArgument arg = workgroupsOp.getInputBlockArgument(index); auto ordinalOp = IREE::TensorExt::DispatchWorkloadOrdinalOp::create( rewriter, loc, arg, rewriter.getIndexAttr(ordinalNumber++)); diff --git a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp index 4961db09ed1f..721dbed57803 100644 --- a/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp @@ -132,8 +132,9 @@ struct FoldFillWithSetEncoding final LogicalResult matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, PatternRewriter &rewriter) const override { auto fillOp = encodingOp.getSource().getDefiningOp(); - if (!fillOp) + if (!fillOp) { return failure(); + } // Create a new fill op, with outs being defined by a new `tensor.empty` op. RankedTensorType encodingType = encodingOp.getResultType(); diff --git a/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp b/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp index d31ccb781d9c..dceb76824250 100644 --- a/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/TensorPadToTensorInsertSlice.cpp @@ -49,8 +49,9 @@ struct TensorPadOpConversion : public OpRewritePattern { // scalar that is not one of the arguments of the linalg operation. Region ®ion = padTensorOp.getRegion(); Block &block = region.front(); - if (!llvm::hasSingleElement(block)) + if (!llvm::hasSingleElement(block)) { return failure(); + } auto yieldOp = cast(block.getTerminator()); Value yieldVal = yieldOp.getValue(); if (llvm::any_of(block.getArguments(), diff --git a/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp b/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp index b2f5ad28dc4a..edd0c233d00a 100644 --- a/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/TransposeGenericOps.cpp @@ -37,18 +37,21 @@ struct MakeReductionInnermostPattern final SmallVector interchange; bool needInterchange = false; unsigned numParallelLoop = genericOp.getNumParallelLoops(); - if (numParallelLoop == 0) + if (numParallelLoop == 0) { return failure(); + } for (auto iter : llvm::enumerate(genericOp.getIteratorTypesArray())) { if (linalg::isParallelIterator(iter.value())) { interchange.push_back(iter.index()); - if (iter.index() >= numParallelLoop) + if (iter.index() >= numParallelLoop) { needInterchange = true; + } } } // If all the parallel loops are outter loops skip the pattern. - if (!needInterchange) + if (!needInterchange) { return failure(); + } for (auto iter : llvm::enumerate(genericOp.getIteratorTypesArray())) { if (linalg::isReductionIterator(iter.value())) { interchange.push_back(iter.index()); @@ -83,8 +86,9 @@ struct TransposeGenericOpPattern final // elementwise op) with a single use. auto producer = operand->get().getDefiningOp(); if (!producer || !llvm::hasSingleElement(producer->getUsers()) || - linalg::isElementwise(producer)) + linalg::isElementwise(producer)) { continue; + } // check if the generic op has a non-identity map for the operand. auto indexingMap = genericOp.getMatchingIndexingMap(operand); @@ -93,11 +97,13 @@ struct TransposeGenericOpPattern final return rewriter.notifyMatchFailure(genericOp, "already normalized"); } // The map must be a permutation. If not, then look for other operand. - if (!indexingMap.isPermutation()) + if (!indexingMap.isPermutation()) { continue; + } - if (!mapForInterchange) + if (!mapForInterchange) { mapForInterchange = indexingMap; + } } if (!mapForInterchange) { diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index f39469f502d5..d7542c4d322f 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -175,8 +175,9 @@ struct GlobalOpInterfaceExternalModel IREE::Util::InliningPolicyAttrInterface getGlobalInliningPolicy(Operation *op) const { - if (op->hasAttr("noinline")) + if (op->hasAttr("noinline")) { return IREE::Util::InlineNeverAttr::get(op->getContext()); + } return {}; } void @@ -283,8 +284,9 @@ struct LinalgOpTiedOpInterface SmallVector getTiedResultOperandIndices(Operation *op) const { SmallVector result; - for (unsigned i = 0; i < op->getNumResults(); ++i) + for (unsigned i = 0; i < op->getNumResults(); ++i) { result.push_back(*getTiedResultOperandIndex(op, i)); + } return result; } }; diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp index 475acc17a19c..1445e0dce06d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp @@ -29,8 +29,9 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { PatternRewriter &rewriter) const override { auto filterShapeType = dyn_cast( convOp.getDpsInputOperand(1)->get().getType()); - if (!filterShapeType) + if (!filterShapeType) { return failure(); + } constexpr bool isNCHW = std::is_same_v; @@ -48,8 +49,9 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { constexpr int khLoopIndex = isNHWC ? 4 : 5; constexpr int kwLoopIndex = isNHWC ? 5 : 6; - if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1) + if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1) { return failure(); + } SmallVector dimReplacements; for (int i = 0; i < numLoops; i++) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp b/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp index aec860cd5a67..a7eb49b6111a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/ConvertStridedContractionToContraction.cpp @@ -27,15 +27,18 @@ class ConvertStridedContractionToContraction PatternRewriter &rewriter) const override { // Check if the generic op satisfies all other conditions for being a // contraction. - if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1) + if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1) { return failure(); - if (op.getNumReductionLoops() == 0) + } + if (op.getNumReductionLoops() == 0) { return failure(); + } if (!mlir::linalg::detail::isContractionBody( *op.getBlock(), [](Operation *first, Operation *second) { if ((isa(first) && isa(second)) || - (isa(first) && isa(second))) + (isa(first) && isa(second))) { return true; + } return false; })) { return failure(); @@ -54,16 +57,18 @@ class ConvertStridedContractionToContraction !resultMap.isProjectedPermutation()) { return failure(); } - if (inputMap.isProjectedPermutation()) + if (inputMap.isProjectedPermutation()) { return failure(); + } SmallVector staticShape = op.getStaticLoopRanges(); llvm::SmallDenseMap strides; SmallVector replacementExprs; Value input = op.getDpsInputs()[0]; auto inputTy = dyn_cast(input.getType()); - if (!inputTy) + if (!inputTy) { return failure(); + } SmallVector inputShape(inputTy.getShape()); replacementExprs.reserve(inputMap.getNumResults()); // Walk through input map and look for expressions of the form `dim * cst`. @@ -76,8 +81,9 @@ class ConvertStridedContractionToContraction // Look at binary op expressions. auto binexpr = dyn_cast(expr); // Fail if we see some unexpected kind of expression. - if (!binexpr) + if (!binexpr) { return failure(); + } auto rhs = dyn_cast(binexpr.getRHS()); auto lhs = dyn_cast(binexpr.getLHS()); // Binary expressions must be of the form `dim * cst`. @@ -87,15 +93,17 @@ class ConvertStridedContractionToContraction } strides.insert(std::pair(pos, rhs.getValue())); int64_t newSize = staticShape[lhs.getPosition()]; - if (newSize == ShapedType::kDynamic || newSize == 0) + if (newSize == ShapedType::kDynamic || newSize == 0) { return failure(); + } inputShape[pos] = newSize; replacementExprs.push_back(lhs); } // Fail if we don't have any work to do. - if (strides.empty()) + if (strides.empty()) { return failure(); + } mapRange[inputPos] = AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), diff --git a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp index 6ad30ce8e87c..cfef1f765698 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp @@ -54,8 +54,9 @@ struct TransposeInnerConcatenation : public OpRewritePattern { ArrayRef concatShape = concatType.getShape(); int64_t outerMostNonUnitDim = 0; while (outerMostNonUnitDim < concatOp.getRank()) { - if (concatShape[outerMostNonUnitDim] != 1) + if (concatShape[outerMostNonUnitDim] != 1) { break; + } outerMostNonUnitDim++; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp index 9b6cdffa8a64..f38ed8f8a312 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp @@ -40,8 +40,9 @@ struct DetachElementwisePattern !isa(*linalgOp)) { return failure(); } - if (!linalgOp.hasPureTensorSemantics()) + if (!linalgOp.hasPureTensorSemantics()) { return failure(); + } // Nothing to do if the output tensor operand is already a fill op. SmallVector outputOperands; @@ -52,8 +53,9 @@ struct DetachElementwisePattern } // Right now all the cases we see have one output. This can be relaxed once // we see multiple output ops. - if (outputOperands.size() != 1) + if (outputOperands.size() != 1) { return failure(); + } Value outputOperand = outputOperands.front()->get(); auto outsDefiningOp = outputOperand.getDefiningOp(); @@ -62,8 +64,9 @@ struct DetachElementwisePattern return failure(); } auto outputType = cast(outputOperand.getType()); - if (!outputType.getElementType().isIntOrFloat()) + if (!outputType.getElementType().isIntOrFloat()) { return failure(); + } auto elementType = outputType.getElementType(); Location loc = linalgOp.getLoc(); @@ -139,17 +142,20 @@ struct DetachSplatConstantOutsOperands for (auto outOperand : llvm::enumerate(dpsInterfaceOp.getDpsInits())) { auto constOp = outOperand.value().template getDefiningOp(); - if (!constOp) + if (!constOp) { continue; + } auto resultType = dyn_cast(constOp.getResult().getType()); - if (!resultType || !resultType.getElementType().isIntOrFloat()) + if (!resultType || !resultType.getElementType().isIntOrFloat()) { continue; + } auto attr = dyn_cast(constOp.getValue()); - if (!attr || !attr.isSplat()) + if (!attr || !attr.isSplat()) { continue; + } Location loc = constOp.getLoc(); Type elementType = resultType.getElementType(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp index 8fd07bef521a..d52968c0b53a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp @@ -121,8 +121,9 @@ static void expandType(Type type, SmallVectorImpl &newTypes) { // Expands tensors in the given |types| list to (tensor, dynamic dims...). // This could be changed to some iterator magic to avoid the alloc. static SmallVector expandTypes(TypeRange types) { - if (types.empty()) + if (types.empty()) { return {}; + } SmallVector newTypes; newTypes.reserve(types.size() * 2); for (auto type : types) { @@ -205,22 +206,25 @@ static void expandTensorDims(Operation *op, SymbolTable &symbolTable, static void expandRegion(Region ®ion, SymbolTable &symbolTable, ExpandedGlobalMap &globalMap, IndexSet &indexSet, TensorDimMap tensorDimMap) { - if (region.empty()) + if (region.empty()) { return; + } // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isDynamicTensor)) + if (!llvm::any_of(block.getArgumentTypes(), isDynamicTensor)) { continue; + } // Insert and build a list of expanded (tensor, dynamic dims...) tuples. SmallVector expansions; for (int i = block.getNumArguments() - 1; i >= 0; --i) { auto arg = block.getArgument(i); auto tensorType = dyn_cast(arg.getType()); - if (!tensorType || tensorType.hasStaticShape()) + if (!tensorType || tensorType.hasStaticShape()) { continue; + } ExpandedValue expandedValue; expandedValue.tensor = arg; for (unsigned j = 0; j < tensorType.getNumDynamicDims(); ++j) { @@ -302,8 +306,9 @@ static void retieResults(Operation *op, Operation *newOp, static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto &expandedGlobal = globalMap[op.getGlobalName()]; @@ -335,8 +340,9 @@ static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op, ExpandedGlobalMap &globalMap, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); builder.setInsertionPointAfter(op); auto expandedValue = consumeExpandedValue( @@ -395,13 +401,15 @@ static void expandFuncOp(IREE::Util::FuncOp op, SymbolTable &symbolTable, // %2 = flow.tensor.tie_shape %r : tensor{%rd} static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } // Ignore calls to public/external functions. auto calleeOp = symbolTable.lookup(op.getCallee()); - if (IREE::Util::isPublicOrExternal(calleeOp)) + if (IREE::Util::isPublicOrExternal(calleeOp)) { return; + } // Build the new call op with expanded operands and results. OpBuilder builder(op); @@ -429,10 +437,13 @@ static void expandCallOp(IREE::Util::CallOp op, SymbolTable &symbolTable, // util.return %0, %d static void expandReturnOp(IREE::Util::ReturnOp op, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; - if (IREE::Util::isPublicOrExternal(op->getParentOfType())) + } + if (IREE::Util::isPublicOrExternal( + op->getParentOfType())) { return; + } OpBuilder builder(op); auto operands = expandOperands(op.getLoc(), op.getOperands(), tensorDimMap, indexSet, builder); @@ -462,8 +473,9 @@ static void expandBranchOp(mlir::cf::BranchOp op, IndexSet &indexSet, static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); mlir::cf::CondBranchOp::create( builder, op.getLoc(), op.getCondition(), op.getTrueDest(), @@ -487,8 +499,9 @@ static void expandCondBranchOp(mlir::cf::CondBranchOp op, IndexSet &indexSet, // %4 = flow.tensor.tie_shape %2 : tensor{%3} static void expandSelectOp(mlir::arith::SelectOp op, IndexSet &indexSet, TensorDimMap &tensorDimMap) { - if (!usesDynamicTensors(op)) + if (!usesDynamicTensors(op)) { return; + } OpBuilder builder(op); auto trueValue = consumeExpandedValue(op.getLoc(), op.getTrueValue(), diff --git a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp index cbbbe5f4880c..4f8f7258bf94 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp @@ -39,8 +39,9 @@ static bool isHoistableOp(LoopLikeOpInterface loopOp, Operation *op, for (OpOperand &operand : op->getOpOperands()) { Value value = operand.get(); // Ignore values defined outside the loop. - if (loopOp.isDefinedOutsideOfLoop(value)) + if (loopOp.isDefinedOutsideOfLoop(value)) { continue; + } Operation *producer = value.getDefiningOp(); // If the producer is not an operation, can't hoist it. @@ -61,8 +62,9 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, llvm::SetVector hoistableOps; for (Region *region : loopOp.getLoopRegions()) { // Skip loops with multi-block regions to simplify op's dependency. - if (!region->hasOneBlock()) + if (!region->hasOneBlock()) { return failure(); + } // Consider only the top-level ops in the region. The forward visiting in a // single block ensures we are check and add ops in topological order. @@ -73,8 +75,9 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, } } } - if (hoistableOps.empty()) + if (hoistableOps.empty()) { return success(); + } // Wrap the loop in zero-trip-check so the hoisted ops will only run when the // loop condition is ever satisfied. @@ -87,8 +90,9 @@ static LogicalResult hoistLoopInvariants(LoopLikeOpInterface loopOp, return scf::wrapWhileLoopInZeroTripCheck(op, rewriter); }) .Default([&](Operation *op) { return failure(); }); - if (failed(wrappedLoop)) + if (failed(wrappedLoop)) { return failure(); + } // Hoist ops out of the loop in topological order. for (Operation *op : hoistableOps) { @@ -118,15 +122,17 @@ struct GlobalLoopInvariantCodeMotionPass // to move across multiple loop levels. funcOp.walk([&](LoopLikeOpInterface op) { // Check if the loop type is supported. - if (isa(op)) + if (isa(op)) { candidateLoops.push_back(op); + } return; }); IRRewriter rewriter(context); for (auto loopOp : candidateLoops) { - if (failed(hoistLoopInvariants(loopOp, rewriter))) + if (failed(hoistLoopInvariants(loopOp, rewriter))) { return signalPassFailure(); + } } } }; diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index 4cc86170c52e..d7f507b868d5 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -38,8 +38,9 @@ struct MaterializeHomogeneousEncodingsPass final void runOnOperation() override { mlir::ModuleOp moduleOp = getOperation(); IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp); - if (failed(deviceAnalysis.run())) + if (failed(deviceAnalysis.run())) { return signalPassFailure(); + } SetVector executableTargets; deviceAnalysis.gatherAllExecutableTargets(executableTargets); diff --git a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp index c8f38fa2c4d8..6f7369dd0a14 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp @@ -22,8 +22,9 @@ namespace { int getNextPotBitWidth(int bitWidth, int minBitWidth = 8) { for (int i = minBitWidth;; i *= 2) { - if (i >= bitWidth) + if (i >= bitWidth) { return i; + } } } @@ -108,8 +109,9 @@ struct TensorEmptyCast LogicalResult matchAndRewrite(IREE::Util::NumericCastOpInterface castOp, PatternRewriter &rewriter) const override { auto emptyOp = castOp.getInput().getDefiningOp(); - if (!emptyOp) + if (!emptyOp) { return failure(); + } Type resultType = castOp.getCasted().getType(); rewriter.replaceOpWithNewOp(castOp, resultType, @@ -127,8 +129,9 @@ struct LinalgFillCast PatternRewriter &rewriter) const override { auto loc = castOp.getLoc(); auto fillOp = castOp.getInput().getDefiningOp(); - if (!fillOp) + if (!fillOp) { return failure(); + } Type toElementType = getElementTypeOrSelf(castOp.getCastedType()); Value fillInput = fillOp.value(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index a1223d8bb83b..d7b80f988238 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -285,8 +285,9 @@ class FuseTransposeWithProducerLinalgOp rewriter.replaceOp(transposeOp, newGenericOp->getResult(resultIndex)); for (auto [oldRes, newRes] : llvm::zip_equal(genericOp.getResults(), newGenericOp->getResults())) { - if (oldRes.getResultNumber() == resultIndex) + if (oldRes.getResultNumber() == resultIndex) { continue; + } rewriter.replaceAllUsesWith(oldRes, newRes); } return success(); diff --git a/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp b/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp index 4614f7b56c8b..e516cbf559c7 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/QuantizedConvToConv.cpp @@ -103,8 +103,9 @@ void GetDynamicDym(ImplicitLocOpBuilder &builder, int64_t dim) { ShapedType ty = cast(value.getType()); dims.push_back(ty.getDimSize(dim)); - if (ty && ty.isDynamicDim(dim)) + if (ty && ty.isDynamicDim(dim)) { dynDims.push_back(tensor::DimOp::create(builder, value, dim)); + } } Value multiplyDims(ImplicitLocOpBuilder &builder, Value value, @@ -178,8 +179,9 @@ struct QuantizedConvToConv // Materialize a length-1 dimension at the end of the summation. SmallVector reassociationMap(3); - for (int i = 0; i < 3; i++) + for (int i = 0; i < 3; i++) { reassociationMap[i].push_back(builder.getAffineDimExpr(i)); + } reassociationMap.back().push_back(builder.getAffineDimExpr(3)); auto expandTy = diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index 638ca1f64f05..462a6caf3417 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -83,8 +83,9 @@ raiseTensorExtractToInput(linalg::GenericOp linalgOp, RewriterBase &rewriter) { // Restrict to cases where the constant is 0. This is because handling // constants other than 0 in indexing map, may cause problems in the // lowering pipeline later. - if (constantIndex.getLimitedValue() != 0) + if (constantIndex.getLimitedValue() != 0) { return failure(); + } exprs.push_back(getAffineConstantExpr(0, rewriter.getContext())); continue; } @@ -306,8 +307,9 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { } if (!llvm::all_of(producer.getIndexingMapsArray(), - [](AffineMap map) { return map.isIdentity(); })) + [](AffineMap map) { return map.isIdentity(); })) { return false; + } std::optional castOp = getDefiningNonI1ExtendingCastOp(operand.get()); @@ -319,8 +321,9 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { // preferred to fuse those with producers (and the consumer fusion is // arguably the less canonical form). auto canFoldCast = [&]() { - if (isa(*castOp)) + if (isa(*castOp)) { return true; + } // Signed operations can only be folded with (implicitly) signed // linalg named ops if (isa(*castOp)) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp index d46faef5f3b4..4b025f71c87d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Utils.cpp @@ -112,8 +112,9 @@ Value sumReduceDimensionSubset(ImplicitLocOpBuilder &rewriter, Value val, llvm::SmallVector staticSizes; SmallVector dynSizes; for (int i = 0, s = is_reduction.size(); i < s; i++) { - if (is_reduction[i]) + if (is_reduction[i]) { continue; + } staticSizes.push_back(ty.getDimSize(i)); if (ty.isDynamicDim(i)) { diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp index 08cc77b31eee..13a75494d08a 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp @@ -32,15 +32,17 @@ class AutoInputConversionPipelinePass final }; void AutoInputConversionPipelinePass::runOnOperation() { - if (!pipelineExtensions) + if (!pipelineExtensions) { return; + } mlir::ModuleOp moduleOp = getOperation(); llvm::StringSet<> detectedTypeMnemonics; pipelineExtensions->populateDetectedCustomInputConversionTypes( moduleOp, detectedTypeMnemonics); - if (detectedTypeMnemonics.empty()) + if (detectedTypeMnemonics.empty()) { return; + } if (detectedTypeMnemonics.getNumItems() > 1) { // TODO(scotttodd): handle multiple typeMnemonics (use all?) diff --git a/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp b/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp index 48080d1c0fea..b6480b0b797e 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp @@ -43,8 +43,9 @@ Value convertRankedFloat(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { Type eTy = getElementTypeOrSelf(type); Type inputETy = getElementTypeOrSelf(inputs[0].getType()); - if (!isa(getElementTypeOrSelf(type))) + if (!isa(getElementTypeOrSelf(type))) { return nullptr; + } if (inputETy.getIntOrFloatBitWidth() > eTy.getIntOrFloatBitWidth()) { return arith::TruncFOp::create(builder, loc, type, inputs[0]); @@ -61,8 +62,9 @@ Value convertRankedInteger(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { Type eTy = getElementTypeOrSelf(type); Type inputETy = getElementTypeOrSelf(inputs[0].getType()); - if (!isa(getElementTypeOrSelf(type))) + if (!isa(getElementTypeOrSelf(type))) { return nullptr; + } bool isUnsigned = eTy.isUnsignedInteger(); int64_t inBitwidth = inputETy.getIntOrFloatBitWidth(); @@ -89,8 +91,9 @@ struct PrimitiveTypeConverter : public TypeConverter { explicit PrimitiveTypeConverter() { addConversion([](Type type) { return type; }); addConversion([&](SourceType type) -> Type { - if (!isSourceType(type)) + if (!isSourceType(type)) { return type; + } return getTargetType(type); }); addConversion([&](ComplexType type) { @@ -302,21 +305,25 @@ struct ConvertTypesPass : public Base { return typeConverter.isLegal(globalOp.getGlobalType()); } else if (auto funcOp = dyn_cast(op)) { for (Type type : funcOp.getArgumentTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (Type type : funcOp.getResultTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } } for (Type type : op->getResultTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } for (Type type : op->getOperandTypes()) { - if (!typeConverter.isLegal(type)) + if (!typeConverter.isLegal(type)) { return false; + } } return true; }); diff --git a/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp b/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp index a4c591fa945d..06eb0bf203f6 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp @@ -95,8 +95,9 @@ class MLProgramGlobalOpPattern matchAndRewrite(ml_program::GlobalOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newType = typeConverter->convertType(srcOp.getType()); - if (!newType) + if (!newType) { return failure(); + } std::map externs; @@ -115,12 +116,14 @@ class MLProgramGlobalOpPattern globalOp.setVisibility(SymbolTable::Visibility::Private); globalOp->setDialectAttrs(srcOp->getDialectAttrs()); - if (isExtern) + if (isExtern) { externGlobals.emplace_back(srcOp.getName(), newType); + } // No more work needed if not public global. - if (visibility != SymbolTable::Visibility::Public) + if (visibility != SymbolTable::Visibility::Public) { return success(); + } ModuleOp module = srcOp->getParentOfType(); @@ -140,12 +143,15 @@ class MLProgramGlobalOpPattern StringRef s = format; // Verify only single replacement of 0th index. s = s.drop_until([](char c) { return c == '{'; }); - if (s.empty() || !s.consume_front("{")) + if (s.empty() || !s.consume_front("{")) { return failure(); - if (!s.consume_front("0")) + } + if (!s.consume_front("0")) { return failure(); - if (!s.consume_front("}")) + } + if (!s.consume_front("}")) { return failure(); + } s = s.drop_until([](char c) { return c == '{'; }); return success(s.empty()); }; @@ -157,15 +163,17 @@ class MLProgramGlobalOpPattern v ? dyn_cast_if_present(v.get("get")) : nullptr; { const std::string getFormat = get ? get.str() : "global${0}$get"; - if (failed(verifyFormat(getFormat))) + if (failed(verifyFormat(getFormat))) { return failure(); + } getterName = llvm::formatv(getFormat.c_str(), globalOp.getSymName()); } auto set = v ? dyn_cast_if_present(v.get("set")) : nullptr; { const std::string setFormat = set ? set.str() : "global${0}$set"; - if (failed(verifyFormat(setFormat))) + if (failed(verifyFormat(setFormat))) { return failure(); + } setterName = llvm::formatv(setFormat.c_str(), globalOp.getSymName()); } @@ -258,12 +266,15 @@ void ImportMLProgramPass::runOnOperation() { ONE_TO_ONE(ml_program::GlobalLoadConstOp, IREE::Util::GlobalLoadOp); ONE_TO_ONE(ml_program::GlobalStoreOp, IREE::Util::GlobalStoreOp); - if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); + } if (!externGlobals.empty() && - failed(createExternInitFunction(getOperation(), externGlobals))) + failed(createExternInitFunction(getOperation(), externGlobals))) { signalPassFailure(); + } } } // namespace mlir::iree_compiler::InputConversion diff --git a/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp b/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp index e5c966c9ef75..b9ab257650c4 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/SanitizeModuleNames.cpp @@ -30,8 +30,9 @@ class SanitizeModuleNamesPass final mlir::ModuleOp moduleOp = getOperation(); auto optionalName = moduleOp.getName(); - if (!optionalName.has_value()) + if (!optionalName.has_value()) { return; + } auto name = optionalName.value(); moduleOp.setName(sanitizeSymbolName(name)); diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp index ae76182a3526..07f8f02a5789 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp @@ -37,8 +37,9 @@ struct OptionalCheckImportConversion : public VMImportOpConversion { rewriter.setInsertionPointToStart(callBlock); auto results = rewriteToCall(op, adaptor, this->importOp, *this->getTypeConverter(), rewriter); - if (!results.has_value()) + if (!results.has_value()) { return failure(); + } rewriter.replaceOp(op, results.value()); IREE::VM::BranchOp::create(rewriter, op.getLoc(), followingBlock); return success(); diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp index ac0f7b91b1ee..b8e9e556bb6e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/HALToHALInline/Patterns.cpp @@ -27,9 +27,10 @@ struct ElementTypeOpConversion ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::ElementTypeOp::getTypeValue(op.getTypeAttr().getValue()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported element type"); + } rewriter.replaceOpWithNewOp(op, value.value(), 32); return success(); } @@ -42,9 +43,10 @@ struct EncodingTypeOpConversion matchAndRewrite(IREE::HAL::EncodingTypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto value = IREE::HAL::EncodingTypeOp::getTypeValue(op.getEncodingAttr()); - if (!value.has_value()) + if (!value.has_value()) { return rewriter.notifyMatchFailure(op.getLoc(), "unsupported encoding type"); + } rewriter.replaceOpWithNewOp(op, value.value(), 32); return success(); } diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp index fe756c89031f..af7cd6099fe4 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp @@ -664,10 +664,12 @@ struct GlobalTimepointConversionPattern matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto initialValue = op.getInitialValue(); - if (!initialValue.has_value()) + if (!initialValue.has_value()) { return failure(); - if (!isa(*initialValue)) + } + if (!isa(*initialValue)) { return failure(); + } rewriter.modifyOpInPlace( op, [&]() { op.setInitialValueAttr(rewriter.getI64IntegerAttr(0)); }); return success(); diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp index aae94b66615c..90a4b1b5cf9a 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.cpp @@ -103,8 +103,9 @@ void BufferStorageOp::getAsmResultNames( OpFoldResult BufferStorageOp::fold(FoldAdaptor operands) { auto *definingOp = getBuffer().getDefiningOp(); - if (!definingOp) + if (!definingOp) { return {}; + } if (auto sourceOp = dyn_cast_if_present( definingOp)) { return sourceOp.getStorage(); @@ -168,8 +169,9 @@ struct FoldBufferViewCreateSubspan needsUpdate = true; } rewriter.restoreInsertionPoint(ip); - if (!needsUpdate) + if (!needsUpdate) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getSourceBufferMutable().assign(newSourceBuffer); op.getSourceOffsetMutable().assign(newSourceOffset); diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp index 31f4008b6a7d..e420f5a2f27b 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Conversion/HALLoaderToVM/Patterns.cpp @@ -26,8 +26,9 @@ namespace { // Casts |value| to i32 if it is not already. static Value castToI32(Value value, OpBuilder &builder) { - if (value.getType().isInteger(32)) + if (value.getType().isInteger(32)) { return value; + } return builder.createOrFold( value.getLoc(), builder.getI32Type(), value); } diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp index 9bb81dee2e91..5c5b0c43642e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.cpp @@ -187,8 +187,9 @@ struct FoldBindingSubspansIntoDispatchOp bindingBuffers.push_back(subspanOp.getSource()); bindingOffsets.push_back(newOffset); } - if (!didChangeAny) + if (!didChangeAny) { return failure(); + } rewriter.modifyOpInPlace(op, [&]() { op.getBindingBuffersMutable().assign(bindingBuffers); op.getBindingOffsetsMutable().assign(bindingOffsets); diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp index 7922adde3439..53417ae46986 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.cpp @@ -14,12 +14,14 @@ namespace mlir::iree_compiler::IREE::IO::Parameters { LogicalResult handleRuntimeError(Operation *op, iree_status_t status, StringRef failureMessage) { - if (iree_status_is_ok(status)) + if (iree_status_is_ok(status)) { return success(); + } iree_host_size_t buffer_length = 0; if (!iree_status_format(status, /*buffer_capacity=*/0, - /*buffer=*/nullptr, &buffer_length)) + /*buffer=*/nullptr, &buffer_length)) { return op->emitError() << failureMessage; + } std::string message; message.reserve(buffer_length); message.resize(buffer_length - 1); diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h index 254bc58aa57a..a199759263f7 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ArchiveUtils.h @@ -52,10 +52,11 @@ using ScopePath = std::pair; // If no `scope=` was specified the resulting scope string will be empty. static inline ScopePath splitScopePath(StringRef scopePath) { size_t i = scopePath.find_first_of('='); - if (i == StringRef::npos) + if (i == StringRef::npos) { return ScopePath("", scopePath); - else + } else { return ScopePath(scopePath.substr(0, i), scopePath.substr(i + 1)); + } } // Helper to interpret iree status messages and print the error message. diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp index 539a850ec8af..58680214a2ac 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp @@ -91,15 +91,17 @@ struct ExportParametersPass MLIRContext *context = &getContext(); // Nothing to do if no path specified. - if (scopePath.empty()) + if (scopePath.empty()) { return; + } auto [scope, path] = splitScopePath(scopePath); // Create a builder used to accumulate the parameters. ModuleOp moduleOp = getOperation(); auto builder = createArchiveBuilder(moduleOp); - if (failed(builder)) + if (failed(builder)) { return signalPassFailure(); + } // Accumulate globals that match the pass options and add them to the index. SmallVector constantGlobalOps; @@ -109,31 +111,36 @@ struct ExportParametersPass auto serializableAttr = dyn_cast_if_present( globalOp.getGlobalInitialValue()); - if (!serializableAttr) + if (!serializableAttr) { continue; + } // Check that the serialized size of the attribute is at least as big as // the pass configured minimum storage size. int64_t storageSize = serializableAttr.getStorageSize(); - if (storageSize < minimumSize) + if (storageSize < minimumSize) { continue; + } // Add the entry with a type based on its contents. - if (failed(addEntry(globalOp, serializableAttr, builder->get()))) + if (failed(addEntry(globalOp, serializableAttr, builder->get()))) { return signalPassFailure(); + } constantGlobalOps.push_back(globalOp); } // Early exit if no parameterizable globals are present. - if (constantGlobalOps.empty()) + if (constantGlobalOps.empty()) { return; + } // Create the parameter archive file opened for writing. auto fileStreamIndexOr = createParameterIndex(moduleOp, std::move(builder.value()), path); - if (failed(fileStreamIndexOr)) + if (failed(fileStreamIndexOr)) { return signalPassFailure(); + } auto [file, stream, index] = *std::move(fileStreamIndexOr); // Serialize parameters to the file. diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp index a319589404b2..e7e2b16de58c 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/GenerateSplatParameterArchive.cpp @@ -70,14 +70,16 @@ struct GenerateSplatParameterArchivePass void runOnOperation() override { // Nothing to do if no path specified. - if (filePath.empty()) + if (filePath.empty()) { return; + } // Create a builder used to accumulate the parameters. ModuleOp moduleOp = getOperation(); auto builder = createArchiveBuilder(moduleOp); - if (failed(builder)) + if (failed(builder)) { return signalPassFailure(); + } // Find all parameters in the module and add them to the builder. // NOTE: there may be no parameters but we still will create the archive @@ -86,8 +88,9 @@ struct GenerateSplatParameterArchivePass for (auto [loc, parameterAttr] : parameterAttrs) { // Only support types we can meaningfully generate splats for. auto shapedType = dyn_cast(parameterAttr.getType()); - if (!shapedType) + if (!shapedType) { continue; + } // TODO: support other patterns/generators. auto elementAttr = getDefaultSplatAttr(shapedType.getElementType()); @@ -122,8 +125,9 @@ struct GenerateSplatParameterArchivePass // Create the parameter archive file. auto fileStreamIndexOr = createParameterIndex(moduleOp, std::move(builder.value()), filePath); - if (failed(fileStreamIndexOr)) + if (failed(fileStreamIndexOr)) { return signalPassFailure(); + } auto [file, stream, index] = *std::move(fileStreamIndexOr); // Commit the written file. diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp index 91b6e690819b..c05ba32a614f 100644 --- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp +++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ImportParameters.cpp @@ -76,8 +76,9 @@ loadParameterIndex(ModuleOp moduleOp, StringRef path, iree_io_parameter_index_t *parameterIndex) { // Open the archive file (hopefully mapping it). auto fileHandle = openArchiveFile(moduleOp, path); - if (failed(fileHandle)) + if (failed(fileHandle)) { return failure(); + } // Parse the archive as a particular format. iree_allocator_t hostAllocator = iree_allocator_system(); @@ -103,8 +104,9 @@ class ParameterIndices { iree_io_parameter_index_t *lookupOrCreate(ModuleOp moduleOp, StringRef scope) { iree_allocator_t hostAllocator = iree_allocator_system(); - if (iree_io_parameter_index_t *existing = lookup(scope)) + if (iree_io_parameter_index_t *existing = lookup(scope)) { return existing; + } iree_io_parameter_index_t *parameterIndexPtr = nullptr; if (failed(handleRuntimeError( moduleOp, @@ -133,8 +135,9 @@ loadParameterArchives(ModuleOp moduleOp, ArrayRef scopePaths) { for (auto &scopePath : scopePaths) { auto [scope, path] = splitScopePath(scopePath); auto *parameterIndex = parameterIndices.lookupOrCreate(moduleOp, scope); - if (failed(loadParameterIndex(moduleOp, path, parameterIndex))) + if (failed(loadParameterIndex(moduleOp, path, parameterIndex))) { return failure(); + } } return parameterIndices; } @@ -143,12 +146,14 @@ loadParameterArchives(ModuleOp moduleOp, ArrayRef scopePaths) { // data as stored in the file. static bool isTypeSupported(Type type) { auto shapedType = dyn_cast(type); - if (!shapedType) + if (!shapedType) { return false; + } auto elementType = shapedType.getElementType(); // NOTE: packed types not yet supported. - if (!elementType.isIntOrFloat()) + if (!elementType.isIntOrFloat()) { return false; + } const unsigned logicalBitWidth = elementType.getIntOrFloatBitWidth(); switch (logicalBitWidth) { case 8: @@ -280,29 +285,34 @@ struct ImportParametersPass void runOnOperation() override { // Nothing to do if no path specified. - if (scopePaths.empty()) + if (scopePaths.empty()) { return; + } // Open the archive file (hopefully mapping it) and parse the index. ModuleOp moduleOp = getOperation(); auto parameterIndices = loadParameterArchives(moduleOp, scopePaths); - if (failed(parameterIndices)) + if (failed(parameterIndices)) { return signalPassFailure(); + } // Decide whether to import a particular parameter. DenseSet importKeys; - for (auto &key : keys) + for (auto &key : keys) { importKeys.insert(key); + } auto shouldImportParameter = [&](IREE::Flow::NamedParameterAttr parameterAttr) -> bool { // Always try to import explicitly named parameters. - if (importKeys.contains(parameterAttr.getKey().getValue())) + if (importKeys.contains(parameterAttr.getKey().getValue())) { return true; // key match + } // If a maximum size is specified use that to limit what we import // (users may want to bring in small parameters but leave the big ones // out). - if (maximumSize && parameterAttr.getStorageSize() <= maximumSize) + if (maximumSize && parameterAttr.getStorageSize() <= maximumSize) { return true; // <= max size + } // Default to not importing. return false; }; @@ -312,14 +322,16 @@ struct ImportParametersPass // Only inspect parameter globals. auto parameterAttr = dyn_cast_if_present( globalOp.getGlobalInitialValue()); - if (!parameterAttr) + if (!parameterAttr) { continue; + } // Lookup the parameter index for the scope. auto scope = parameterAttr.getScope().getValue(); auto *parameterIndex = parameterIndices->lookup(scope); - if (!parameterIndex) + if (!parameterIndex) { continue; + } // See if the parameter is present in the scope (we may have only been // provided as partial index). @@ -351,8 +363,9 @@ struct ImportParametersPass auto valueOr = importParameter( fullName, cast(globalOp.getGlobalType()), parameterAttr, entry); - if (failed(valueOr)) + if (failed(valueOr)) { return signalPassFailure(); + } // Replace the initial value with the constant. globalOp.setGlobalInitialValue(*valueOr); diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 9ae20d73a6b5..c8b1cace4003 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -58,15 +58,17 @@ IREEVMPipelineHooks::operator IREE::HAL::PipelineHooks() const { auto beforePhase = this->beforePhase; halHooks.beforePhase = [beforePhase](IREE::HAL::PipelinePhase phase, OpPassManager &passManager) { - if (beforePhase) + if (beforePhase) { beforePhase(getIREEVMPipelinePhase(phase), passManager); + } }; auto afterPhase = this->afterPhase; halHooks.afterPhase = [afterPhase](IREE::HAL::PipelinePhase phase, OpPassManager &passManager) { - if (afterPhase) + if (afterPhase) { afterPhase(getIREEVMPipelinePhase(phase), passManager); + } }; return halHooks; @@ -90,8 +92,9 @@ void buildIREEPrecompileTransformPassPipeline( if (compileFrom < IREEVMPipelinePhase::Input) { // late-entry auto inputType = inputOptions.parseInputTypeMnemonic(); IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Input"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Input, passManager); + } if (hooks.pipelineExtensions) { hooks.pipelineExtensions->extendInputConversionPreprocessingPassPipeline( passManager, inputType); @@ -132,18 +135,21 @@ void buildIREEPrecompileTransformPassPipeline( InputConversion::buildCommonInputConversionPassPipeline( passManager, inputTransformOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Input, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Input"); } - if (compileTo == IREEVMPipelinePhase::Input) + if (compileTo == IREEVMPipelinePhase::Input) { return; // early-exit + } // Now that inputs are legalized, generate wrapper for entry functions. if (compileFrom < IREEVMPipelinePhase::ABI) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "ABI"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::ABI, passManager); + } IREE::ABI::InvocationOptions invocationOptions; invocationOptions.invocationModel = schedulingOptions.executionModel == @@ -156,12 +162,14 @@ void buildIREEPrecompileTransformPassPipeline( if (bindingOptions.tflite) { IREE::TFLite::buildTransformPassPipeline(passManager); } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::ABI, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "ABI"); } - if (compileTo == IREEVMPipelinePhase::ABI) + if (compileTo == IREEVMPipelinePhase::ABI) { return; // early-exit + } // If the user specified a set of target devices we attach them to the module // IR so that they are available for all passes that may want to use this @@ -228,16 +236,19 @@ void buildIREEPrecompileTransformPassPipeline( default: if (compileFrom < IREEVMPipelinePhase::Preprocessing) { // late-entry. IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Preprocessing"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Preprocessing, passManager); + } Preprocessing::buildPreprocessingPassPipeline( passManager, preprocessingOptions, hooks.pipelineExtensions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Preprocessing, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Preprocessing"); } - if (compileTo == IREEVMPipelinePhase::Preprocessing) + if (compileTo == IREEVMPipelinePhase::Preprocessing) { return; // early-exit + } if (compileFrom < IREEVMPipelinePhase::GlobalOptimization) { // late-entry // This pass pipeline recursively invokes the compiler if constEval is @@ -255,20 +266,23 @@ void buildIREEPrecompileTransformPassPipeline( } else { IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "GlobalOptimization"); } - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::GlobalOptimization, passManager); + } GlobalOptimization::buildGlobalOptimizationPassPipeline( passManager, globalTransformOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::GlobalOptimization, passManager); + } if (globalOptimizationOptions.constEval) { IREE_TRACE_ADD_END_FRAME_PASS(passManager, "GlobalOptimizationConst"); } else { IREE_TRACE_ADD_END_FRAME_PASS(passManager, "GlobalOptimization"); } } - if (compileTo == IREEVMPipelinePhase::GlobalOptimization) + if (compileTo == IREEVMPipelinePhase::GlobalOptimization) { return; // early-exit + } break; } @@ -292,8 +306,9 @@ void buildIREEVMTransformPassPipeline( dispatchCreationOptions, schedulingOptions, halTargetOptions, hooks, passManager, compileFrom, compileTo); - if (compileTo <= IREEVMPipelinePhase::GlobalOptimization) + if (compileTo <= IREEVMPipelinePhase::GlobalOptimization) { return; // early-exit + } IREE::Stream::TransformOptions streamOptions; // TODO(benvanik): find a way to share the enums w/o circular deps. @@ -354,42 +369,51 @@ void buildIREEVMTransformPassPipeline( pipelineOptions.constExprHoisting; if (compileFrom < IREEVMPipelinePhase::DispatchCreation) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "DispatchCreation"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::DispatchCreation, passManager); + } DispatchCreation::buildDispatchCreationPassPipeline( passManager, dispatchTransformOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::DispatchCreation, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "DispatchCreation"); } - if (compileTo == IREEVMPipelinePhase::DispatchCreation) + if (compileTo == IREEVMPipelinePhase::DispatchCreation) { return; // early-exit + } IREE::Flow::TransformOptions flowOptions; if (compileFrom < IREEVMPipelinePhase::Flow) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Flow"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Flow, passManager); + } IREE::Flow::buildFlowTransformPassPipeline(passManager, flowOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Flow, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Flow"); } - if (compileTo == IREEVMPipelinePhase::Flow) + if (compileTo == IREEVMPipelinePhase::Flow) { return; // early-exit + } if (compileFrom < IREEVMPipelinePhase::Stream) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "Stream"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::Stream, passManager); + } IREE::Stream::buildStreamTransformPassPipeline(passManager, streamOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::Stream, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "Stream"); } - if (compileTo == IREEVMPipelinePhase::Stream) + if (compileTo == IREEVMPipelinePhase::Stream) { return; // early-exit + } break; } @@ -400,8 +424,9 @@ void buildIREEVMTransformPassPipeline( if (compileFrom < IREEVMPipelinePhase::HAL) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "HAL"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::HAL, passManager); + } switch (schedulingOptions.executionModel) { case SchedulingOptions::ExecutionModel::HostOnly: // No HAL required. @@ -422,8 +447,9 @@ void buildIREEVMTransformPassPipeline( passManager, targetRegistry, halTargetOptions); break; } - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::HAL, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "HAL"); } if (compileTo == IREEVMPipelinePhase::HAL || @@ -433,15 +459,18 @@ void buildIREEVMTransformPassPipeline( if (compileFrom < IREEVMPipelinePhase::VM) { // late-entry IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "VM"); - if (hooks.beforePhase) + if (hooks.beforePhase) { hooks.beforePhase(IREEVMPipelinePhase::VM, passManager); + } IREE::VM::buildVMTransformPassPipeline(passManager, vmTargetOptions); - if (hooks.afterPhase) + if (hooks.afterPhase) { hooks.afterPhase(IREEVMPipelinePhase::VM, passManager); + } IREE_TRACE_ADD_END_FRAME_PASS(passManager, "VM"); } - if (compileTo == IREEVMPipelinePhase::VM) + if (compileTo == IREEVMPipelinePhase::VM) { return; // early-exit + } } void buildDefaultIREEVMTransformPassPipeline(OpPassManager &passManager) { diff --git a/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp b/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp index 3f37b9faeb72..3ac9025a2029 100644 --- a/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp +++ b/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp @@ -136,8 +136,9 @@ LogicalResult PluginManagerSession::initializePlugins() { } // Skip if already initialized. - if (!initializedIds.insert(it.first()).second) + if (!initializedIds.insert(it.first()).second) { continue; + } if (options.printPluginInfo) { llvm::errs() << "[IREE plugins]: Initializing default '" << it.first() @@ -156,8 +157,9 @@ LogicalResult PluginManagerSession::initializePlugins() { } // Skip if already initialized. - if (!initializedIds.insert(pluginId).second) + if (!initializedIds.insert(pluginId).second) { continue; + } if (options.printPluginInfo) { llvm::errs() << "[IREE plugins]: Initializing plugin '" << pluginId @@ -187,8 +189,9 @@ void PluginManagerSession::registerDialects(DialectRegistry ®istry) { LogicalResult PluginManagerSession::activatePlugins(MLIRContext *context) { for (auto *s : initializedSessions) { - if (failed(s->activate(context))) + if (failed(s->activate(context))) { return failure(); + } } return success(); } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp index 627e1a1af948..f5db486146b3 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp @@ -320,12 +320,14 @@ createFlowDispatchOp(PatternRewriter &rewriter, SymbolRefAttr exportOp, // Get the dynamic dims for the operands. for (auto operand : operands) { auto tensorType = dyn_cast(operand.getType()); - if (!tensorType) + if (!tensorType) { continue; + } for (auto [index, shape] : llvm::enumerate(tensorType.getShape())) { - if (ShapedType::isStatic(shape)) + if (ShapedType::isStatic(shape)) { continue; + } Value dim = tensor::DimOp::create(rewriter, loc, operand, index); operandDynamicDims.push_back(dim); @@ -352,8 +354,9 @@ getDynamicResultDims(PatternRewriter &rewriter, ValueRange givenResultDims) { SmallVector mixedValues = getAsOpFoldResult(givenResultDims); for (auto ofr : mixedValues) { auto value = dyn_cast(ofr); - if (!value) + if (!value) { continue; + } dynamicResultDims.push_back(value); } return dynamicResultDims; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp index 267e19e605d3..86f407760831 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp @@ -30,15 +30,17 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { static Value createAdd(Location loc, Value x, Value y, bool isInt, OpBuilder &builder) { - if (isInt) + if (isInt) { return arith::AddIOp::create(builder, loc, x, y); + } return arith::AddFOp::create(builder, loc, x, y); } static Value createMul(Location loc, Value x, Value y, bool isInt, OpBuilder &builder) { - if (isInt) + if (isInt) { return arith::MulIOp::create(builder, loc, x, y); + } return arith::MulFOp::create(builder, loc, x, y); } @@ -255,11 +257,12 @@ class ConvertDepthwiseConv2DNhwcHwc final } // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) + if (!hasAllOneValues(convOp.getDilations())) { return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { diag << "[unimplemented] " << "expected no dilations (expected dilations to all be one)."; }); + } auto loc = convOp.getLoc(); @@ -415,11 +418,12 @@ class ConvertConv2DNchwFchw final } // TODO: Support dilation. - if (!hasAllOneValues(convOp.getDilations())) + if (!hasAllOneValues(convOp.getDilations())) { return rewriter.notifyMatchFailure(convOp, [](Diagnostic &diag) { diag << "[unimplemented] " << "expected no dilations (expected dilations to all be one)."; }); + } Value input = convOp.getInputs()[0]; Value filter = convOp.getInputs()[1]; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp index fec1c95531b2..d1eda2597e1c 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvFilterToChannelsLast.cpp @@ -268,8 +268,9 @@ struct ConvertGenericFilterToFhwc : public OpRewritePattern { FailureOr reorderOp = linalg::interchangeGenericOp(rewriter, genericOp, interchange); - if (failed(reorderOp)) + if (failed(reorderOp)) { return failure(); + } rewriter.replaceOp(linalgOp, reorderOp->getResults()); return success(); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp index ee49524e2942..de71cfd01c34 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp @@ -522,10 +522,12 @@ class GeneralizeOuterUnitDimsPackOp final LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { - if (!packOp.getOuterDimsPerm().empty()) + if (!packOp.getOuterDimsPerm().empty()) { return failure(); - if (packOp.getPaddingValue()) + } + if (packOp.getPaddingValue()) { return failure(); + } RankedTensorType destType = cast(packOp.getDest().getType()); @@ -572,8 +574,9 @@ class GeneralizeOuterUnitDimsPackOp final int64_t nTiled = 0; for (int64_t srcIdx = 0; srcIdx < srcRank; srcIdx++) { reassocationIndices.push_back({srcIdx + nTiled}); - while (innerDims.contains(srcIdx + nTiled)) + while (innerDims.contains(srcIdx + nTiled)) { reassocationIndices.back().push_back(srcIdx + ++nTiled); + } } rewriter.replaceOpWithNewOp( @@ -603,8 +606,9 @@ class GeneralizeOuterUnitDimsUnPackOp final LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override { - if (!unpackOp.getOuterDimsPerm().empty()) + if (!unpackOp.getOuterDimsPerm().empty()) { return failure(); + } RankedTensorType srcType = cast(unpackOp.getSource().getType()); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp index eaf16845b335..cd2de6158f9e 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp @@ -30,8 +30,9 @@ class InterpreterPass // pass finishes. OwningOpRef transformModule; if (failed(transform::detail::assembleTransformLibraryFromPaths( - context, transformSpecPath, transformModule))) + context, transformSpecPath, transformModule))) { return signalPassFailure(); + } Operation *payloadRoot = getOperation(); Operation *transformEntryPoint = transform::detail::findTransformEntryPoint( getOperation(), *transformModule, "__transform_main"); diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp index 93fcd02cedfd..95ac26cbde7f 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp @@ -30,8 +30,9 @@ class PadMatmulOp : public OpInterfaceRewritePattern { Operation *op = linalgOp.getOperation(); const bool isBatchMatmul = isa(op); const bool isMatmul = isa(op); - if (!isBatchMatmul && !isMatmul) + if (!isBatchMatmul && !isMatmul) { return failure(); + } Location loc = linalgOp.getLoc(); Value lhs = linalgOp.getDpsInputOperand(0)->get(); @@ -42,11 +43,13 @@ class PadMatmulOp : public OpInterfaceRewritePattern { auto rhsType = dyn_cast(rhs.getType()); auto resultType = dyn_cast(result.getType()); - if (!lhsType || !rhsType) + if (!lhsType || !rhsType) { return failure(); + } - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) + if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) { return failure(); + } auto lhsShape = lhsType.getShape(); auto rhsShape = rhsType.getShape(); @@ -63,13 +66,15 @@ class PadMatmulOp : public OpInterfaceRewritePattern { int paddingForN = newNSize - N; int paddingForK = newKSize - K; - if (paddingForM == 0 && paddingForN == 0 && paddingForK == 0) + if (paddingForM == 0 && paddingForN == 0 && paddingForK == 0) { return failure(); + } auto getFullShape = [&](ArrayRef dims) { SmallVector shape; - if (isBatchMatmul) + if (isBatchMatmul) { shape.push_back(B); + } llvm::append_range(shape, dims); return shape; }; diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index ef293c9dd97b..68dd88fb8f38 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -138,8 +138,9 @@ expandMapsAndIterators(SmallVector &expandedMaps, map = map.shiftDims(1, expandDstDim); std::optional maybeDim = map.getResultPosition( getAffineDimExpr(expandSrcDim, map.getContext())); - if (!maybeDim) + if (!maybeDim) { continue; + } map = map.insertResult(getAffineDimExpr(expandDstDim, map.getContext()), maybeDim.value() + 1); } @@ -158,8 +159,9 @@ getIntrinsics(linalg::LinalgOp linalgOp, // For LIT testing, also directly search TargetAttr around the op. target = getGPUTargetAttr(linalgOp); } - if (!target) + if (!target) { return {}; + } IREE::GPU::MMAOpsArrayAttr mmaKinds = target.getWgp().getMma(); @@ -176,8 +178,9 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, // Early exit if cannot find intrinsics or if multiple executable targets. SmallVector intrinsics = getIntrinsics(linalgOp, executableTargets); - if (intrinsics.empty()) + if (intrinsics.empty()) { return; + } // Check that conv has met conditions to go down mfma. SmallVector bounds = linalgOp.getStaticLoopRanges(); @@ -348,8 +351,9 @@ static void padContractionLikeOp( // Early exit if cannot find intrinsics or if multiple executable targets. SmallVector intrinsics = getIntrinsics(linalgOp, executableTargets); - if (intrinsics.empty()) + if (intrinsics.empty()) { return; + } Location loc = linalgOp.getLoc(); @@ -377,8 +381,9 @@ static void padContractionLikeOp( auto operandMap = linalgOp.getMatchingIndexingMap(operand); std::optional maybeDim = operandMap.getResultPosition( getAffineDimExpr(targetDim, operandMap.getContext())); - if (maybeDim) + if (maybeDim) { return std::pair{operand->get(), maybeDim.value()}; + } } return std::nullopt; }; @@ -405,8 +410,9 @@ static void padContractionLikeOp( OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize); if (ShapedType::isDynamic(mSize)) { auto mOperandDimPair = getSrcOperandAndDim(mDim); - if (!mOperandDimPair) + if (!mOperandDimPair) { return; + } auto [mOperand, mOperandDim] = mOperandDimPair.value(); mSizeExpr = tensor::DimOp::create(rewriter, loc, mOperand, mOperandDim) .getResult(); @@ -419,8 +425,9 @@ static void padContractionLikeOp( OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize); if (ShapedType::isDynamic(nSize)) { auto nOperandDimPair = getSrcOperandAndDim(nDim); - if (!nOperandDimPair) + if (!nOperandDimPair) { return; + } auto [nOperand, nOperandDim] = nOperandDimPair.value(); nSizeExpr = tensor::DimOp::create(rewriter, loc, nOperand, nOperandDim) .getResult(); @@ -433,8 +440,9 @@ static void padContractionLikeOp( OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize); if (ShapedType::isDynamic(kSize)) { auto kOperandDimPair = getSrcOperandAndDim(kDim); - if (!kOperandDimPair) + if (!kOperandDimPair) { return; + } auto [kOperand, kOperandDim] = kOperandDimPair.value(); kSizeExpr = tensor::DimOp::create(rewriter, loc, kOperand, kOperandDim) .getResult(); @@ -474,14 +482,16 @@ static void padContractionLikeOp( auto getOperandPadding = [&](AffineMap operandMap) -> SmallVector { auto operandRank = operandMap.getNumResults(); - if (operandRank == 0) + if (operandRank == 0) { return {}; + } SmallVector operandPadding(operandRank, zero); for (auto [targetDim, targetPad] : llvm::zip(mnkDim, mnkPadding)) { std::optional maybeDim = operandMap.getResultPosition( getAffineDimExpr(targetDim, operandMap.getContext())); - if (!maybeDim) + if (!maybeDim) { continue; + } operandPadding[maybeDim.value()] = targetPad; } return operandPadding; @@ -541,13 +551,14 @@ static void padContractionLikeOp( SmallVector offsets(resultRank, zero), strides(resultRank, one), sizes; for (auto [dimIdx, dimSize] : llvm::enumerate(resultShape)) { - if (ShapedType::isDynamic(dimSize)) + if (ShapedType::isDynamic(dimSize)) { sizes.push_back( tensor::DimOp::create(rewriter, loc, linalgOp.getDpsInitOperand(0)->get(), dimIdx) .getResult()); - else + } else { sizes.push_back(rewriter.getIndexAttr(dimSize)); + } } rewriter.replaceOpWithNewOp(linalgOp, paddedCompute, offsets, sizes, strides); diff --git a/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp b/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp index 7bab516d8e80..18b1ff3f5581 100644 --- a/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp +++ b/compiler/src/iree/compiler/Reducer/Framework/Delta.cpp @@ -102,8 +102,9 @@ void Delta::runDeltaPass(DeltaFunc deltaFunc, StringRef message) { for (Chunk chunk : maybeInteresting) { FailureOr result = checkChunk(chunk, deltaFunc, maybeInteresting, uninterestingChunks); - if (failed(result)) + if (failed(result)) { continue; + } // Removing this chunk is still interesting. Mark this chunk as // uninteresting. diff --git a/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h b/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h index 01e7e07757a2..39476e66be55 100644 --- a/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h +++ b/compiler/src/iree/compiler/Reducer/Framework/WorkItem.h @@ -22,8 +22,9 @@ class WorkItem { /// TODO(Groverkss): Ownership of module should be conveyed here via /// mlir::OwningOpReference. void replaceModule(ModuleOp newModule) { - if (root) + if (root) { root->erase(); + } root = newModule; } diff --git a/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp b/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp index c2eef672c2cc..4d66462d37c5 100644 --- a/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp +++ b/compiler/src/iree/compiler/Reducer/Strategies/ReduceLinalgOnTensorsDelta.cpp @@ -28,8 +28,9 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( SmallVector linalgOps; SmallVector keepOps; module.walk([&](linalg::LinalgOp op) { - if (!op.hasPureTensorSemantics()) + if (!op.hasPureTensorSemantics()) { return; + } // Op should have at least one tensor input, otherwise the operation is // already a fill-like operation. // TODO(Groverkss): Explore if we can remove in this case too. @@ -41,14 +42,17 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( } } - if (!hasAtleastOneTensorInput) + if (!hasAtleastOneTensorInput) { return; + } // There should be only 1 tensor output. - if (op.getNumDpsInits() != 1) + if (op.getNumDpsInits() != 1) { return; - if (!isa(op.getDpsInitOperand(0)->get().getType())) + } + if (!isa(op.getDpsInitOperand(0)->get().getType())) { return; + } if (!chunker.shouldFeatureBeKept()) { linalgOps.push_back(op); @@ -84,8 +88,9 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( if (outType.hasStaticShape()) { for (auto *input : linalgOp.getDpsInputOperands()) { auto inType = dyn_cast(input->get().getType()); - if (!inType) + if (!inType) { continue; + } // Check if we can replace an input directly with the output. if (inType == outType) { @@ -124,6 +129,7 @@ void mlir::iree_compiler::Reducer::reduceLinalgOnTensorsDelta( pm.addPass(createCanonicalizerPass()); // Remove dead globals. pm.addPass(createSymbolDCEPass()); - if (failed(pm.run(module))) + if (failed(pm.run(module))) { return; + } } diff --git a/compiler/src/iree/compiler/Utils/ConversionUtils.cpp b/compiler/src/iree/compiler/Utils/ConversionUtils.cpp index 78fd1a93960d..3966e160830d 100644 --- a/compiler/src/iree/compiler/Utils/ConversionUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ConversionUtils.cpp @@ -45,8 +45,9 @@ LogicalResult verifyAllOperationsAreLegal(Operation *op, illegalOps.insert(op); } }); - if (illegalOps.empty()) + if (illegalOps.empty()) { return success(); + } emitLegalizationErrors(op->getLoc(), illegalOps); return failure(); } @@ -60,14 +61,16 @@ Attribute convertAttribute(Location loc, Attribute oldAttr, // Return the same attribute if it doesn't have a type. auto typedOldAttr = dyn_cast(oldAttr); - if (!typedOldAttr) + if (!typedOldAttr) { return oldAttr; + } // Convert the attribute type - if it's the same then it's already legal. auto oldType = typedOldAttr.getType(); auto newType = typeConverter.convertType(oldType); - if (oldType == newType) + if (oldType == newType) { return typedOldAttr; + } if (auto intAttr = dyn_cast(typedOldAttr)) { APInt value = intAttr.getValue(); diff --git a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp index debfe478d4f6..ea0c9bbc8681 100644 --- a/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp +++ b/compiler/src/iree/compiler/Utils/EquivalenceUtils.cpp @@ -21,14 +21,18 @@ OperationEquivalenceCache::OperationEquivalenceCache(MLIRContext *context) StringAttr::get(context, SymbolTable::getSymbolAttrName())) {} OperationEquivalenceCache::~OperationEquivalenceCache() { - for (auto *mapping : mappingFreeList) + for (auto *mapping : mappingFreeList) { delete mapping; - for (auto region : regions) + } + for (auto region : regions) { delete region.second; - for (auto block : blocks) + } + for (auto block : blocks) { delete block.second; - for (auto op : ops) + } + for (auto op : ops) { delete op.second; + } } bool OperationEquivalenceCache::isSymbolAttrName(StringAttr name) const { @@ -52,8 +56,9 @@ OperationEquivalenceCache::acquireMapping() { OperationEquivalenceCache::RegionEntry & OperationEquivalenceCache::getRegion(Region *region) { auto it = regions.find(region); - if (it != regions.end()) + if (it != regions.end()) { return *it->second; + } RegionEntry *entry = new RegionEntry(); for (Block &block : region->getBlocks()) { llvm::ReversePostOrderTraversal traversal(&block); @@ -66,8 +71,9 @@ OperationEquivalenceCache::getRegion(Region *region) { OperationEquivalenceCache::BlockEntry & OperationEquivalenceCache::getBlock(Block *block) { auto it = blocks.find(block); - if (it != blocks.end()) + if (it != blocks.end()) { return *it->second; + } BlockEntry *entry = new BlockEntry(); entry->count = block->getOperations().size(); blocks[block] = entry; @@ -77,8 +83,9 @@ OperationEquivalenceCache::getBlock(Block *block) { OperationEquivalenceCache::OperationEntry & OperationEquivalenceCache::getOp(Operation *op) { auto it = ops.find(op); - if (it != ops.end()) + if (it != ops.end()) { return *it->second; + } OperationEntry *entry = new OperationEntry(); entry->attrs.append(op->getRawDictionaryAttrs().getValue()); if (op->getPropertiesStorageSize()) { @@ -95,8 +102,9 @@ bool compare_ranges(Range &&lhs, Range &&rhs, Pred pred) { auto lhsEnd = lhs.end(); auto rhsEnd = rhs.end(); while (lhsIt != lhsEnd && rhsIt != rhsEnd) { - if (!pred(*lhsIt++, *rhsIt++)) + if (!pred(*lhsIt++, *rhsIt++)) { return false; + } } if ((lhsIt == lhsEnd) != (rhsIt == rhsEnd)) { // Block count mismatch. We do this here so that we avoid the O(n) scan @@ -157,18 +165,21 @@ bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs, Region &rhs, IRMapping &mapping) { auto &lhsRegionEntry = cache.getRegion(&lhs); auto &rhsRegionEntry = cache.getRegion(&rhs); - if (lhsRegionEntry.blocks.size() != rhsRegionEntry.blocks.size()) + if (lhsRegionEntry.blocks.size() != rhsRegionEntry.blocks.size()) { return false; + } // Map blocks and their arguments so that we can compare their use by ops. for (auto [lhsBlock, rhsBlock] : llvm::zip_equal(lhsRegionEntry.blocks, rhsRegionEntry.blocks)) { - if (lhsBlock->getNumArguments() != rhsBlock->getNumArguments()) + if (lhsBlock->getNumArguments() != rhsBlock->getNumArguments()) { return false; + } for (auto [lhsArg, rhsArg] : llvm::zip_equal(lhsBlock->getArguments(), rhsBlock->getArguments())) { - if (lhsArg.getType() != rhsArg.getType()) + if (lhsArg.getType() != rhsArg.getType()) { return false; + } mapping.map(lhsArg, rhsArg); } mapping.map(lhsBlock, rhsBlock); @@ -180,13 +191,15 @@ bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, Region &lhs, llvm::zip_equal(lhsRegionEntry.blocks, rhsRegionEntry.blocks)) { const auto &lhsBlockEntry = cache.getBlock(lhsBlock); const auto &rhsBlockEntry = cache.getBlock(rhsBlock); - if (lhsBlockEntry.count != rhsBlockEntry.count) + if (lhsBlockEntry.count != rhsBlockEntry.count) { return false; + } for (auto [lhsOp, rhsOp] : llvm::zip_equal(lhsBlock->getOperations(), rhsBlock->getOperations())) { - if (!isStructurallyEquivalentTo(cache, lhsOp, rhsOp, mapping)) + if (!isStructurallyEquivalentTo(cache, lhsOp, rhsOp, mapping)) { return false; + } } } @@ -210,13 +223,15 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, auto &rhsEntry = cache.getOp(&rhs); // TODO(#3996): symbol mapping; for now allow them to differ unconditionally. - if (lhsEntry.attrs.getAttrs().size() != rhsEntry.attrs.getAttrs().size()) + if (lhsEntry.attrs.getAttrs().size() != rhsEntry.attrs.getAttrs().size()) { return false; + } for (auto [lhsAttr, rhsAttr] : llvm::zip_equal(lhsEntry.attrs, rhsEntry.attrs)) { if (!cache.isSymbolAttrName(lhsAttr.getName())) { - if (lhsAttr != rhsAttr) + if (lhsAttr != rhsAttr) { return false; + } } } @@ -224,8 +239,9 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, // in the mapping already from the parent region to do the lhs->rhs mapping. for (auto [lhsSuccessor, rhsSuccessor] : llvm::zip_equal(lhs.getSuccessors(), rhs.getSuccessors())) { - if (rhsSuccessor != parentMapping.lookup(lhsSuccessor)) + if (rhsSuccessor != parentMapping.lookup(lhsSuccessor)) { return false; + } } // Ensure result types match first and add to the block and value mapping. @@ -234,8 +250,9 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, // exit prior to the full traversal. for (auto [lhsValue, rhsValue] : llvm::zip_equal(lhs.getResults(), rhs.getResults())) { - if (lhsValue.getType() != rhsValue.getType()) + if (lhsValue.getType() != rhsValue.getType()) { return false; + } parentMapping.map(lhsValue, rhsValue); } @@ -243,10 +260,12 @@ static bool isStructurallyEquivalentTo(OperationEquivalenceCache &cache, // these values they should already be defined in the mapping. for (auto [lhsValue, rhsValue] : llvm::zip_equal(lhs.getOperands(), rhs.getOperands())) { - if (lhsValue.getType() != rhsValue.getType()) + if (lhsValue.getType() != rhsValue.getType()) { return false; - if (rhsValue != parentMapping.lookup(lhsValue)) + } + if (rhsValue != parentMapping.lookup(lhsValue)) { return false; + } } // Recurse into regions. diff --git a/compiler/src/iree/compiler/Utils/FlatbufferUtils.h b/compiler/src/iree/compiler/Utils/FlatbufferUtils.h index 697556e16df0..eddd13b92274 100644 --- a/compiler/src/iree/compiler/Utils/FlatbufferUtils.h +++ b/compiler/src/iree/compiler/Utils/FlatbufferUtils.h @@ -61,8 +61,9 @@ class FlatbufferBuilder { auto stringRefs = llvm::map_to_vector<8>(Range, [&](StringRef value) { return flatbuffers_string_create(*this, value.data(), value.size()); }); - if (stringRefs.empty()) + if (stringRefs.empty()) { return 0; + } return flatbuffers_string_vec_create(*this, stringRefs.data(), stringRefs.size()); } @@ -70,8 +71,9 @@ class FlatbufferBuilder { // Creates an offset vector with the given values. The source values will not // be modified. flatbuffers_vec_ref_t createOffsetVec(ArrayRef values) { - if (values.empty()) + if (values.empty()) { return 0; + } return flatcc_builder_create_offset_vector(*this, values.data(), values.size()); } @@ -81,8 +83,9 @@ class FlatbufferBuilder { // serialization but be much faster. flatbuffers_vec_ref_t createOffsetVecDestructive(SmallVectorImpl &values) { - if (values.empty()) + if (values.empty()) { return 0; + } return flatcc_builder_create_offset_vector_direct(*this, values.data(), values.size()); } @@ -90,8 +93,9 @@ class FlatbufferBuilder { // Creates an [int32] vec with the contents of the given range. template flatbuffers_int32_vec_ref_t createInt32Vec(RangeTy &&Range) { - if (Range.empty()) + if (Range.empty()) { return 0; + } flatbuffers_int32_vec_start(*this); for (int32_t v : Range) { flatbuffers_int32_vec_push_create(*this, v); diff --git a/compiler/src/iree/compiler/Utils/Indexing.cpp b/compiler/src/iree/compiler/Utils/Indexing.cpp index c321c20cc3a2..b27ff4cfd55d 100644 --- a/compiler/src/iree/compiler/Utils/Indexing.cpp +++ b/compiler/src/iree/compiler/Utils/Indexing.cpp @@ -33,8 +33,9 @@ LogicalResult basisFromSizesStrides(ArrayRef sizes, stride = 1; size = 1; } - if (stride % previousSizes != 0) + if (stride % previousSizes != 0) { return failure(); + } // Handle casis like threads = {4, 8}, strides = {1, 16}, which need an // extra basis element. @@ -56,8 +57,9 @@ LogicalResult basisFromSizesStrides(ArrayRef sizes, size_t basisLength = basis.size(); dimToResult.assign(numDims, ~0); for (auto [reverseBasisPos, dimPos] : llvm::enumerate(basisEntryToDim)) { - if (!dimPos) + if (!dimPos) { continue; + } // There's an extra overflow term at the front of the delineraize results, // so this subtraction lands in the [1, basisLength] range we need it // to be in. diff --git a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp index 2006a93dd87b..c57dc3ba5420 100644 --- a/compiler/src/iree/compiler/Utils/ModuleUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ModuleUtils.cpp @@ -31,22 +31,26 @@ std::optional findFirstFileLoc(Location baseLoc) { // Recurse through fused locations. for (auto &childLoc : loc.getLocations()) { auto childResult = findFirstFileLoc(childLoc); - if (childResult) + if (childResult) { return childResult; + } } } else if (auto loc = dyn_cast(baseLoc)) { // First check caller... auto callerResult = findFirstFileLoc(loc.getCaller()); - if (callerResult) + if (callerResult) { return callerResult; + } // Then check callee... auto calleeResult = findFirstFileLoc(loc.getCallee()); - if (calleeResult) + if (calleeResult) { return calleeResult; + } } else if (auto loc = dyn_cast(baseLoc)) { auto childResult = findFirstFileLoc(loc.getChildLoc()); - if (childResult) + if (childResult) { return childResult; + } } else if (auto loc = dyn_cast(baseLoc)) { // TODO(scotttodd): Use loc.fallbackLocation()? } else if (auto loc = dyn_cast(baseLoc)) { @@ -58,8 +62,9 @@ std::optional findFirstFileLoc(Location baseLoc) { std::string guessModuleName(mlir::ModuleOp moduleOp, StringRef defaultName) { std::string moduleName = moduleOp.getName().value_or("").str(); - if (!moduleName.empty()) + if (!moduleName.empty()) { return moduleName; + } auto loc = findFirstFileLoc(moduleOp.getLoc()); if (loc.has_value()) { return sanitizeSymbolName( @@ -152,8 +157,9 @@ LogicalResult mergeModuleInto(Operation *sourceModuleOp, // Resolve conflicts and move the op. for (auto &sourceOp : sourceOps) { - if (sourceOp->hasTrait()) + if (sourceOp->hasTrait()) { continue; + } if (auto symbolOp = dyn_cast(sourceOp)) { auto symbolName = symbolOp.getName(); diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.cpp b/compiler/src/iree/compiler/Utils/OptionUtils.cpp index 9d0678f81b67..3ed11228ef47 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.cpp +++ b/compiler/src/iree/compiler/Utils/OptionUtils.cpp @@ -78,10 +78,12 @@ llvm::SmallVector OptionsBinder::printArguments(bool nonDefaultOnly) { llvm::SmallVector values; for (auto &[flag, info] : getOptionsStorage()) { - if (!info.print) + if (!info.print) { continue; - if (nonDefaultOnly && !info.isDefault()) + } + if (nonDefaultOnly && !info.isDefault()) { continue; + } std::string s; llvm::raw_string_ostream os(s); diff --git a/compiler/src/iree/compiler/Utils/OptionUtils.h b/compiler/src/iree/compiler/Utils/OptionUtils.h index 6e6e722b0131..bb707fca7c6b 100644 --- a/compiler/src/iree/compiler/Utils/OptionUtils.h +++ b/compiler/src/iree/compiler/Utils/OptionUtils.h @@ -36,8 +36,9 @@ struct opt_initializer { : parentName(parentName), init(val), optLevel(opt) {} void apply(const llvm::OptimizationLevel inLevel, Ty &val) const { assert(inLevel.getSizeLevel() == 0 && "size level not implemented"); - if (inLevel.getSpeedupLevel() >= optLevel.getSpeedupLevel()) + if (inLevel.getSpeedupLevel() >= optLevel.getSpeedupLevel()) { val = init; + } } /// Append to the description string of the flag. @@ -201,8 +202,9 @@ class OptionsBinder { void restoreOptimizationDefaults() { for (auto &[_, info] : getOptionsStorage()) { - if (info.optOverrides) + if (info.optOverrides) { info.optOverrides->restoreBackup(); + } } } @@ -463,8 +465,9 @@ class OptionsBinder { return [optionName, values](llvm::raw_ostream &os) { os << "--" << optionName << "="; for (auto it : llvm::enumerate(*values)) { - if (it.index() > 0) + if (it.index() > 0) { os << ","; + } os << it.value(); } }; @@ -478,8 +481,9 @@ class OptionsBinder { [&] { if constexpr (std::is_same_v, llvm::cl::desc>) { assert(!result && "Multiple llvm::cl::desc in args"); - if (!result) + if (!result) { result = &args; + } } }(), ...); diff --git a/compiler/src/iree/compiler/Utils/ToolUtils.cpp b/compiler/src/iree/compiler/Utils/ToolUtils.cpp index 6a8524074359..67860cf3d7c6 100644 --- a/compiler/src/iree/compiler/Utils/ToolUtils.cpp +++ b/compiler/src/iree/compiler/Utils/ToolUtils.cpp @@ -109,8 +109,9 @@ std::string findToolFromExecutableDir(SmallVector toolNames) { static std::string getCurrentDylibPath() { #if __linux__ || __APPLE__ Dl_info dlInfo; - if (dladdr((void *)getCurrentDylibPath, &dlInfo) == 0) + if (dladdr((void *)getCurrentDylibPath, &dlInfo) == 0) { return {}; + } return (dlInfo.dli_fname); #elif defined(WIN32) HMODULE hm = NULL; @@ -145,8 +146,9 @@ static std::string getCurrentDylibPath() { std::string findToolFromDylibDir(SmallVector toolNames) { const auto &normalizedToolNames = normalizeToolNames(toolNames); std::string dylibPath = getCurrentDylibPath(); - if (dylibPath.empty()) + if (dylibPath.empty()) { return {}; + } SmallString<256> dylibDir(dylibPath); llvm::sys::path::remove_filename(dylibDir); @@ -240,18 +242,21 @@ std::string findTool(SmallVector toolNames) { // TODO(benvanik): add a test for IREE_[toolName]_PATH. std::string dylibDirPath = findToolFromDylibDir(toolNames); - if (!dylibDirPath.empty()) + if (!dylibDirPath.empty()) { return dylibDirPath; + } // Search the install or build dir. std::string executableDirPath = findToolFromExecutableDir(toolNames); - if (!executableDirPath.empty()) + if (!executableDirPath.empty()) { return executableDirPath; + } // Currently fall back on searching the environment. std::string environmentPath = findToolInEnvironment(toolNames); - if (!environmentPath.empty()) + if (!environmentPath.empty()) { return environmentPath; + } return ""; } @@ -263,14 +268,16 @@ std::string findTool(std::string toolName) { std::string findPlatformLibDirectory(StringRef platformName) { std::string dylibPath = getCurrentDylibPath(); - if (dylibPath.empty()) + if (dylibPath.empty()) { return {}; + } SmallString<256> path(dylibPath); llvm::sys::path::remove_filename(path); llvm::sys::path::append(path, "iree_platform_libs", platformName); - if (!llvm::sys::fs::is_directory(path)) + if (!llvm::sys::fs::is_directory(path)) { return {}; + } llvm::sys::fs::make_absolute(path); (void)llvm::sys::path::remove_dots(path, /*remove_dot_dot=*/true); From 5287d0ddda8b4faad43a8f1a12b4dabb12ebb5fa Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 16:57:15 -0500 Subject: [PATCH 51/71] Enforce braces around control flow statements in clang-format. 7/7. (#23150) Braces has been mandatory for a long time and we call it out in our style guide: https://iree.dev/developers/general/contributing/#coding-style-guidelines. Add them to clang-format so that I don't have to point it out in code reviews. I cleaned up the codebase in the following PRs: 1. https://github.com/iree-org/iree/pull/23143 2. https://github.com/iree-org/iree/pull/23144 3. https://github.com/iree-org/iree/pull/23145 4. https://github.com/iree-org/iree/pull/23146 5. https://github.com/iree-org/iree/pull/23147 6. https://github.com/iree-org/iree/pull/23148 --- compiler/.clang-format | 1 + .../builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c | 3 ++- .../builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c | 3 ++- .../ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c | 3 ++- .../ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c | 3 ++- .../ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c | 3 ++- .../ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c | 3 ++- compiler/src/iree/compiler/Tools/iree_compile_lib.cc | 9 ++++++--- 8 files changed, 19 insertions(+), 9 deletions(-) diff --git a/compiler/.clang-format b/compiler/.clang-format index f50fe3d2d350..3e499f9a3782 100644 --- a/compiler/.clang-format +++ b/compiler/.clang-format @@ -10,6 +10,7 @@ # ordering. BasedOnStyle: LLVM AlwaysBreakTemplateDeclarations: Yes +InsertBraces: Yes IncludeCategories: - Regex: '^<.*\.h>' Priority: 1 diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c index aeebb08a26b6..cb542ed9d0a5 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i32.c @@ -31,8 +31,9 @@ float newIn = idx >= reductionSize ? -FLT_MAX : (float)(inputBuffer[input_offset + idx]); - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c index 50388dac062b..d68f27693e2d 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_bf16i64.c @@ -31,8 +31,9 @@ float newIn = idx >= reductionSize ? -FLT_MAX : (float)(inputBuffer[input_offset + idx]); - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c index 2d8f51add345..d2ee3cff419a 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i32.c @@ -25,8 +25,9 @@ int32_t idx = warpSize * i + laneID; _Float16 newIn = idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf16(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c index 2232d5f3887a..fd00bc8c1c33 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f16i64.c @@ -25,8 +25,9 @@ int32_t idx = warpSize * i + laneID; _Float16 newIn = idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf16(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c index ad5d5088e054..7819020f9f0f 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i32.c @@ -24,8 +24,9 @@ int32_t idx = warpSize * i + laneID; float newIn = idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c index 5438c79cc182..d608d2368e69 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_argmax_f32i64.c @@ -24,8 +24,9 @@ int32_t idx = warpSize * i + laneID; float newIn = idx >= reductionSize ? -FLT_MAX : inputBuffer[input_offset + idx]; - if (newIn == laneMax) + if (newIn == laneMax) { continue; + } laneMax = __builtin_fmaxf(newIn, laneMax); laneResult = newIn == laneMax ? idx : laneResult; } diff --git a/compiler/src/iree/compiler/Tools/iree_compile_lib.cc b/compiler/src/iree/compiler/Tools/iree_compile_lib.cc index 0f4586f53528..1ebfc8d7747e 100644 --- a/compiler/src/iree/compiler/Tools/iree_compile_lib.cc +++ b/compiler/src/iree/compiler/Tools/iree_compile_lib.cc @@ -52,9 +52,10 @@ struct BytecodeVersionParser : public llvm::cl::parser> { bool parse(llvm::cl::Option &O, StringRef /*argName*/, StringRef arg, std::optional &v) { long long w; - if (llvm::getAsSignedInteger(arg, 10, w)) + if (llvm::getAsSignedInteger(arg, 10, w)) { return O.error("Invalid argument '" + arg + "', only integer is supported."); + } v = w; return false; } @@ -264,8 +265,9 @@ int mlir::iree_compiler::runIreecMain(int argc, char **argv) { remarksOutputFile.c_str()); } - if (!ireeCompilerInvocationParseSource(r.inv, source)) + if (!ireeCompilerInvocationParseSource(r.inv, source)) { return false; + } // Switch on compileMode to choose a pipeline to run. switch (compileMode) { @@ -377,8 +379,9 @@ int mlir::iree_compiler::runIreecMain(int argc, char **argv) { return 1; } } else { - if (!processBuffer(s.source)) + if (!processBuffer(s.source)) { return 1; + } } ireeCompilerOutputKeep(s.output); From 1e3ee626df1d65a2f9700ee2bf8fa03b047caf4a Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:07:51 -0500 Subject: [PATCH 52/71] Reapply "[Util] Implement InferIntDivisibilityOpInterface for affine ops (#22860)" (#23137) Reapply https://github.com/iree-org/iree/pull/22723 now that the torch model failures were fixed by https://github.com/iree-org/iree/pull/23118. ci-extra: test_torch Signed-off-by: Max Dawkins --- .../Dialect/Util/Transforms/BUILD.bazel | 1 + .../Dialect/Util/Transforms/CMakeLists.txt | 1 + .../compiler/Dialect/Util/Transforms/Passes.h | 1 + .../Dialect/Util/Transforms/Passes.td | 10 + .../TestIntegerDivisibilityAnalysis.cpp | 68 +++++ .../Dialect/Util/Transforms/test/BUILD.bazel | 1 + .../Util/Transforms/test/CMakeLists.txt | 1 + .../test_integer_divisibility_analysis.mlir | 188 +++++++++++++ .../compiler/ExternalInterfaces/BUILD.bazel | 1 + .../ExternalInterfaces/CMakeLists.txt | 1 + .../ExternalInterfaces/UtilExternalModels.cpp | 248 +++++++++++++++++- 11 files changed, 519 insertions(+), 2 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index f8bedfcc836c..997dfe0f5742 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -43,6 +43,7 @@ iree_compiler_cc_library( "StripDebugOps.cpp", "TestConversion.cpp", "TestFloatRangeAnalysis.cpp", + "TestIntegerDivisibilityAnalysis.cpp", "VerifyInitializationOrder.cpp", "VerifyStructuredControlFlow.cpp", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index f73fa2a9bca9..e64d47b6ab6d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -41,6 +41,7 @@ iree_cc_library( "StripDebugOps.cpp" "TestConversion.cpp" "TestFloatRangeAnalysis.cpp" + "TestIntegerDivisibilityAnalysis.cpp" "VerifyInitializationOrder.cpp" "VerifyStructuredControlFlow.cpp" DEPS diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 5445337ddf1a..188c5457b7d0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h @@ -73,6 +73,7 @@ createHoistIntoGlobalsPass(const ExprHoistingOptions &options); #define GEN_PASS_DECL_STRIPDEBUGOPSPASS #define GEN_PASS_DECL_TESTCONVERSIONPASS #define GEN_PASS_DECL_TESTFLOATRANGEANALYSISPASS +#define GEN_PASS_DECL_TESTINTEGERDIVISIBILITYANALYSISPASS #define GEN_PASS_DECL_VERIFYINITIALIZATIONORDERPASS #define GEN_PASS_DECL_VERIFYSTRUCTUREDCONTROLFLOWPASS #include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" // IWYU pragma: keep diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index 7093bed69d38..b3f46f78add6 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -346,4 +346,14 @@ def TestFloatRangeAnalysisPass : Pass<"iree-util-test-float-range-analysis", ""> }]; } +def TestIntegerDivisibilityAnalysisPass : + Pass<"iree-util-test-integer-divisibility-analysis", ""> { + let summary = "Tests integer divisibility analysis."; + let description = [{ + Tests integer divisibility analysis by evaluating any + 'iree_unregistered.test_int_divisibility' op and setting the results on an + attribute. + }]; +} + #endif // IREE_DIALECT_UTIL_PASSES diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp new file mode 100644 index 000000000000..21954d4ec0dc --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp @@ -0,0 +1,68 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir::iree_compiler::IREE::Util { + +#define GEN_PASS_DEF_TESTINTEGERDIVISIBILITYANALYSISPASS +#include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" + +namespace { + +class TestIntegerDivisibilityAnalysisPass + : public impl::TestIntegerDivisibilityAnalysisPassBase< + TestIntegerDivisibilityAnalysisPass> { +public: + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = &getContext(); + + // The pass is rooted on `iree_unregistered.test_int_divisibility` ops, + // which are expected to have a single operand for which to annotate + // divisibility information. + SmallVector> queryOps; + rootOp->walk([&](Operation *op) { + if (op->getName().getStringRef() == + "iree_unregistered.test_int_divisibility" && + op->getNumOperands() == 1) { + queryOps.emplace_back(op, op->getOperand(0)); + } + }); + + DataFlowSolver solver; + // DeadCodeAnalysis is the base analysis that allows the solver to traverse + // control flow. We include it to make the divisibility analysis more + // powerful. + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(rootOp))) { + return signalPassFailure(); + } + + for (auto &[op, value] : queryOps) { + auto *lattice = solver.lookupState(value); + if (!lattice || lattice->getValue().isUninitialized()) { + op->setAttr("divisibility", StringAttr::get(context, "uninitialized")); + continue; + } + + // Format for the divisibility information is "udiv = X, sdiv = Y". + const auto &div = lattice->getValue().getValue(); + std::string result; + llvm::raw_string_ostream os(result); + os << "udiv = " << div.udiv() << ", sdiv = " << div.sdiv(); + op->setAttr("divisibility", StringAttr::get(context, os.str())); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index f4b05a1ef1fe..7c52ce11c6b3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -42,6 +42,7 @@ iree_lit_test_suite( "strip_debug_ops.mlir", "test_float_range_analysis.mlir", "test_float_range_analysis_linalg.mlir", + "test_integer_divisibility_analysis.mlir", "verify_initialization_order.mlir", "verify_structured_control_flow.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index abca38549966..658f9a9582f3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -39,6 +39,7 @@ iree_lit_test_suite( "strip_debug_ops.mlir" "test_float_range_analysis.mlir" "test_float_range_analysis_linalg.mlir" + "test_integer_divisibility_analysis.mlir" "verify_initialization_order.mlir" "verify_structured_control_flow.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir new file mode 100644 index 000000000000..998b6f9a5592 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir @@ -0,0 +1,188 @@ +// RUN: iree-opt --split-input-file --iree-util-test-integer-divisibility-analysis --allow-unregistered-dialect %s | FileCheck %s + +// CHECK-LABEL: @affine_apply_mul_divisibility +util.func @affine_apply_mul_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 4)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mul_negative +util.func @affine_apply_mul_negative(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * -4)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_add_gcd +util.func @affine_apply_add_gcd(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_floordiv_exact +util.func @affine_apply_floordiv_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_ceildiv_exact +util.func @affine_apply_ceildiv_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_floordiv_non_exact +util.func @affine_apply_floordiv_non_exact(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%0) + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_mod +util.func @affine_apply_mod(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%0) + // CHECK: divisibility = "udiv = 1, sdiv = 1" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_composition +util.func @affine_apply_composition(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 4 + 16)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_with_symbol +util.func @affine_apply_with_symbol(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%0#0)[%0#1] + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_min_uniform_divisibility +util.func @affine_min_uniform_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.min affine_map<(d0) -> (d0, d0 + 64)>(%0) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_min_different_divisibilities +util.func @affine_min_different_divisibilities(%arg0 : index, %arg1 : index) -> index { + %0:2 = util.assume.int %arg0, + %arg1 : index, index + %1 = affine.min affine_map<(d0, d1) -> (d0, d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 8, sdiv = 8" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_max_uniform_divisibility +util.func @affine_max_uniform_divisibility(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.max affine_map<(d0) -> (d0, d0 - 64)>(%0) + // CHECK: divisibility = "udiv = 32, sdiv = 32" + %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + util.return %2 : index +} + +// ----- + +// CHECK-LABEL: @affine_max_different_divisibilities +util.func @affine_max_different_divisibilities(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %0:3 = util.assume.int %arg0, + %arg1, + %arg2 : index, index, index + %3 = affine.max affine_map<(d0, d1, d2) -> (d0, d1, d2)>(%0#0, %0#1, %0#2) + // CHECK: divisibility = "udiv = 6, sdiv = 6" + %4 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index + util.return %4 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_constant +util.func @affine_apply_constant() -> index { + %0 = affine.apply affine_map<() -> (64)>() + // CHECK: divisibility = "udiv = 64, sdiv = 64" + %1 = "iree_unregistered.test_int_divisibility"(%0) : (index) -> index + util.return %1 : index +} + +// ----- + +// CHECK-LABEL: @affine_apply_chained_operations +util.func @affine_apply_chained_operations(%arg0 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = affine.apply affine_map<(d0) -> (d0 * 8)>(%0) + %2 = affine.apply affine_map<(d0) -> (d0 + 16)>(%1) + // CHECK: divisibility = "udiv = 16, sdiv = 16" + %3 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index + util.return %3 : index +} + +// ----- + +// CHECK-LABEL: @complex_chained_affine_ops +util.func @complex_chained_affine_ops(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %0:3 = util.assume.int %arg0, + %arg1, + %arg2 : index, index, index + %1 = affine.apply affine_map<(d0, d1) -> (d0 + 2 * d1)>(%0#0, %0#1) + // CHECK: divisibility = "udiv = 14, sdiv = 14" + %div_1 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index + %2 = affine.max affine_map<(d0, d1) -> (d0 floordiv 6, d1 * 3)>(%0#0, %0#2) + // CHECK: divisibility = "udiv = 5, sdiv = 5" + %div_2 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index + %3 = affine.min affine_map<(d0)[s0] -> (2 * (s0 * d0 - 14) ceildiv 7, d0 floordiv 3 * 2)>(%2)[%1] + // CHECK: divisibility = "udiv = 2, sdiv = 2" + %div_3 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index + util.return %div_3 : index +} diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel index 57b34e43879d..06ef7e993587 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel @@ -42,6 +42,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/TensorExt/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt index a52be8559525..25c6eee54999 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt @@ -31,6 +31,7 @@ iree_cc_library( "UtilExternalModels.cpp" DEPS LLVMSupport + MLIRAffineDialect MLIRArithDialect MLIRControlFlowInterfaces MLIRGPUDialect diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index d7542c4d322f..e1ff0e5c6d2b 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -16,12 +16,14 @@ #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -50,6 +52,232 @@ getDivisibilityOfOperand(Value v, return IREE::Util::ConstantIntDivisibility(1, 1); } +/// Visits affine expressions and recursively calculates the divisibilities of +/// each subexpression. The final divisibilities of the expression and its +/// subexpressions will be stored in the map for which a reference is provided +/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`). +class AffineExprDivisibilityFinder + : public AffineExprVisitor { +public: + using ExprDivisibilityMap = + llvm::DenseMap; + AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap) + : divisibilityMap(divisibilityMap) {} + + IREE::Util::ConstantIntDivisibility + visitConstantExpr(AffineConstantExpr expr) { + // Constant expressions are trivial, since they are always static. + uint64_t constValue = std::abs(expr.getValue()); + return IREE::Util::ConstantIntDivisibility(constValue, constValue); + } + + IREE::Util::ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) { + // Dim expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + IREE::Util::ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) { + // Symbol expressions cannot be analyzed further, so return the divisibility + // in `divisibilityMap` if it has been populated by the caller, or fallback + // to the minimum divisibility. + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Infer the divisibility of an addition or subtraction expression by + /// recursively visiting the LHS and RHS, and then unioning the results. + IREE::Util::ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + // The divisibility of an addition is the GCD of its constituents' + // divisibilities. + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return lhsDiv.getUnion(rhsDiv); + } + + /// Infer the divisibility of a multiplication expression by recursively + /// visiting the LHS and RHS, and then multiplying the results. + IREE::Util::ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + // The divisibility of a multiplication is the product of its constituents' + // divisibilities. + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS()); + return IREE::Util::ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(), + lhsDiv.sdiv() * rhsDiv.sdiv()); + } + + IREE::Util::ConstantIntDivisibility + visitFloorDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + IREE::Util::ConstantIntDivisibility + visitCeilDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr); + } + + /// Mod expressions could be inferred to be zero in some cases, but for now + /// just return the minimum divisibility. + /// TODO(Max191): Handle evenly divisible cases, and ensure that the zero + /// divisibility propagates properly through parent expressions. + IREE::Util::ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) { + return visitInvalidExpr(expr); + } + +private: + IREE::Util::ConstantIntDivisibility + visitInvalidExpr(AffineBinaryOpExpr expr) { + return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue(); + } + + /// Helper shared by ceildiv and floordiv implementations. Returns the minimum + /// divisibility as a fallback if the divisor is not a constant, because the + /// divisibility cannot be inferred in this case. If the divisor is a + /// constant, then this function recursively visits the dividend, and returns + /// the quotient of the dividend's divisibility with the divisor. + IREE::Util::ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) { + if (divisibilityMap.contains(expr)) { + return divisibilityMap[expr]; + } + auto constRhs = dyn_cast(expr.getRHS()); + // Division by zero is undefined, so return the minimum divisibility. + if (!constRhs || constRhs.getValue() == 0) { + return IREE::Util::ConstantIntDivisibility(1, 1); + } + auto constValue = static_cast(std::abs(constRhs.getValue())); + IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS()); + uint64_t divUDiv = + lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1; + uint64_t divSDiv = + lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1; + return IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv); + } + + ExprDivisibilityMap &divisibilityMap; +}; + +/// Returns the divisibilities of each AffineMap result based on the +/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities` +/// should contain the divisibilities of the dims, followed by the +/// divisibilities of the symbols in ascending order by their positions. +static SmallVector getResultDivisibilities( + AffineMap map, + ArrayRef dimAndSymbolDivisibilities) { + // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities. + llvm::DenseMap + exprDivisibilityMap; + SmallVector inputExprs; + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumDims()), + [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); })); + inputExprs.append(llvm::map_to_vector( + llvm::seq(map.getNumSymbols()), + [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); })); + for (auto [expr, divisibility] : + llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) { + exprDivisibilityMap[expr] = divisibility; + } + AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap); + + // Walk each result expression and compute their divisibilities. + SmallVector resultDivisibilities; + for (AffineExpr resultExpr : map.getResults()) { + resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr)); + } + return resultDivisibilities; +} + +struct AffineApplyInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineApplyInferIntDivisibilityOpInterface, affine::AffineApplyOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineApplyOp = cast(op); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(affineApplyOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(affineApplyOp.getMap(), operandDivisibilities); + for (auto [result, divisibility] : + llvm::zip_equal(affineApplyOp->getResults(), resultDivisibilities)) { + setResultDivs(result, divisibility); + } + } +}; + +/// Infer the result divisibility of an affine.min or affine.max operation +/// based on its operand divisibilities. The result divisibility is the GCD +/// of the divisibilities of each of the affine map results, because the result +/// of the affine.min/max op could be any of these results. +template +static void inferAffineMinOrMaxResultDivisibility( + MinOrMaxTy minOrMaxOp, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) { + static_assert( + llvm::is_one_of::value, + "MinOrMaxTy must be affine::AffineMinOp or affine::AffineMaxOp"); + SmallVector operandDivisibilities; + for (auto [operand, divisibility] : + llvm::zip(minOrMaxOp.getOperands(), argDivs)) { + operandDivisibilities.push_back( + getDivisibilityOfOperand(operand, divisibility)); + } + + SmallVector resultDivisibilities = + getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities); + + IREE::Util::ConstantIntDivisibility resultDivisibility = + resultDivisibilities.pop_back_val(); + for (auto divisibility : resultDivisibilities) { + resultDivisibility = resultDivisibility.getUnion(divisibility); + } + setResultDivs(minOrMaxOp.getResult(), resultDivisibility); +} + +struct AffineMinInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineMinInferIntDivisibilityOpInterface, affine::AffineMinOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineMinOp = cast(op); + inferAffineMinOrMaxResultDivisibility(affineMinOp, argDivs, setResultDivs); + } +}; + +struct AffineMaxInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + AffineMaxInferIntDivisibilityOpInterface, affine::AffineMaxOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto affineMaxOp = cast(op); + inferAffineMinOrMaxResultDivisibility(affineMaxOp, argDivs, setResultDivs); + } +}; + struct ArithConstantInferIntDivisibilityOpInterface : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> { @@ -105,8 +333,13 @@ struct ArithDivUIInferIntDivisibilityOpInterface auto lhsDivisibility = getDivisibilityOfOperand(divOp.getLhs(), argDivs[0]); - uint64_t divUDiv = lhsDivisibility.udiv() / intVal.getZExtValue(); - uint64_t divSDiv = lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()); + uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0 + ? lhsDivisibility.udiv() / intVal.getZExtValue() + : 1; + uint64_t divSDiv = + lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0 + ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()) + : 1; setResultDivs(divOp, IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv)); } @@ -958,6 +1191,7 @@ struct RegionControlFlowHoistableOpInterfaceHelper { void registerUtilExternalModels(DialectRegistry ®istry) { // Must ensure that any dependent dialects are registered. + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -984,6 +1218,16 @@ void registerUtilExternalModels(DialectRegistry ®istry) { *context); }); + registry.addExtension( + +[](MLIRContext *context, affine::AffineDialect *dialect) { + affine::AffineApplyOp::attachInterface< + AffineApplyInferIntDivisibilityOpInterface>(*context); + affine::AffineMinOp::attachInterface< + AffineMinInferIntDivisibilityOpInterface>(*context); + affine::AffineMaxOp::attachInterface< + AffineMaxInferIntDivisibilityOpInterface>(*context); + }); + registry.addExtension( +[](MLIRContext *context, tensor::TensorDialect *dialect) { tensor::InsertSliceOp::attachInterface( From edb0411d96599c161a929f831849e2841b6dac67 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 15 Jan 2026 14:09:58 -0800 Subject: [PATCH 53/71] [Samples] Fix hal.executable.export syntax in custom dispatch samples. (#23154) The `hal.executable.export` assembly format requires `count` to come before `attributes`, but the HIP and CUDA custom dispatch samples had them in the wrong order. This was a regression from commit 6dcf0864ef which restructured the workgroup count region but incorrectly updated the sample files. Fixes #22877. Signed-off-by: hanhanW --- samples/custom_dispatch/cuda/kernels/README.md | 4 ++-- .../custom_dispatch/cuda/kernels/example.mlir | 16 ++++++++-------- samples/custom_dispatch/hip/kernels/example.mlir | 16 ++++++++-------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/samples/custom_dispatch/cuda/kernels/README.md b/samples/custom_dispatch/cuda/kernels/README.md index d4ebc0bf6e6f..3b784617a5b4 100644 --- a/samples/custom_dispatch/cuda/kernels/README.md +++ b/samples/custom_dispatch/cuda/kernels/README.md @@ -75,11 +75,11 @@ nvcc ... (TODO, see CMakeLists.txt) -o kernels_sm_80.ptx #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding - ]>) attributes {workgroup_size = [64 : index, 1 : index, 1 : index]} count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index - } + } attributes {workgroup_size = [64 : index, 1 : index, 1 : index]} } ``` diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir index 69f66e6008ae..f332e7302c7e 100644 --- a/samples/custom_dispatch/cuda/kernels/example.mlir +++ b/samples/custom_dispatch/cuda/kernels/example.mlir @@ -79,11 +79,7 @@ module @example attributes {hal.device.targets = [#cuda_target]} { #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding - ]>) attributes { - // Certain backends (like CUDA) require a workgroup size (aka block - // size) to be defined ahead of time. - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { // This host function is used to compute the XYZ workgroup count // dispatched at runtime. It can query the %device for capabilities // and limits (shared memory size, etc). The other arguments are the @@ -92,6 +88,10 @@ module @example attributes {hal.device.targets = [#cuda_target]} { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + // Certain backends (like CUDA) require a workgroup size (aka block + // size) to be defined ahead of time. + workgroup_size = [64 : index, 1 : index, 1 : index] } // Similar to the above but in-place by using a read/write binding. @@ -99,12 +99,12 @@ module @example attributes {hal.device.targets = [#cuda_target]} { layout(#hal.pipeline.layout, #hal.pipeline.binding - ]>) attributes { - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] } } // hal.executable.source diff --git a/samples/custom_dispatch/hip/kernels/example.mlir b/samples/custom_dispatch/hip/kernels/example.mlir index ed44046cb3c8..aa83fd5490f2 100644 --- a/samples/custom_dispatch/hip/kernels/example.mlir +++ b/samples/custom_dispatch/hip/kernels/example.mlir @@ -70,11 +70,7 @@ module @example attributes {hal.device.targets = [#rocm_target]} { #hal.pipeline.binding, #hal.pipeline.binding, #hal.pipeline.binding - ]>) attributes { - // Certain backends (like ROCM) require a workgroup size (aka block - // size) to be defined ahead of time. - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { // This host function is used to compute the XYZ workgroup count // dispatched at runtime. It can query the %device for capabilities // and limits (shared memory size, etc). The other arguments are the @@ -83,6 +79,10 @@ module @example attributes {hal.device.targets = [#rocm_target]} { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + // Certain backends (like ROCM) require a workgroup size (aka block + // size) to be defined ahead of time. + workgroup_size = [64 : index, 1 : index, 1 : index] } // Similar to the above but in-place by using a read/write binding. @@ -90,12 +90,12 @@ module @example attributes {hal.device.targets = [#rocm_target]} { layout(#hal.pipeline.layout, #hal.pipeline.binding - ]>) attributes { - workgroup_size = [64 : index, 1 : index, 1 : index] - } count(%device: !hal.device, %workload: index) -> (index, index, index) { + ]>) count(%device: !hal.device, %workload: index) -> (index, index, index) { %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload] %c1 = arith.constant 1 : index hal.return %x, %c1, %c1 : index, index, index + } attributes { + workgroup_size = [64 : index, 1 : index, 1 : index] } } // hal.executable.source From f650b08445facccb1a3b4bb970bb73a7d54ac838 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 15 Jan 2026 17:13:15 -0500 Subject: [PATCH 54/71] [Codegen][MXFP4] Add attribute to lowering_config for XOR swizzle support. (#23085) This is the third of a series of PRs that together implement support in IREE for XOR swizzling through the SwizzleHintOp. There are four PRs that need to be merged: 1) Allow rank > 1 swizzle hint op operands and add a pass to flatten swizzle hint allocs. 2) Add patterns which can fold reshapes and `extract_slice` ops into empty ops through swizzle hint ops. 3) Add swizzle hint attribute to be set in `lowering_config` and consumed in `GPUPromoteMatmulOperandsPass`. 4) Update `LLVMGPUSelectLoweringStrategy` Pass to set xor swizzles for MXFP4 GEMMs. This is PR 3, which does two things: - Adds a new attribute called `SwizzleOperand` which can be set in `lowering_config` at config selection. - Adds a operand promotion implementation which consumes the `SwizzleOperand` attribute to generate the right sequence of tensor.empty, swizzle hint and linalg.copy ops at `GPUPromoteMatmulOperands`. --------- Signed-off-by: Muzammiluddin Syed --- .../GPU/test/gpu_promote_matmul_operands.mlir | 87 +++++++++++++ .../Codegen/IR/IREECodegenInterfaces.td | 5 + .../Dialect/Codegen/IR/IREECodegenOps.td | 4 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 9 ++ .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.h | 4 + .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 45 +++++++ .../Codegen/Dialect/GPU/IR/PromotionImpls.cpp | 118 ++++++++++++++---- 7 files changed, 247 insertions(+), 25 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir index cb209b3db860..28600bdd4130 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir @@ -305,3 +305,90 @@ func.func @promote_with_cache_swizzle_f4_no_stride(%a: tensor<2x34x34x129xf4E2M1 // CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma // CHECK-SAME: ins(%[[SWIZZLE_B]] // CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]] + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{ + promote_operands = [0, 1], + promotion_types = [ + #iree_gpu.swizzle_operand>, + #iree_gpu.swizzle_operand>]}> + +func.func @promote_with_swizzle_operand(%a: tensor<32x64xf32>, %b: tensor<64x128xf32>) -> tensor<32x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<32x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%a, %b : tensor<32x64xf32>, tensor<64x128xf32>) outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32> + return %mm : tensor<32x128xf32> +} + +// SwizzleOperand attribute creates swizzle_hint op with xor_shuffle +// and flattens/expands the tensor for shared memory swizzling. +// CHECK-LABEL: func.func @promote_with_swizzle_operand +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x64xf32> +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x128xf32> +// CHECK: %[[EMPTY_A:.+]] = tensor.empty() : tensor<2048xf32> +// CHECK: %[[SWIZZLE_A:.+]] = iree_codegen.swizzle_hint %[[EMPTY_A]][#iree_codegen.xor_shuffle<128, 16>] : tensor<2048xf32> +// CHECK: %[[EXPAND_A:.+]] = tensor.expand_shape %[[SWIZZLE_A]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : tensor<2048xf32> into tensor<32x64xf32> +// CHECK: %[[COPY_A:.+]] = linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma +// CHECK-SAME: ins(%[[A]] : tensor<32x64xf32>) outs(%[[EXPAND_A]] : tensor<32x64xf32>) +// CHECK: %[[EMPTY_B:.+]] = tensor.empty() : tensor<8192xf32> +// CHECK: %[[SWIZZLE_B:.+]] = iree_codegen.swizzle_hint %[[EMPTY_B]][#iree_codegen.xor_shuffle<256, 32>] : tensor<8192xf32> +// CHECK: %[[EXPAND_B:.+]] = tensor.expand_shape %[[SWIZZLE_B]] {{\[\[}}0, 1{{\]\]}} output_shape [64, 128] : tensor<8192xf32> into tensor<64x128xf32> +// CHECK: %[[COPY_B:.+]] = linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config +// CHECK-SAME: ins(%[[B]] : tensor<64x128xf32>) outs(%[[EXPAND_B]] : tensor<64x128xf32>) +// CHECK: linalg.matmul {{.*}} ins(%[[COPY_A]], %[[COPY_B]] : tensor<32x64xf32>, tensor<64x128xf32>) + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{ + promote_operands = [1], + promotion_types = [ + #iree_gpu.swizzle_operand>]}> + +func.func @promote_with_swizzle_operand_f16(%a: tensor<32x64xf16>, %b: tensor<64x128xf16>) -> tensor<32x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<32x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%a, %b : tensor<32x64xf16>, tensor<64x128xf16>) outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32> + return %mm : tensor<32x128xf32> +} + +// SwizzleOperand with f16 element type. +// CHECK-LABEL: func.func @promote_with_swizzle_operand_f16 +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x64xf16> +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<64x128xf16> +// CHECK: %[[EMPTY_B:.+]] = tensor.empty() : tensor<8192xf16> +// CHECK: %[[SWIZZLE_B:.+]] = iree_codegen.swizzle_hint %[[EMPTY_B]][#iree_codegen.xor_shuffle<64, 8>] : tensor<8192xf16> +// CHECK: %[[EXPAND_B:.+]] = tensor.expand_shape %[[SWIZZLE_B]] {{\[\[}}0, 1{{\]\]}} output_shape [64, 128] : tensor<8192xf16> into tensor<64x128xf16> +// CHECK: %[[COPY_B:.+]] = linalg.copy +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma +// CHECK-SAME: ins(%[[B]] : tensor<64x128xf16>) outs(%[[EXPAND_B]] : tensor<64x128xf16>) +// CHECK: linalg.matmul {{.*}} ins(%[[A]], %[[COPY_B]] : tensor<32x64xf16>, tensor<64x128xf16>) + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{ + promote_operands = [0], + promotion_types = [ + #iree_gpu.swizzle_operand>]}> + +func.func @swizzle_operand_no_promote_fill(%b: tensor<128x128xf32>) -> tensor<4x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x128xf32>) -> tensor<4x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%fill, %b : tensor<4x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<4x128xf32>) -> tensor<4x128xf32> + return %mm : tensor<4x128xf32> +} + +// Verify that fills are not promoted even with swizzle_operand. +// CHECK-LABEL: func.func @swizzle_operand_no_promote_fill +// CHECK-NOT: iree_codegen.swizzle_hint +// CHECK-NOT: tensor.expand_shape +// CHECK: linalg.matmul +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td index bf63c8850e45..5fa6512c6433 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td @@ -619,6 +619,11 @@ def IREECodegen_AnySwizzleAttr : Attr; + def IREECodegen_UKernelProviderInterface : AttrInterface<"UKernelProviderInterface"> { let cppNamespace = "::mlir::iree_compiler::IREE::Codegen"; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td index 8c2bdab9cff7..a59e341f9657 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td @@ -173,9 +173,9 @@ def IREECodegen_SwizzleHintOp : Op:$operand, + let arguments = (ins AnyRankedTensorOrMemRef:$operand, IREECodegen_AnySwizzleAttr:$swizzle); - let results = (outs RankedTensorOrMemRefOf<[AnyType], [1]>:$result); + let results = (outs AnyRankedTensorOrMemRef:$result); let assemblyFormat = [{ $operand `[` $swizzle attr-dict `]` `:` type($result) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 45595fcc1fbb..75379573633b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -2281,6 +2281,15 @@ Value PromoteWithCacheSwizzleAttr::promoteOperand( return cacheSwizzlePromotionImpl(builder, operand, getCopyConfig()); } +//===----------------------------------------------------------------------===// +// SwizzleOperandAttr +//===----------------------------------------------------------------------===// + +Value SwizzleOperandAttr::promoteOperand(mlir::OpBuilder &builder, + mlir::OpOperand &operand) const { + return swizzlePromotionImpl(builder, operand, getCopyConfig(), getSwizzle()); +} + //===----------------------------------------------------------------------===// // LaneIdAttr //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h index 8a05e5b7ea64..920081181beb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h @@ -295,6 +295,10 @@ StringRef getTilingLevelName(GPU::TilingLevel level); Value cacheSwizzlePromotionImpl(OpBuilder &builder, OpOperand &operand, Attribute attr); +Value swizzlePromotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr, + Codegen::SwizzleAttrInterface swizzle); + } // namespace mlir::iree_compiler::IREE::GPU // clang-format off diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index 9c93c3d43c4d..9766a468ee59 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -131,6 +131,51 @@ def IREEGPU_PromoteWithCacheSwizzle : ); } +def IREEGPU_SwizzleOperand : + AttrDef + ]> { + let mnemonic = "swizzle_operand"; + let summary = [{ + Indicate promotion of an operand with setting an xor swizzle value. + }]; + let description = [{ + During matmul operand promotion, we generate copies associated to a + particular matmul operand with specific lowering configuration optimized + for coalesced loads. This attribute carries information on how accesses to + memrefs or tensors associated to a particular copy should be swizzled. + + This information is used to create a swizzle hint op on the alloc + associated with the copy. Ultimately, this aims to modify memory accesses + to minimize bank conflicts. For example, + + ```mlir + %0 = tensor_ext.dispatch.tensor.load : tensor + %1 = linalg.matmul ins(%0, ...) + ``` + + Becomes with `#iree_gpu.swizzle_operand<#iree_gpu.use_global_load_dma>` + + ```mlir + %0 = tensor_ext.dispatch.tensor.load : tensor + %1 = tensor.empty() + %2 = swizzle_hint_op %1 xor_shuffle(256, 32) + %3 = linalg.copy lowering_config = #iree_gpu.use_global_load_dma ins(%0) outs(%1) + %4 = linalg.matmul ins(%3, ...) + ``` + + With intelligent selection of `row_width` and `access_width`, this should + minimize bank conflicts. + }]; + let assemblyFormat = "`<` struct(params) `>`"; + let parameters = (ins + "Attribute":$copy_config, + IREECodegen_SwizzleAttrParameter:$swizzle + ); +} + //===----------------------------------------------------------------------===// // GPU Workgroup Processor (WGP) Level Feature/Limit Attributes //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp index 080c2e14485c..6f683bf87f2d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp @@ -8,9 +8,11 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "llvm/Support/DebugLog.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -18,13 +20,14 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/TilingInterface.h" +#define DEBUG_TYPE "iree-codegen-promotion-utils" + namespace mlir::iree_compiler::IREE::GPU { /// Helper to insert copy with the specified attr. Value promoteValue(OpBuilder &builder, Location loc, Value v, Attribute attr) { auto tensorType = cast(v.getType()); SmallVector mixedSizes = tensor::getMixedSizes(builder, loc, v); - Value empty = tensor::EmptyOp::create(builder, loc, mixedSizes, tensorType.getElementType()); auto copy = linalg::CopyOp::create(builder, loc, v, empty); @@ -32,28 +35,35 @@ Value promoteValue(OpBuilder &builder, Location loc, Value v, Attribute attr) { return copy.getResult(0); } -/// Inserts a `linalg.copy` directly before the given operation on the -/// specified operand, for example with operand index = 1: -/// -/// %2 = linalg.matmul ins(%0, %1) -/// -/// becomes -/// -/// %empty = tensor.empty() -/// %copy = linalg.copy %1 to %empty { -/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}} -/// linalg.matmul ins(%0, %copy) -/// -/// If the producer is already a tilable op, the producer is just annotated with -/// the underlying attribute. -/// Additionally we can also promote results so in above example we will -/// generate for index = 2 : -/// %out_buffer = bufferization.alloc_tensor -/// %copy1 = linalg.copy %2 to %out_buffer -/// %copy2 = linalg.copy %copy1 to %empty { -/// lowering_config = #iree_gpu.derived_thread_config} -Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand, - Attribute attr) { +// Helper to insert a swizzle hint op and flatten the associated alloc. +Value swizzlePromoteValue(OpBuilder &builder, Location loc, Value v, + Attribute attr, + Codegen::SwizzleAttrInterface swizzle) { + auto tensorType = cast(v.getType()); + int64_t numElements = tensorType.getNumElements(); + SmallVector sizes = tensor::getMixedSizes(builder, loc, v); + bool hasStaticShape = tensorType.hasStaticShape(); + if (hasStaticShape) { + sizes = {builder.getIndexAttr(numElements)}; + } + Value alloc = + tensor::EmptyOp::create(builder, loc, sizes, tensorType.getElementType()); + + // Only generate a swizzle hint op if the shape is static. + if (hasStaticShape) { + Value swizzled = + IREE::Codegen::SwizzleHintOp::create(builder, loc, alloc, swizzle); + alloc = tensor::ExpandShapeOp::create( + builder, loc, tensorType, swizzled, + {llvm::to_vector(llvm::seq(tensorType.getRank()))}); + } + auto copy = linalg::CopyOp::create(builder, loc, v, alloc); + setLoweringConfig(copy, attr); + return copy.getResult(0); +} + +std::optional promotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr) { if (auto producer = operand.get().getDefiningOp()) { // Skip promotion of fills. if (isa(producer)) { @@ -78,11 +88,73 @@ Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand, if (!tensorType) { return operand.get(); } + return std::nullopt; +} +/// Inserts a `linalg.copy` directly before the given operation on the +/// specified operand, for example with operand index = 1: +/// +/// ```mlir +/// %2 = linalg.matmul ins(%0, %1) +/// ``` +/// +/// becomes +/// +/// ```mlir +/// %empty = tensor.empty() +/// %copy = linalg.copy %1 to %empty { +/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}} +/// linalg.matmul ins(%0, %copy) +/// ``` +/// +/// If the producer is already a tilable op, the producer is just annotated with +/// the underlying attribute. +/// Additionally we can also promote results so in above example we will +/// generate for index = 2 : +/// +/// ```mlir +/// %out_buffer = bufferization.alloc_tensor +/// %copy1 = linalg.copy %2 to %out_buffer +/// %copy2 = linalg.copy %copy1 to %empty { +/// lowering_config = #iree_gpu.derived_thread_config} +/// ``` +Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr) { + std::optional promotedValue = promotionImpl(builder, operand, attr); + if (promotedValue.has_value()) { + return promotedValue.value(); + } return promoteValue(builder, operand.getOwner()->getLoc(), operand.get(), attr); } +/// Inserts a `linalg.copy` directly before the given operation on the +/// specified operand, similar to the defaultPromotionImpl. +/// The difference is this also assigns a `iree_codegen.swizzle_hint` op +/// to the generated `tensor.empty` op. +/// For example: +/// ```mlir +/// %2 = linalg.matmul ins(%0, %1) +/// ``` +/// becomes +/// ```mlir +/// %empty = tensor.empty() +/// %swizzle = iree_codegen.swizzle_hint %empty[...] +/// %copy = linalg.copy %1 to %swizzle { +/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}} +/// linalg.matmul ins(%0, %copy) +/// ``` +Value swizzlePromotionImpl(OpBuilder &builder, OpOperand &operand, + Attribute attr, + Codegen::SwizzleAttrInterface swizzle) { + std::optional promotedValue = promotionImpl(builder, operand, attr); + if (promotedValue.has_value()) { + return promotedValue.value(); + } + return swizzlePromoteValue(builder, operand.getOwner()->getLoc(), + operand.get(), attr, swizzle); +} + /// Inserts a `linalg.copy` directly before the given operation on the /// specified operand, and also inserts a buffer_resource_cast on the producing /// dispatch input if possible. From 312c7a67da9fcf09c22d5298313798ab5b677d32 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 15 Jan 2026 17:14:38 -0500 Subject: [PATCH 55/71] Implement to-nearest-even for denormals (#23134) As noticed in https://github.com/iree-org/iree/pull/23119#issuecomment-3752401681, the rounding of denormals was not implementing ties-to-nearest-even. Signed-off-by: Benoit Jacob --- runtime/src/iree/base/internal/math.h | 30 +++++++++++-------- runtime/src/iree/base/internal/math_test.cc | 32 +++++++++++++++++++++ 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/runtime/src/iree/base/internal/math.h b/runtime/src/iree/base/internal/math.h index eb8c9b17ecfb..bcbbcc729557 100644 --- a/runtime/src/iree/base/internal/math.h +++ b/runtime/src/iree/base/internal/math.h @@ -334,6 +334,15 @@ static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits, (src_exp >> src_exp_shift) - src_exp_bias - src_mantissa_bits); } +// Helper for rounding to nearest-even. Does not right-shift. Returns the +// biased value suitable for right-shifting. +static inline uint32_t bias_to_nearest_even(uint32_t input, int shift_amount) { + uint32_t even_bit = 1u << shift_amount; + uint32_t odd_bit = even_bit >> 1; + uint32_t bias = (input & even_bit) ? (odd_bit) : (odd_bit - 1); + return input + bias; +} + // Generic conversion from f32 to any less-than-32-bit floating-point format, // rounding to nearest-even. The return value is typed as a uint32_t for // genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits. @@ -370,8 +379,8 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( // can remain nonzero. This happens only with the bf16 type. // Just divide the mantissa (rounding shift). int shift_amount = f32_mantissa_bits - dst_mantissa_bits; - uint32_t rounding_term = 1 << (shift_amount - 1); - dst_mantissa = (f32_mantissa + rounding_term) >> shift_amount; + dst_mantissa = + bias_to_nearest_even(f32_mantissa, shift_amount) >> shift_amount; } // The destination type has fewer exponent bits, so f32 subnormal values // become exactly zero. Leave the mantissa zero. @@ -398,21 +407,18 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( dst_mantissa = 0; } else { // Source f32 value is normal so has an implied 1... leading bit. - int effective_f32_mantissa = (1 << f32_mantissa_bits) + f32_mantissa; - // Add this term to achieve rounding to nearest instead of truncation - // towards zero. - int rounding_term = 1 << (shift_amount - 1); - // Finally compute the destination mantissa as a rounded right shift. - dst_mantissa = (effective_f32_mantissa + rounding_term) >> shift_amount; + uint32_t effective_f32_mantissa = + (1u << f32_mantissa_bits) + f32_mantissa; + dst_mantissa = + bias_to_nearest_even(effective_f32_mantissa, shift_amount) >> + shift_amount; } } else { // Normal case. // Implement round-to-nearest-even, by adding a bias before truncating. - int even_bit = 1u << (f32_mantissa_bits - dst_mantissa_bits); - int odd_bit = even_bit >> 1; + int shift_amount = f32_mantissa_bits - dst_mantissa_bits; uint32_t biased_f32_mantissa = - f32_mantissa + - ((f32_mantissa & even_bit) ? (odd_bit) : (odd_bit - 1)); + bias_to_nearest_even(f32_mantissa, shift_amount); // Adding the bias may cause an exponent increment. if (biased_f32_mantissa > f32_mantissa_mask) { // Note: software implementations that try to be fast tend to get this diff --git a/runtime/src/iree/base/internal/math_test.cc b/runtime/src/iree/base/internal/math_test.cc index 0c2f05008ca2..222555da7f47 100644 --- a/runtime/src/iree/base/internal/math_test.cc +++ b/runtime/src/iree/base/internal/math_test.cc @@ -313,6 +313,22 @@ TEST(BF16ConversionTest, F32ToBF16) { EXPECT_EQ(0xff80, iree_math_f32_to_bf16(-FLT_MAX)); EXPECT_EQ(0x0080, iree_math_f32_to_bf16(FLT_MIN)); EXPECT_EQ(0x8080, iree_math_f32_to_bf16(-FLT_MIN)); + // Test some round-to-nearest-even. F32->BF16 is interesting because F32 + // denormals can round to nonzero BF16 denormals. + EXPECT_EQ(0x0000, iree_math_f32_to_bf16(FLT_MIN * 1.0f / 256.f)); + EXPECT_EQ(0x0001, iree_math_f32_to_bf16(FLT_MIN * 2.0f / 256.f)); + EXPECT_EQ(0x0002, iree_math_f32_to_bf16(FLT_MIN * 3.0f / 256.f)); + EXPECT_EQ(0x0002, iree_math_f32_to_bf16(FLT_MIN * 4.0f / 256.f)); + EXPECT_EQ(0x0002, iree_math_f32_to_bf16(FLT_MIN * 5.0f / 256.f)); + EXPECT_EQ(0x0003, iree_math_f32_to_bf16(FLT_MIN * 6.0f / 256.f)); + EXPECT_EQ(0x0004, iree_math_f32_to_bf16(FLT_MIN * 7.0f / 256.f)); + EXPECT_EQ(0x8000, iree_math_f32_to_bf16(FLT_MIN * -1.0f / 256.f)); + EXPECT_EQ(0x8001, iree_math_f32_to_bf16(FLT_MIN * -2.0f / 256.f)); + EXPECT_EQ(0x8002, iree_math_f32_to_bf16(FLT_MIN * -3.0f / 256.f)); + EXPECT_EQ(0x8002, iree_math_f32_to_bf16(FLT_MIN * -4.0f / 256.f)); + EXPECT_EQ(0x8002, iree_math_f32_to_bf16(FLT_MIN * -5.0f / 256.f)); + EXPECT_EQ(0x8003, iree_math_f32_to_bf16(FLT_MIN * -6.0f / 256.f)); + EXPECT_EQ(0x8004, iree_math_f32_to_bf16(FLT_MIN * -7.0f / 256.f)); } TEST(BF16ConversionTest, Denormals) { @@ -503,6 +519,22 @@ TEST(F8E4M3FNConversionTest, F32ToF8E4M3FN) { EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fn(304.0f)); EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fn(336.0f)); EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3fn(368.0f)); + // Test round-to-nearest-even for denormals. + EXPECT_EQ(0x00, iree_math_f32_to_f8e4m3fn(0.5f / 512.f)); + EXPECT_EQ(0x01, iree_math_f32_to_f8e4m3fn(1.f / 512.f)); + EXPECT_EQ(0x02, iree_math_f32_to_f8e4m3fn(1.5f / 512.f)); + EXPECT_EQ(0x02, iree_math_f32_to_f8e4m3fn(2.f / 512.f)); + EXPECT_EQ(0x02, iree_math_f32_to_f8e4m3fn(2.5f / 512.f)); + EXPECT_EQ(0x03, iree_math_f32_to_f8e4m3fn(3.f / 512.f)); + EXPECT_EQ(0x04, iree_math_f32_to_f8e4m3fn(3.5f / 512.f)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fn(-0.5f / 512.f)); + EXPECT_EQ(0x81, iree_math_f32_to_f8e4m3fn(-1.f / 512.f)); + EXPECT_EQ(0x82, iree_math_f32_to_f8e4m3fn(-1.5f / 512.f)); + EXPECT_EQ(0x82, iree_math_f32_to_f8e4m3fn(-2.f / 512.f)); + EXPECT_EQ(0x82, iree_math_f32_to_f8e4m3fn(-2.5f / 512.f)); + EXPECT_EQ(0x83, iree_math_f32_to_f8e4m3fn(-3.f / 512.f)); + EXPECT_EQ(0x84, iree_math_f32_to_f8e4m3fn(-3.5f / 512.f)); + // Important case to test: overflow due to rounding to nearest-even of 465 // to 512, while 464 gets rounded to nearest-even 448, not overflowing. EXPECT_EQ(0x7E, iree_math_f32_to_f8e4m3fn(464.f)); From 676c0acc29d70fefe69e855b3ef722808d7eaf07 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 15 Jan 2026 17:14:49 -0500 Subject: [PATCH 56/71] Disable bf16 ukernel Clang workaround on newer Clang. (#23071) bf16 ukernels use inline asm to work around multiple bugs in the `_mm512_dpbf16_ps` intrinsic on Clang, but the last such bug was fixed at least as far as Clang 20, so disable the workaround on newer Clang. Signed-off-by: Benoit Jacob --- .../builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c index 97eae890a3de..55f30291ca2e 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c @@ -7,7 +7,8 @@ #include "iree/builtins/ukernel/arch/x86_64/common_x86_64.h" #include "iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_internal.h" -#if defined(IREE_UK_COMPILER_CLANG) && !defined(IREE_UK_COMPILER_MSVC) +#if defined(IREE_UK_COMPILER_CLANG) && !defined(IREE_UK_COMPILER_MSVC) && \ + !IREE_UK_COMPILER_CLANG_VERSION_AT_LEAST(20, 0) // This inline-asm function is a work-around for: // 1. https://github.com/llvm/llvm-project/issues/68117 // Summary: LLVM crash affecting Clang 16-17. Fixed in Clang 18. From 331a3d02100c3d1c4ad2163852f7510b844b49ba Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Thu, 15 Jan 2026 17:23:07 -0500 Subject: [PATCH 57/71] [CI] Update iree-test-suite (#23135) This update adds new quality tests to torch_ops which include: * Tests for GEMM with interesting shapes https://github.com/iree-org/iree-test-suites/pull/146 * Skip's compiling i8 test due to https://github.com/iree-org/iree/issues/23136 * Skip's running some vulkan tests issue here https://github.com/iree-org/iree/issues/23141 ci-extra: test_torch --- .github/workflows/pkgci_test_onnx.yml | 4 ++-- .github/workflows/pkgci_test_sharktank.yml | 4 ++-- .github/workflows/pkgci_test_torch.yml | 4 ++-- .../torch_ops/torch_ops_cpu_llvm_sync.json | 4 +++- .../torch_ops/torch_ops_gpu_hip_gfx1100_O3.json | 4 +++- .../torch_ops/torch_ops_gpu_hip_gfx942_O3.json | 4 +++- .../torch_ops/torch_ops_gpu_vulkan_O3.json | 9 +++++++-- 7 files changed, 22 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pkgci_test_onnx.yml b/.github/workflows/pkgci_test_onnx.yml index db1e1ddb1571..9e228f45341a 100644 --- a/.github/workflows/pkgci_test_onnx.yml +++ b/.github/workflows/pkgci_test_onnx.yml @@ -103,7 +103,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites - name: Install ONNX ops test suite requirements run: | @@ -189,7 +189,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites - name: Install ONNX models test suite requirements run: | diff --git a/.github/workflows/pkgci_test_sharktank.yml b/.github/workflows/pkgci_test_sharktank.yml index 4f048bcb52d1..031327dd2a4e 100644 --- a/.github/workflows/pkgci_test_sharktank.yml +++ b/.github/workflows/pkgci_test_sharktank.yml @@ -88,7 +88,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites lfs: true @@ -197,7 +197,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites lfs: true diff --git a/.github/workflows/pkgci_test_torch.yml b/.github/workflows/pkgci_test_torch.yml index 6780ef009ceb..3627b31a4c3d 100644 --- a/.github/workflows/pkgci_test_torch.yml +++ b/.github/workflows/pkgci_test_torch.yml @@ -74,7 +74,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites - name: Install Torch ops test suite requirements run: | @@ -138,7 +138,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: repository: iree-org/iree-test-suites - ref: dc50625f4ac9d561f52ced410b8470b8168ed8a1 + ref: 17ead09be6d84bf46d80e6192dc12e45ba776045 path: iree-test-suites # Don't need lfs for torch models yet. lfs: false diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json index 78bb259473b5..376e5cfc1ebc 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_cpu_llvm_sync.json @@ -6,7 +6,9 @@ "--iree-llvmcpu-target-cpu=host" ], "iree_run_module_flags": [], - "skip_compile_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], "skip_run_tests": [ "AB/8192x8192xf32_bench", "AB/4096x4096xf32_bench", diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json index 79637b8c210c..4f8f6ff4b11c 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1100_O3.json @@ -8,7 +8,9 @@ "iree_run_module_flags": [ "--device=hip" ], - "skip_compile_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], "skip_run_tests": [ "ABPlusC/64x64xf16", "ATB/64x64xf16" diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json index 22d007362704..858a5eadcc0c 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx942_O3.json @@ -8,7 +8,9 @@ "iree_run_module_flags": [ "--device=hip" ], - "skip_compile_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], "skip_run_tests": [ "ATB/64x64xf16" ], diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json index 87e32aba227f..d511f7942c3b 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_vulkan_O3.json @@ -14,9 +14,14 @@ "ReluABPlusC/64x64xf16", "GeluABPlusC/64x64xf16", "AB/64x64xf16", - "AB/Nx64xf16_64xNxf16" + "AB/Nx64xf16_64xNxf16", + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], + "skip_run_tests": [ + "InterestingShapesBiasAdd/1152x997xf16_matmul_997x576xf16_NN", + "InterestingShapesBiasAdd/6144x419xbf16_matmul_419x384xbf16_NT", + "InterestingShapesBiasAdd/997x997xf16_NT_bias" ], - "skip_run_tests": [], "expected_compile_failures": [], "expected_run_failures": [], "golden_times_ms": { From 71c48c84712fe3146539cb06e82bc390f97f2d03 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 15 Jan 2026 22:18:05 -0500 Subject: [PATCH 58/71] [CI] Add rdna4 / gfx1201 / r9700 tests (#23156) These should run on shark75-ci. --- .github/workflows/pkgci.yml | 7 ++ .github/workflows/pkgci_test_amd_r9700.yml | 67 +++++++++++++++++++ .github/workflows/pkgci_test_onnx.yml | 8 +++ .github/workflows/pkgci_test_torch.yml | 3 + .../onnx_models_gpu_hip_rdna4.json | 30 +++++++++ .../onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json | 27 ++++++++ .../torch_ops_gpu_hip_gfx1201_O3.json | 19 ++++++ 7 files changed, 161 insertions(+) create mode 100644 .github/workflows/pkgci_test_amd_r9700.yml create mode 100644 tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json create mode 100644 tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json create mode 100644 tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json diff --git a/.github/workflows/pkgci.yml b/.github/workflows/pkgci.yml index 48013b8e6757..8ebe09c4b330 100644 --- a/.github/workflows/pkgci.yml +++ b/.github/workflows/pkgci.yml @@ -64,6 +64,12 @@ jobs: if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_amd_w7900') uses: ./.github/workflows/pkgci_test_amd_w7900.yml + test_amd_r9700: + name: Test AMD R9700 + needs: [setup, build_packages] + if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_amd_r9700') + uses: ./.github/workflows/pkgci_test_amd_r9700.yml + # TODO(#18238): migrate to new runner cluster # test_nvidia_t4: # name: Test NVIDIA T4 @@ -135,6 +141,7 @@ jobs: - test_amd_mi250 - test_amd_mi325 - test_amd_w7900 + - test_amd_r9700 # - test_nvidia_t4 - test_android - test_riscv64 diff --git a/.github/workflows/pkgci_test_amd_r9700.yml b/.github/workflows/pkgci_test_amd_r9700.yml new file mode 100644 index 000000000000..736769da82f7 --- /dev/null +++ b/.github/workflows/pkgci_test_amd_r9700.yml @@ -0,0 +1,67 @@ +# Copyright 2026 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: PkgCI Test AMD R9700 +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + workflow_dispatch: + inputs: + artifact_run_id: + type: string + default: "" + +jobs: + test_r9700: + runs-on: [Linux, X64, iree-r9700] + env: + PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages + BUILD_DIR: build-tests + VENV_DIR: ${{ github.workspace }}/.venv + GH_TOKEN: ${{ github.token }} + IREE_CPU_DISABLE: 1 + IREE_VULKAN_DISABLE: 0 + IREE_CUDA_ENABLE: 0 + IREE_HIP_ENABLE: 1 + IREE_HIP_TEST_TARGET_CHIP: "gfx1201" + steps: + - name: Check out repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + submodules: false + - name: Check out runtime submodules + run: ./build_tools/scripts/git/update_runtime_submodules.sh + - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + with: + # Must match the subset of versions built in pkgci_build_packages. + python-version: "3.11" + - uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + if: ${{ inputs.artifact_run_id == '' }} + with: + name: linux_x86_64_release_packages + path: ${{ env.PACKAGE_DOWNLOAD_DIR }} + - name: Setup base venv + run: | + ./build_tools/pkgci/setup_venv.py ${VENV_DIR} \ + --artifact-path=${PACKAGE_DOWNLOAD_DIR} \ + --fetch-gh-workflow=${{ inputs.artifact_run_id }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Build tests + run: ./build_tools/pkgci/build_tests_using_package.sh ${VENV_DIR}/bin + - name: Run GPU tests + env: + CTEST_PARALLEL_LEVEL: 1 + IREE_CTEST_LABEL_REGEX: ^requires-gpu|^driver=vulkan$|^driver=hip$ + IREE_AMD_RDNA4_TESTS_DISABLE: 0 + IREE_NVIDIA_GPU_TESTS_DISABLE: 0 + IREE_NVIDIA_SM80_TESTS_DISABLE: 1 + IREE_MULTI_DEVICE_TESTS_DISABLE: 0 + run: ./build_tools/cmake/ctest_all.sh ${BUILD_DIR} diff --git a/.github/workflows/pkgci_test_onnx.yml b/.github/workflows/pkgci_test_onnx.yml index 9e228f45341a..1c2dc7fcd399 100644 --- a/.github/workflows/pkgci_test_onnx.yml +++ b/.github/workflows/pkgci_test_onnx.yml @@ -47,6 +47,11 @@ jobs: numprocesses: 1 config-file: onnx_ops_gpu_hip_rdna3_O3.json runs-on: [Linux, X64, gfx1100] + # TODO(#23160): Fix the onnx ops test suite for gfx1201. + # - name: amdgpu_hip_rdna4_O3 + # numprocesses: 1 + # config-file: onnx_ops_gpu_hip_rdna4_O3.json + # runs-on: [Linux, X64, gfx1201] - name: amdgpu_vulkan_O0 numprocesses: 1 config-file: onnx_ops_gpu_vulkan_O0.json @@ -154,6 +159,9 @@ jobs: - name: amdgpu_hip_rdna3 config-file: onnx_models_gpu_hip_rdna3.json runs-on: [Linux, X64, gfx1100, persistent-cache] + - name: amdgpu_hip_rdna4 + config-file: onnx_models_gpu_hip_rdna4.json + runs-on: [Linux, X64, gfx1201, persistent-cache] - name: amdgpu_vulkan config-file: onnx_models_gpu_vulkan.json # TODO(#22579): Remove `shark10-ci` label. There are vulkan driver issues on other runners. diff --git a/.github/workflows/pkgci_test_torch.yml b/.github/workflows/pkgci_test_torch.yml index 3627b31a4c3d..d58bf432be78 100644 --- a/.github/workflows/pkgci_test_torch.yml +++ b/.github/workflows/pkgci_test_torch.yml @@ -37,6 +37,9 @@ jobs: - name: amdgpu_hip_gfx1100_O3 config-file: torch_ops_gpu_hip_gfx1100_O3.json runs-on: [Linux, X64, gfx1100] + - name: amdgpu_hip_gfx1201_O3 + config-file: torch_ops_gpu_hip_gfx1201_O3.json + runs-on: [Linux, X64, gfx1201] - name: amdgpu_vulkan_O3 config-file: torch_ops_gpu_vulkan_O3.json # TODO(#22579): Remove `shark10-ci` label. There are vulkan driver issues on other runners. diff --git a/tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json b/tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json new file mode 100644 index 000000000000..cca9d5e3cc8f --- /dev/null +++ b/tests/external/iree-test-suites/onnx_models/onnx_models_gpu_hip_rdna4.json @@ -0,0 +1,30 @@ +{ + "config_name": "gpu_hip_rdna4", + "iree_compile_flags": [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx1201" + ], + "iree_run_module_flags": [ + "--device=hip" + ], + "tests_and_expected_outcomes": { + "default": "skip", + "tests/model_zoo/validated/vision/body_analysis_models_test.py::test_models[age_gender/models/age_googlenet.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[alexnet/model/bvlcalexnet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[caffenet/model/caffenet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[densenet-121/model/densenet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[efficientnet-lite4/model/efficientnet-lite4-11.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/googlenet/model/googlenet-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/inception_v2/model/inception-v2-9.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[mnist/model/mnist-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[mobilenet/model/mobilenetv2-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[resnet/model/resnet50-v1-12.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[resnet/model/resnet50-v2-7.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[shufflenet/model/shufflenet-9.onnx]": "pass", + "tests/model_zoo/validated/vision/classification_models_test.py::test_models[shufflenet/model/shufflenet-v2-12.onnx]": "pass", + "tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[tiny-yolov2/model/tinyyolov2-8.onnx]": "pass", + "tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[yolov2-coco/model/yolov2-coco-9.onnx]": "pass", + "tests/model_zoo/validated/vision/super_resolution_models_test.py::test_models[sub_pixel_cnn_2016/model/super-resolution-10.onnx]": "pass" + } +} diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json new file mode 100644 index 000000000000..b7afbda04c6a --- /dev/null +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json @@ -0,0 +1,27 @@ +{ + "config_name": "gpu_hip_rdna4", + "iree_compile_flags": [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx1201", + "--iree-input-demote-f64-to-f32=false", + "--iree-opt-level=O3" + ], + "iree_run_module_flags": [ + "--device=hip" + ], + "skip_compile_tests": [ + "onnx/node/generated/test_dequantizelinear", + "onnx/node/generated/test_einsum_inner_prod", + "onnx/node/generated/test_group_normalization_epsilon_expanded", + "onnx/node/generated/test_group_normalization_example_expanded", + "onnx/node/generated/test_nonmaxsuppression_two_batches", + "onnx/node/generated/test_constantofshape_int_shape_zero" + ], + "skip_run_tests": [ + "onnx/node/generated/test_top_k", + "onnx/node/generated/test_top_k_negative_axis", + "onnx/node/generated/test_top_k_smallest" + ], + "expected_compile_failures": [], + "expected_run_failures": [] +} diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json new file mode 100644 index 000000000000..a1dd3b1519cf --- /dev/null +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json @@ -0,0 +1,19 @@ +{ + "config_name": "gpu_hip_gfx1201", + "iree_compile_flags": [ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx1201", + "--iree-opt-level=O3" + ], + "iree_run_module_flags": [ + "--device=hip" + ], + "skip_compile_tests": [], + "skip_run_tests": [ + "ABPlusC/64x64xf16", + "ATB/64x64xf16" + ], + "expected_compile_failures": [], + "expected_run_failures": [], + "golden_times_ms": {} +} From 3d20ed0f00871cea3f51fbfaf481f6c54790984c Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 15 Jan 2026 23:14:46 -0500 Subject: [PATCH 59/71] [Codegen][MXFP4] Add folding patterns for tensor.empty op that can bypass SwizzleHintOps (#23084) This is the second of a series of PRs that together implement support in IREE for XOR swizzling through the SwizzleHintOp. There are four PRs that need to be merged: 1) Allow rank > 1 swizzle hint op operands and add a pass to flatten swizzle hint allocs. 2) Add patterns which can fold reshapes and `extract_slice` ops into empty ops through swizzle hint ops. 3) Add swizzle hint attribute to be set in `lowering_config` and consumed in `GPUPromoteMatmulOperandsPass`. 4) Update `LLVMGPUSelectLoweringStrategy` Pass to set xor swizzles for MXFP4 GEMMs. This is PR 2, which does two things: - duplicates folding patterns for tensor.empty op from upstream llvm-project in IREE, but with support for swizzle hint ops. - Adds these patterns to the `GPUApplyTilingPass`. --------- Signed-off-by: Muzammiluddin Syed --- .../Common/GPU/GPUApplyTilingLevel.cpp | 3 +- .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + .../GPU/test/gpu_apply_tiling_level.mlir | 39 ++++++ .../GPU/test/gpu_fold_swizzle_hint_ops.mlir | 120 +++++++++++++++++ .../Dialect/GPU/Transforms/Transforms.cpp | 127 +++++++++++++++++- .../Dialect/GPU/Transforms/Transforms.h | 3 + 7 files changed, 291 insertions(+), 3 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp index 8ab4d99bbf73..b5c51c5cad45 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp @@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLForwardCompat.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -61,7 +62,6 @@ getTiledOps(Operation *funcOp, IREE::GPU::TilingLevel tilingLevel) { void GPUApplyTilingLevelPass::runOnOperation() { FunctionOpInterface funcOp = getOperation(); - if (!llvm::is_contained({IREE::GPU::TilingLevel::Reduction, IREE::GPU::TilingLevel::Thread, IREE::GPU::TilingLevel::Subgroup, @@ -107,6 +107,7 @@ void GPUApplyTilingLevelPass::runOnOperation() { // Apply cleanup patterns. { RewritePatternSet patterns(context); + IREE::GPU::populateFoldSwizzleHintOpPatterns(patterns); // Merge consecutive insert/extract slice ops to simplify later loop // hoisting patterns. tensor::populateFoldTensorEmptyPatterns(patterns); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 5e9bbe2b9fed..4f1aeda88479 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -38,6 +38,7 @@ iree_lit_test_suite( "gpu_distribute_scf_for.mlir", "gpu_distribute_shared_memory.mlir", "gpu_expand_dimensions.mlir", + "gpu_fold_swizzle_hint_ops.mlir", "gpu_fuse_and_hoist_forall.mlir", "gpu_generalize_named_ops.mlir", "gpu_greedily_distribute_to_threads.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 4ce6c005b783..dde2c3d34120 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -33,6 +33,7 @@ iree_lit_test_suite( "gpu_distribute_scf_for.mlir" "gpu_distribute_shared_memory.mlir" "gpu_expand_dimensions.mlir" + "gpu_fold_swizzle_hint_ops.mlir" "gpu_fuse_and_hoist_forall.mlir" "gpu_generalize_named_ops.mlir" "gpu_greedily_distribute_to_threads.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir index 348bc4db92be..d5e21133494c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir @@ -735,3 +735,42 @@ module { // SERIAL: linalg.generic // SERIAL: scf.forall.in_parallel // SERIAL-NOT: mapping + +// ----- + +func.func @matmul_transpose_b_with_swizzle(%5: tensor<64x64xf32>, %6: tensor<64x1280xf16>, %7: tensor<64x1280xf16>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c1280 = arith.constant 1280 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %8 = linalg.fill ins(%cst : f32) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32> + %9 = tensor.empty() : tensor<64x1280xf16> + %swizzle_9 = iree_codegen.swizzle_hint %9[#iree_codegen.xor_shuffle<256, 32>] : tensor<64x1280xf16> + %10 = tensor.empty() : tensor<64x1280xf16> + %swizzle_10 = iree_codegen.swizzle_hint %10[#iree_codegen.xor_shuffle<256, 32>] : tensor<64x1280xf16> + %11 = scf.for %arg0 = %c0 to %c1280 step %c4 iter_args(%arg1 = %8) -> (tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %6[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %extracted_slice_0 = tensor.extract_slice %swizzle_9[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %12 = linalg.copy {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1]}>} ins(%extracted_slice : tensor<64x4xf16>) outs(%extracted_slice_0 : tensor<64x4xf16>) -> tensor<64x4xf16> + %extracted_slice_1 = tensor.extract_slice %7[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %extracted_slice_2 = tensor.extract_slice %swizzle_10[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16> + %13 = linalg.copy {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1]}>} ins(%extracted_slice_1 : tensor<64x4xf16>) outs(%extracted_slice_2 : tensor<64x4xf16>) -> tensor<64x4xf16> + %14 = linalg.matmul + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ] + {lowering_config = #iree_gpu.lowering_config<{thread = [4, 4]}>} + ins(%12, %13 : tensor<64x4xf16>, tensor<64x4xf16>) + outs(%arg1 : tensor<64x64xf32>) -> tensor<64x64xf32> + scf.yield %14 : tensor<64x64xf32> + } + return %11 : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @matmul_transpose_b_with_swizzle + +// THREAD-LABEL: func.func @matmul_transpose_b_with_swizzle +// THREAD: %2 = tensor.empty() : tensor<64x4xf16> +// THREAD: %3 = iree_codegen.swizzle_hint %2[#iree_codegen.xor_shuffle<256, 32>] : tensor<64x4xf16> diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir new file mode 100644 index 000000000000..4fdd41f9e6cf --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fold_swizzle_hint_ops.mlir @@ -0,0 +1,120 @@ +// RUN: iree-opt --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level, canonicalize, cse))" %s | FileCheck %s + +// Test: tensor.extract_slice of swizzle_hint(tensor.empty) should fold +// to swizzle_hint(tensor.empty) with the sliced shape. +func.func @fold_extract_slice_of_swizzle_hint() -> tensor<16x32xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [16, 32] [1, 1] : tensor<64x64xf32> to tensor<16x32xf32> + return %slice : tensor<16x32xf32> +} + +// CHECK-LABEL: func.func @fold_extract_slice_of_swizzle_hint +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<16x32xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<16x32xf32> +// CHECK: return %[[SWIZZLE]] + +// Test: tensor.extract_slice with dynamic sizes should fold correctly. +func.func @fold_extract_slice_dynamic(%size0: index, %size1: index) -> tensor { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.xor_shuffle<128, 16>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [%size0, %size1] [1, 1] : tensor<64x64xf32> to tensor + return %slice : tensor +} + +// CHECK-LABEL: func.func @fold_extract_slice_dynamic +// CHECK-SAME: %[[SIZE0:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[SIZE1:[A-Za-z0-9]+]]: index +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[SIZE0]], %[[SIZE1]]) : tensor +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.xor_shuffle<128, 16>] : tensor +// CHECK: return %[[SWIZZLE]] + +// Test: tensor.expand_shape of swizzle_hint(tensor.empty) should fold +// to swizzle_hint(tensor.empty) with the expanded shape. +func.func @fold_expand_shape_of_swizzle_hint() -> tensor<4x16x64xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %expanded = tensor.expand_shape %swizzle [[0, 1], [2]] output_shape [4, 16, 64] : tensor<64x64xf32> into tensor<4x16x64xf32> + return %expanded : tensor<4x16x64xf32> +} + +// CHECK-LABEL: func.func @fold_expand_shape_of_swizzle_hint +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x16x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<4x16x64xf32> +// CHECK: return %[[SWIZZLE]] + +// Test: tensor.collapse_shape of swizzle_hint(tensor.empty) should fold +// to swizzle_hint(tensor.empty) with the collapsed shape. +func.func @fold_collapse_shape_of_swizzle_hint() -> tensor<64x64xf32> { + %empty = tensor.empty() : tensor<4x16x4x16xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<4x16x4x16xf32> + %collapsed = tensor.collapse_shape %swizzle [[0, 1], [2, 3]] : tensor<4x16x4x16xf32> into tensor<64x64xf32> + return %collapsed : tensor<64x64xf32> +} + +// CHECK-LABEL: func.func @fold_collapse_shape_of_swizzle_hint +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> +// CHECK: return %[[SWIZZLE]] + +// Negative test: extract_slice of swizzle_hint without tensor.empty source +// should NOT fold. +func.func @no_fold_extract_slice_non_empty(%arg0: tensor<64x64xf32>) -> tensor<16x32xf32> { + %swizzle = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [16, 32] [1, 1] : tensor<64x64xf32> to tensor<16x32xf32> + return %slice : tensor<16x32xf32> +} + +// CHECK-LABEL: func.func @no_fold_extract_slice_non_empty +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SWIZZLE]] +// CHECK: return %[[SLICE]] + +// Negative test: expand_shape of swizzle_hint without tensor.empty source +// should NOT fold. +func.func @no_fold_expand_shape_non_empty(%arg0: tensor<64x64xf32>) -> tensor<4x16x64xf32> { + %swizzle = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %expanded = tensor.expand_shape %swizzle [[0, 1], [2]] output_shape [4, 16, 64] : tensor<64x64xf32> into tensor<4x16x64xf32> + return %expanded : tensor<4x16x64xf32> +} + +// CHECK-LABEL: func.func @no_fold_expand_shape_non_empty +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<64x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SWIZZLE]] +// CHECK: return %[[EXPANDED]] + +// Test: XOR shuffle swizzle attribute is preserved through folding. +func.func @fold_xor_shuffle_swizzle() -> tensor<8x64xf32> { + %empty = tensor.empty() : tensor<16x128xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.xor_shuffle<128, 16>] : tensor<16x128xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [8, 64] [1, 1] : tensor<16x128xf32> to tensor<8x64xf32> + return %slice : tensor<8x64xf32> +} + +// CHECK-LABEL: func.func @fold_xor_shuffle_swizzle +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.xor_shuffle<128, 16>] : tensor<8x64xf32> +// CHECK: return %[[SWIZZLE]] + +// Test: Rank-reducing extract_slice should work correctly. +func.func @fold_rank_reducing_extract_slice() -> tensor<32xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<64, 4>] : tensor<64x64xf32> + %slice = tensor.extract_slice %swizzle[0, 0] [1, 32] [1, 1] : tensor<64x64xf32> to tensor<32xf32> + return %slice : tensor<32xf32> +} + +// CHECK-LABEL: func.func @fold_rank_reducing_extract_slice +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SWIZZLE:.+]] = iree_codegen.swizzle_hint %[[EMPTY]][#iree_codegen.rotate_rows<64, 4>] : tensor<32xf32> +// CHECK: return %[[SWIZZLE]] + +#encoding = #iree_encoding.encoding (m, k)>, affine_map<(m, n, k) -> (k, n)>, affine_map<(m, n, k) -> (m, n)>], iteration_sizes = [?, ?, ?]> +func.func @fold_swizzle_hint_of_encoding() -> tensor<16xbf16,#encoding> { + %empty = tensor.empty() : tensor<8x16xbf16, #encoding> + %swizzle = iree_codegen.swizzle_hint %empty[#iree_codegen.rotate_rows<8, 4>] : tensor<8x16xbf16, #encoding> + %slice = tensor.extract_slice %swizzle[0, 0] [1, 16] [1, 1] : tensor<8x16xbf16, #encoding> to tensor<16xbf16,#encoding> + return %slice : tensor<16xbf16,#encoding> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 5f53373f4cb9..a61f075ba3ec 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -101,24 +101,44 @@ static FailureOr createSharedAllocDestination(RewriterBase &rewriter, return failure(); } - auto empty = forallOp.getDpsInits()[0].getDefiningOp(); + // Skip swizzle hint ops. + Operation *destination = forallOp.getDpsInits()[0].getDefiningOp(); + if (auto swizzleOp = dyn_cast(destination)) { + destination = swizzleOp->getOperand(0).getDefiningOp(); + } + // Fail if the destination is not a `tensor.empty` op and cannot be trivially // converted to a `bufferization.alloc_tensor`. + auto empty = dyn_cast(destination); if (!empty) { return failure(); } // Create a `bufferization.alloc_tensor` op with memory space // `#gpu.address_space`. + Location loc = empty->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(empty); Attribute sharedMemoryAddrSpace = gpu::AddressSpaceAttr::get( rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); auto allocTensor = bufferization::AllocTensorOp::create( - rewriter, empty->getLoc(), cast(empty.getResult().getType()), + rewriter, loc, cast(empty.getResult().getType()), empty.getDynamicSizes(), /*copy=*/Value(), /*size_hint=*/Value(), /*memory_space=*/sharedMemoryAddrSpace); + + // If the original `tensor.empty` has a swizzle hint, apply it to the new + // allocation. Note that if there is a swizzle hint, it will be the only user + // of the `tensor.empty` op. + if (auto swizzleHintOp = + dyn_cast(*empty->getUsers().begin())) { + assert(swizzleHintOp->hasOneUse() && + "a tensor.empty op with a swizzle hint applied, should have the " + "swizzle hint as its only user"); + auto newSwizzle = IREE::Codegen::SwizzleHintOp::create( + rewriter, loc, allocTensor.getResult(), swizzleHintOp.getSwizzle()); + return newSwizzle.getResult(); + } return allocTensor.getResult(); } @@ -2070,4 +2090,107 @@ void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +//===----------------------------------------------------------------------===// +// SwizzleHintOp Fold Patterns +//===----------------------------------------------------------------------===// + +// The following patterns are adapted from the populateFoldTensorEmptyPatterns +// in upstream llvm-project. The main change is to add support for folding with +// swizzle_hint ops from IREE. Once swizzle_hint ops are more widely used and +// proven stable, we could consider upstreaming this extension. + +namespace { +struct FoldSwizzleHintOpWithExtractSliceOp final + : OpRewritePattern { + using Base::Base; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + // Check for swizzle_hint op source. + auto swizzleHintOp = + sliceOp.getSource().getDefiningOp(); + if (!swizzleHintOp) { + return failure(); + } + + // Check for tensor.empty source. + auto emptyOp = swizzleHintOp.getOperand().getDefiningOp(); + if (!emptyOp) { + return failure(); + } + + // Check for single use. + if (!emptyOp->hasOneUse()) { + return failure(); + } + + // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; + // its dynamic sizes must be preserved as well as its result type. + Location loc = sliceOp.getLoc(); + auto sliceType = cast(sliceOp.getType()); + auto tensorType = + RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding()); + auto newEmptyOp = + tensor::EmptyOp::create(rewriter, loc, tensorType, sliceOp.getSizes()); + rewriter.replaceOpWithNewOp( + sliceOp, newEmptyOp, swizzleHintOp.getSwizzle()); + return success(); + } +}; + +template +struct FoldSwizzleHintOpWithReshapeOp final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto swizzleHintOp = + reshapeOp.getSrc() + .template getDefiningOp(); + if (!swizzleHintOp) { + return failure(); + } + auto emptyOp = + swizzleHintOp.getOperand().template getDefiningOp(); + if (!emptyOp) { + return failure(); + } + + // Check for single use. + if (!emptyOp->hasOneUse()) { + return failure(); + } + + // Reify result shape. + Location loc = reshapeOp.getLoc(); + ReifiedRankedShapedTypeDims resultShapes; + if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || + !llvm::hasSingleElement(resultShapes)) { + return failure(); + } + + // Create new tensor.empty op. + Value emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultShapes[0], + reshapeOp.getResultType().getElementType(), + reshapeOp.getResultType().getEncoding()); + Value newSwizzleHintOp = IREE::Codegen::SwizzleHintOp::create( + rewriter, loc, emptyTensor, swizzleHintOp.getSwizzle()); + if (newSwizzleHintOp.getType() != reshapeOp.getResultType()) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), newSwizzleHintOp); + } else { + rewriter.replaceOp(reshapeOp, newSwizzleHintOp); + } + return success(); + } +}; + +} // namespace + +void populateFoldSwizzleHintOpPatterns(RewritePatternSet &patterns) { + patterns.add, + FoldSwizzleHintOpWithReshapeOp, + FoldSwizzleHintOpWithExtractSliceOp>(patterns.getContext()); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 70d9c3522b73..dcdd11f4232a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -195,6 +195,9 @@ void populateIREEGPUVectorUnrollPatterns( void populateIREEGPUVectorUnrollPatterns(RewritePatternSet &patterns); void populateIREEGPUVectorizationPatterns(RewritePatternSet &patterns); +// Populate patterns to fold tensor.empty ops through swizzle hint ops. +void populateFoldSwizzleHintOpPatterns(RewritePatternSet &patterns); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMS_TRANSFORMS_H_ From 1cb1cf6c441aa2b445b66386baacc2499ec71f54 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Thu, 15 Jan 2026 23:15:01 -0500 Subject: [PATCH 60/71] [Codegen][MXFP4] Adding SwizzleHintOp alloc flattening pass (#23083) This is the first of a series of PRs that together implement support in IREE for XOR swizzling through the SwizzleHintOp. There are four PRs that need to be merged: 1) Allow rank > 1 swizzle hint op operands and add a pass to flatten swizzle hint allocs. 2) Add patterns which can fold reshapes and `extract_slice` ops into empty ops through swizzle hint ops. 3) Add swizzle hint attribute to be set in `lowering_config` and consumed in `GPUPromoteMatmulOperandsPass`. 4) Update `LLVMGPUSelectLoweringStrategy` Pass to set xor swizzles for MXFP4 GEMMs. This is PR 1, which does three things: - Loosens the restriction on SwizzleHintOp inputs needing to be a Shaped type of rank 1. We do this because things are a lot simpler during tiling when you can fold arbitrary shapes into the swizzle hint op and then flatten later. - Introduces a pass to flatten allocs associated to `SwizzleHintOps`. - Moves the verification of flatness of swizzle hint ops to the `ResolveSwizzleHintOps` pass, prior to removal. --------- Signed-off-by: Muzammiluddin Syed --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Common/FlattenSwizzleHintAllocs.cpp | 87 +++++++++++++++++ .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + .../GPU/test/flatten_swizzle_hint_allocs.mlir | 96 +++++++++++++++++++ .../iree/compiler/Codegen/Common/Passes.td | 5 + .../Codegen/Common/ResolveSwizzleHints.cpp | 22 ++++- .../Common/test/resolve_swizzle_hints.mlir | 21 ++++ .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 2 + 10 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 050485d60e78..deff1e88ced2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -105,6 +105,7 @@ iree_compiler_cc_library( "FissionTransferOpsInControlFlow.cpp", "FlattenMemRefSubspanPass.cpp", "FlattenMemRefs.cpp", + "FlattenSwizzleHintAllocs.cpp", "FoldAffineMinInDistributedLoops.cpp", "FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp", "FoldTensorExtractOpPass.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 5ec758dba5d0..e17a7cd2cd9d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -98,6 +98,7 @@ iree_cc_library( "FissionTransferOpsInControlFlow.cpp" "FlattenMemRefSubspanPass.cpp" "FlattenMemRefs.cpp" + "FlattenSwizzleHintAllocs.cpp" "FoldAffineMinInDistributedLoops.cpp" "FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp" "FoldTensorExtractOpPass.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp new file mode 100644 index 000000000000..235ab5176d23 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/FlattenSwizzleHintAllocs.cpp @@ -0,0 +1,87 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_FLATTENSWIZZLEHINTALLOCSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { +struct FlattenSwizzleHintAllocsPass final + : impl::FlattenSwizzleHintAllocsPassBase { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +/// This pass flattens swizzle hint ops that operate on allocations of rank > 1. +/// This is required since swizzle hint op indices require flat memrefs. +/// +/// Example: +/// ``` +/// %0 = iree.alloc() : tensor<512x32xf4E2M1FN> +/// %1 = iree.swizzle_hint %0 : tensor<512x32xf4E2M1FN> -> +/// tensor<512x32xf4E2M1FN> +/// ``` +/// +/// is flattened to: +/// ``` +/// %0 = iree.alloc() : tensor<16384xf4E2M1FN> +/// %1 = iree.swizzle_hint %0 : tensor<16384xf4E2M1FN> -> tensor<16384xf4E2M1FN> +/// %2 = iree.expand_shape %1 : tensor<16384xf4E2M1FN> -> +/// tensor<512x32xf4E2M1FN> +/// ``` +static void flattenSwizzleHintAllocs(RewriterBase &rewriter, + IREE::Codegen::SwizzleHintOp hintOp) { + auto allocOp = hintOp.getOperand().getDefiningOp(); + if (!allocOp || !allocOp->hasOneUse()) { + return; + } + MemRefType resultType = allocOp.getType(); + if (resultType.getRank() == 1 || !resultType.getLayout().isIdentity() || + !memref::isStaticShapeAndContiguousRowMajor(resultType)) { + return; + } + + SmallVector newResultShape = {resultType.getNumElements()}; + auto newResultType = + MemRefType::get(newResultShape, resultType.getElementType(), AffineMap(), + resultType.getMemorySpace()); + rewriter.setInsertionPoint(hintOp); + ReassociationIndices reassoc = + llvm::to_vector(llvm::seq(resultType.getRank())); + auto newAllocOp = + memref::AllocOp::create(rewriter, hintOp.getLoc(), newResultType); + auto newSwizzleHintOp = IREE::Codegen::SwizzleHintOp::create( + rewriter, hintOp.getLoc(), newAllocOp.getResult(), hintOp.getSwizzle()); + auto expandShape = memref::ExpandShapeOp::create(rewriter, hintOp.getLoc(), + resultType.getShape(), + newSwizzleHintOp, {reassoc}); + rewriter.replaceOp(hintOp, expandShape); +} + +void FlattenSwizzleHintAllocsPass::runOnOperation() { + FunctionOpInterface funcOp = getOperation(); + // Collect all swizzle hint ops that operate on allocations. + // Flatten all allocs of rank > 1. + SmallVector hintOps; + funcOp.walk( + [&](IREE::Codegen::SwizzleHintOp hint) { hintOps.push_back(hint); }); + + IRRewriter rewriter(funcOp->getContext()); + for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) { + flattenSwizzleHintAllocs(rewriter, hintOp); + } +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 4f1aeda88479..ed2b9576967e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -21,6 +21,7 @@ iree_lit_test_suite( [ "amdgpu_lower_coalesced_dma_to_gather_lds.mlir", "decompose_horizontally_fused_gemms.mlir", + "flatten_swizzle_hint_allocs.mlir", "gpu_alloc_private_memory_for_dps_ops.mlir", "gpu_apply_derived_thread_config.mlir", "gpu_apply_padding_online_attention.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index dde2c3d34120..9e626933d760 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "amdgpu_lower_coalesced_dma_to_gather_lds.mlir" "decompose_horizontally_fused_gemms.mlir" + "flatten_swizzle_hint_allocs.mlir" "gpu_alloc_private_memory_for_dps_ops.mlir" "gpu_apply_derived_thread_config.mlir" "gpu_apply_padding_online_attention.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir new file mode 100644 index 000000000000..b571927fa818 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/flatten_swizzle_hint_allocs.mlir @@ -0,0 +1,96 @@ +// RUN: iree-opt --allow-unregistered-dialect --pass-pipeline="builtin.module(func.func(iree-codegen-flatten-swizzle-hint-allocs))" \ +// RUN: --mlir-print-local-scope %s | FileCheck %s + +// Test: 1D alloc should NOT be flattened (already 1D). +func.func @skip_1d_alloc() { + %alloc = memref.alloc() : memref<2048xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> + "test.use"(%0) : (memref<2048xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_1d_alloc +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[HINT]]) + +// Test: 2D alloc with swizzle hint should be flattened to 1D. +func.func @flatten_2d_alloc() { + %alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_2d_alloc +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<2048xf32, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [32, 64] : memref<2048xf32, #gpu.address_space> into memref<32x64xf32, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<32x64xf32 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<32x64xf32 +// CHECK: return + +// Test: 3D alloc with swizzle hint should be flattened to 1D. +func.func @flatten_3d_alloc() { + %alloc = memref.alloc() : memref<4x8x16xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<4x8x16xf32, #gpu.address_space> + "test.use"(%0) : (memref<4x8x16xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_3d_alloc +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<512xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.rotate_rows<64, 4>] : memref<512xf32, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1, 2{{\]\]}} output_shape [4, 8, 16] : memref<512xf32, #gpu.address_space> into memref<4x8x16xf32, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<4x8x16xf32 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<4x8x16xf32 +// CHECK: return + +// Test: Non-alloc operand should NOT be affected. +func.func @skip_non_alloc(%arg0: memref<32x64xf32, #gpu.address_space>) { + %0 = iree_codegen.swizzle_hint %arg0[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_non_alloc +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: memref<32x64xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ARG0]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[HINT]]) + +// Test: Alloc with multiple uses should NOT be flattened. +func.func @skip_multi_use_alloc() { + %alloc = memref.alloc() : memref<32x64xf32, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> + "test.use"(%alloc) : (memref<32x64xf32, #gpu.address_space>) -> () + "test.use"(%0) : (memref<32x64xf32, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @skip_multi_use_alloc +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC]][#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, #gpu.address_space> +// CHECK-NOT: memref.expand_shape +// CHECK: "test.use"(%[[ALLOC]]) +// CHECK: "test.use"(%[[HINT]]) + +// Test: XOR shuffle swizzle attribute. +func.func @flatten_xor_shuffle() { + %alloc = memref.alloc() : memref<16x128xi8, #gpu.address_space> + %0 = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<16x128xi8, #gpu.address_space> + "test.use"(%0) : (memref<16x128xi8, #gpu.address_space>) -> () + return +} + +// CHECK-LABEL: func @flatten_xor_shuffle +// CHECK: %[[ALLOC1D:.+]] = memref.alloc() : memref<2048xi8, #gpu.address_space> +// CHECK: %[[HINT:.+]] = iree_codegen.swizzle_hint %[[ALLOC1D]][#iree_codegen.xor_shuffle<128, 16>] : memref<2048xi8, #gpu.address_space> +// CHECK: %[[EXPAND:.+]] = memref.expand_shape %[[HINT]] {{\[\[}}0, 1{{\]\]}} output_shape [16, 128] : memref<2048xi8, #gpu.address_space> into memref<16x128xi8, #gpu.address_space> +// CHECK: "test.use"(%[[EXPAND]]) +// CHECK-NOT: memref.alloc() : memref<16x128xi8 +// CHECK-NOT: iree_codegen.swizzle_hint {{.*}} : memref<16x128xi8 +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index e604891e593a..fc518ef1842c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -496,6 +496,11 @@ def FlattenMemRefSubspanPass : Pass<"iree-codegen-flatten-memref-subspan", "Modu }]; } +def FlattenSwizzleHintAllocsPass : + InterfacePass<"iree-codegen-flatten-swizzle-hint-allocs", "mlir::FunctionOpInterface"> { + let summary = "Flattens allocations associated with iree_codegen.swizzle_hint ops"; +} + def FoldAffineMinInDistributedLoopsPass : InterfacePass<"iree-codegen-fold-affinemin-in-distributed-loops", "mlir::FunctionOpInterface"> { let summary = "Fold `affine.min` ops in distributed loops"; diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp index bc2f6aef0c0c..b6f5994755e5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -157,6 +158,22 @@ static void swizzleGatherToLDS(RewriterBase &rewriter, }); } +static LogicalResult +verifyFlatContiguousSwizzleHintOp(IREE::Codegen::SwizzleHintOp hintOp) { + auto memrefType = cast(hintOp.getOperand().getType()); + // Swizzle hints require flat (rank 1) memrefs. + // For rank 1, allow dynamic memrefs or static contiguous row-major memrefs. + if ((memrefType.getRank() != 1 || !memrefType.getLayout().isIdentity()) || + (memrefType.hasStaticShape() && + !memref::isStaticShapeAndContiguousRowMajor(memrefType))) { + hintOp.emitError() + << "swizzle hint operand must be a contiguous flat memref, got " + << hintOp.getOperand().getType(); + return failure(); + } + return success(); +} + /// Resolves all hints. Walks all direct users and splits them into loads and /// stores. If any user is not a swizzle-able load or store, bail out and /// silently drop the optimization hint. @@ -189,7 +206,7 @@ static void resolveHintOp(RewriterBase &rewriter, } if (auto gatherToLDSOp = dyn_cast(user)) { // Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a - // subgroup contiguously in order of lane ID + // subgroup contiguously in order of lane ID. if (gatherToLDSOp.getDst() == hintOp) { continue; } @@ -242,6 +259,9 @@ void ResolveSwizzleHintsPass::runOnOperation() { // silently pass through for that hint. IRRewriter rewriter(funcOp->getContext()); for (IREE::Codegen::SwizzleHintOp hintOp : hintOps) { + if (failed(verifyFlatContiguousSwizzleHintOp(hintOp))) { + return signalPassFailure(); + } resolveHintOp(rewriter, hintOp); } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir index 4a77f8d8d2f3..6b6b9030a3ba 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir @@ -322,3 +322,24 @@ func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #a // CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index // CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space> // CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]] + +// ----- + +// Verify that swizzle_hint fails on non-flat (rank > 1) memrefs. +func.func @swizzle_hint_non_flat_memref_error(%src: memref<32x64xf32>) -> vector<4xf32> { + // expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32>'}} + %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32> + %offset = arith.constant 0 : index + %1 = vector.load %0[%offset, %offset] : memref<32x64xf32>, vector<4xf32> + return %1: vector<4xf32> +} + +// Verify that swizzle_hint fails on non-contiguous memrefs. +func.func @swizzle_hint_non_contiguous_memref_error() -> vector<4xf32> { + %src = memref.alloc() : memref<32x64xf32, strided<[2, 1], offset: 0>> + // expected-error @+1 {{swizzle hint operand must be a contiguous flat memref, got 'memref<32x64xf32, strided<[2, 1]>>'}} + %0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<32x64xf32, strided<[2, 1], offset: 0>> + %offset = arith.constant 0 : index + %1 = vector.load %0[%offset, %offset] : memref<32x64xf32, strided<[2, 1], offset: 0>>, vector<4xf32> + return %1: vector<4xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 98d047592d4f..7cb5f22183af 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" #include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" @@ -582,6 +583,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); // Step 9. Remaining post-bufferization optimizations/lowerings. + funcPassManager.addPass(createFlattenSwizzleHintAllocsPass()); funcPassManager.addPass(createPropagateDispatchSizeBoundsPass()); funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass()); funcPassManager.addPass(createUnrollAnnotatedLoopsPass()); From 68284a2e816214a402e7bc9d70f22be2f79336f8 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 16 Jan 2026 09:46:55 -0500 Subject: [PATCH 61/71] Add build flag to enable reverse iteration over llvm maps (#23161) You can enable it with `-DIREE_REVERSE_ITERATION=On`. I found 4 failing tests but there might be more non-determinism. ``` iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting.mlir iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting_scf.mlir iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir ``` Once fixed, I plan to enable this in CI. --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index ab747f26c7c6..1d15a10b04bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -529,6 +529,7 @@ option(IREE_ENABLE_THIN_ARCHIVES "Enables thin ar archives (elf systems only). D option(IREE_LINK_COMPILER_SHARED_LIBRARY "Links IREE tools using the compiler compiled into a shared library" ON) option(IREE_ENABLE_WERROR_FLAG "Enable `-Werror` flag, treat error as warning" ON) option(IREE_ENABLE_POSITION_INDEPENDENT_CODE "Enable position independent code" TRUE) +option(IREE_REVERSE_ITERATION "Reverse iteration over in unordered LLVM containers" OFF) if(IREE_LINK_COMPILER_SHARED_LIBRARY AND IREE_ENABLE_COMPILER_TRACING) message(SEND_ERROR @@ -570,6 +571,10 @@ if(IREE_ENABLE_RUNTIME_COVERAGE AND NOT _UPPERCASE_CMAKE_BUILD_TYPE STREQUAL "DE message(FATAL_ERROR "IREE_ENABLE_*_COVERAGE requires building in Debug") endif() +if(IREE_REVERSE_ITERATION) + set(LLVM_ENABLE_REVERSE_ITERATION ON CACHE BOOL "" FORCE) +endif() + #------------------------------------------------------------------------------- # IREE assertions # We don't love the way this is done, but we have to line it up with how LLVM From fc6f68038ea7964985e8bc4ed1db9ca58156c926 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 16 Jan 2026 10:20:06 -0500 Subject: [PATCH 62/71] [Codegen] Fix compiler errors in LinkTuningSpecsPass (#23169) Pass booleans instead of `nullptr`; the former confuses some compilers because both `bool` and `Value` are constructible with `nullptr`. Also clean up comments and needlessly complicated code just above. Fixes: https://github.com/iree-org/iree/issues/23164 --- .../compiler/Codegen/Common/LinkTuningSpecsPass.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp index f66fb308e4e4..0039de5fd7e3 100644 --- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp @@ -409,12 +409,8 @@ static FailureOr emitLinkedDefaultTuningSpec(ModuleOp module) { SmallVector mergedActions; for (ForeachMatchOp foreachMatchOp : foreachMatchOps) { - ArrayAttr matchers = foreachMatchOp.getMatchers(); - ArrayAttr actions = foreachMatchOp.getActions(); - for (auto [matcher, action] : llvm::zip_equal(matchers, actions)) { - mergedMatchers.push_back(cast(matcher)); - mergedActions.push_back(cast(action)); - } + llvm::append_range(mergedMatchers, foreachMatchOp.getMatchers()); + llvm::append_range(mergedActions, foreachMatchOp.getActions()); } Region ®ion = newEntryPoint.getRegion(); @@ -423,8 +419,8 @@ static FailureOr emitLinkedDefaultTuningSpec(ModuleOp module) { builder.setInsertionPointToStart(body); auto mergedForeachMatch = ForeachMatchOp::create( builder, loc, resultTypes, newEntryPoint.getArgument(0), - /* forwarded_inputs = */ ValueRange(), - /* restrictRoot = */ nullptr, /* flattenResults = */ nullptr, + /*forwarded_inputs=*/ValueRange(), + /*restrict_root=*/false, /*flatten_results=*/false, builder.getArrayAttr(mergedMatchers), builder.getArrayAttr(mergedActions)); transform::YieldOp::create(builder, loc, mergedForeachMatch->getResult(0)); From 97a9fd869eef840bb1f9eb25193e5d7e8e607f77 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Fri, 16 Jan 2026 10:25:02 -0500 Subject: [PATCH 63/71] [CI] Update torch_ops config file and run tests when config files are modified (#23168) * Updates torch_ops configuration file to skip running some tests (new tests added without golden_value and a new failing that was not skipped). * Adds a new rule to configure_ci.py to run torch tests whenever configuration files are modified. This is because otherwise one needs to remember to add ci-extra to test relevant tests. (onnx and sharktank are not included here since they are always run on pre-submit) --- build_tools/github_actions/configure_ci.py | 4 ++++ .../torch_ops_gpu_hip_gfx1201_O3.json | 23 ++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/build_tools/github_actions/configure_ci.py b/build_tools/github_actions/configure_ci.py index 86978ce4fcd0..7a62b1d633c5 100755 --- a/build_tools/github_actions/configure_ci.py +++ b/build_tools/github_actions/configure_ci.py @@ -191,6 +191,10 @@ def contains(cls, val): ".github/worklflows/ci_windows_x64_msvc.yml", ], ), + ( + "test_torch", + ["tests/external/iree-test-suites/torch*"], + ), ] PR_DESCRIPTION_TEMPLATE = string.Template("${title}\n\n${body}") diff --git a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json index a1dd3b1519cf..baebf6555755 100644 --- a/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json +++ b/tests/external/iree-test-suites/torch_ops/torch_ops_gpu_hip_gfx1201_O3.json @@ -8,12 +8,29 @@ "iree_run_module_flags": [ "--device=hip" ], - "skip_compile_tests": [], + "skip_compile_tests": [ + "InterestingShapesBiasAdd/997x997xi8_NN_bias" + ], "skip_run_tests": [ "ABPlusC/64x64xf16", - "ATB/64x64xf16" + "ATB/64x64xf16", + "AB/1024x1024xf32_bench", + "AB/128x128xf32_bench", + "AB/2048x2048xf32_bench", + "AB/256x256xf32_bench", + "AB/4096x4096xf32_bench", + "AB/512x512xf32_bench", + "AB/8192x8192xf32_bench" ], "expected_compile_failures": [], "expected_run_failures": [], - "golden_times_ms": {} + "golden_times_ms": { + "AB/1024x1024xf32_bench" : 0.0, + "AB/128x128xf32_bench" : 0.0, + "AB/2048x2048xf32_bench" : 0.0, + "AB/256x256xf32_bench" : 0.0, + "AB/4096x4096xf32_bench" : 0.0, + "AB/512x512xf32_bench" : 0.0, + "AB/8192x8192xf32_bench" : 0.0 + } } From 604eb3e954ff8a235a89f5edce7624b036ac5c24 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Fri, 16 Jan 2026 11:05:17 -0500 Subject: [PATCH 64/71] [Codegen] fixes typo in assert statement (#23170) Signed-off-by: Muzammiluddin Syed --- .../iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index a61f075ba3ec..190a1c27e401 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -132,7 +132,7 @@ static FailureOr createSharedAllocDestination(RewriterBase &rewriter, // of the `tensor.empty` op. if (auto swizzleHintOp = dyn_cast(*empty->getUsers().begin())) { - assert(swizzleHintOp->hasOneUse() && + assert(empty->hasOneUse() && "a tensor.empty op with a swizzle hint applied, should have the " "swizzle hint as its only user"); auto newSwizzle = IREE::Codegen::SwizzleHintOp::create( From 487d98f9f154015daa08765b868db3662b703d0b Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 16 Jan 2026 11:37:16 -0500 Subject: [PATCH 65/71] [Codegen] Add pass to remove iree_codegen.index_hint ops (#23139) Adds a pass to remove iree_codegen.index_hint operations. The pass unconditionally drops all index_hint ops, and should be used once the compiler is done using them for optimizations. The ops can get in the way of later optimizations, so this pass should be used to drop them once they are no longer needed. The pass is not added to any pipelines, because we are not generating index_hint ops anywhere yet, but this pass will be added later once index_hints start to be used. --------- Signed-off-by: Max Dawkins --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../iree/compiler/Codegen/Common/Passes.td | 15 +++++++ .../Codegen/Common/RemoveIndexHints.cpp | 41 +++++++++++++++++++ .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../Common/test/remove_index_hints.mlir | 34 +++++++++++++++ 7 files changed, 94 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index deff1e88ced2..f87cd7ddc22a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -148,6 +148,7 @@ iree_compiler_cc_library( "PropagateReshapesByExpansion.cpp", "ReconcileTranslationInfo.cpp", "RematerializeParallelOps.cpp", + "RemoveIndexHints.cpp", "RemoveSingleIterationLoop.cpp", "ReplaceSlowMinMaxOps.cpp", "ReshapePatterns.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index e17a7cd2cd9d..48497a85be7f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -141,6 +141,7 @@ iree_cc_library( "PropagateReshapesByExpansion.cpp" "ReconcileTranslationInfo.cpp" "RematerializeParallelOps.cpp" + "RemoveIndexHints.cpp" "RemoveSingleIterationLoop.cpp" "ReplaceSlowMinMaxOps.cpp" "ReshapePatterns.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index fc518ef1842c..b8a6fd88a0fd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -886,6 +886,21 @@ def RematerializeParallelOpsPass : let summary = "Pass to rematerialize and merge parallel ops into consumers."; } +def RemoveIndexHintsPass : + InterfacePass<"iree-codegen-remove-index-hints", "mlir::FunctionOpInterface"> { + let summary = "Remove iree_codegen.index_hint operations"; + let description = [{ + This pass removes all iree_codegen.index_hint operations by replacing + them with their input values (pass-through semantics). + + Index hints are used to convey optimization information to downstream + passes and should be cleaned up once that information has been consumed. + }]; + let dependentDialects = [ + "IREE::Codegen::IREECodegenDialect" + ]; +} + def RemoveSingleIterationLoopPass : InterfacePass<"iree-codegen-remove-single-iteration-loop", "mlir::FunctionOpInterface"> { let summary = "Remove distributed loop with single iteration."; diff --git a/compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp b/compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp new file mode 100644 index 000000000000..27d5bf87f91d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/RemoveIndexHints.cpp @@ -0,0 +1,41 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_REMOVEINDEXHINTSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +/// Pass to remove all iree_codegen.index_hint operations by replacing them +/// with their input values. +struct RemoveIndexHintsPass final + : impl::RemoveIndexHintsPassBase { + void runOnOperation() override { + FunctionOpInterface funcOp = getOperation(); + IRRewriter rewriter(funcOp.getContext()); + + SmallVector indexHintOps; + funcOp.walk([&](IREE::Codegen::IndexHintOp hintOp) { + indexHintOps.push_back(hintOp); + }); + + for (auto hintOp : indexHintOps) { + hintOp.getResult().replaceAllUsesWith(hintOp.getInput()); + rewriter.eraseOp(hintOp); + } + } +}; + +} // namespace + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 702a432bde7b..af3c70e2ef28 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -109,6 +109,7 @@ iree_lit_test_suite( "reductions.mlir", "rematerialize_parallel_ops.mlir", "remove_dead_allocs.mlir", + "remove_index_hints.mlir", "remove_single_iteration_loop.mlir", "repeated_matcher_use.mlir", "replace_slow_min_max_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 2c095be422ef..0295c5c09fb0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -104,6 +104,7 @@ iree_lit_test_suite( "reductions.mlir" "rematerialize_parallel_ops.mlir" "remove_dead_allocs.mlir" + "remove_index_hints.mlir" "remove_single_iteration_loop.mlir" "repeated_matcher_use.mlir" "replace_slow_min_max_ops.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir b/compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir new file mode 100644 index 000000000000..198a2193df57 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/remove_index_hints.mlir @@ -0,0 +1,34 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-codegen-remove-index-hints))' %s | FileCheck %s + +// Test: index_hint with lane_constant is removed. +// CHECK-LABEL: func.func @remove_lane_constant_hint +// CHECK-NOT: iree_codegen.index_hint +// CHECK: return %arg0 +func.func @remove_lane_constant_hint(%arg0: index) -> index { + %hint = iree_codegen.index_hint %arg0(#iree_gpu.lane_constant<16>) : index + return %hint : index +} + +// ----- + +// Test: index_hint with lane_increment is removed. +// CHECK-LABEL: func.func @remove_lane_increment_hint +// CHECK-NOT: iree_codegen.index_hint +// CHECK: return %arg0 +func.func @remove_lane_increment_hint(%arg0: index) -> index { + %hint = iree_codegen.index_hint %arg0(#iree_gpu.lane_increment<16>) : index + return %hint : index +} + +// ----- + +// Test: Multiple hints in sequence are all removed. +// CHECK-LABEL: func.func @remove_multiple_hints +// CHECK-NOT: iree_codegen.index_hint +// CHECK: arith.addi %arg0, %arg1 +func.func @remove_multiple_hints(%arg0: index, %arg1: index) -> index { + %hint0 = iree_codegen.index_hint %arg0(#iree_gpu.lane_constant<16>) : index + %hint1 = iree_codegen.index_hint %arg1(#iree_gpu.lane_increment<16>) : index + %sum = arith.addi %hint0, %hint1 : index + return %sum : index +} From 46557fd1c368cbb271e05aef65f431c1ffb30aac Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Fri, 16 Jan 2026 11:14:29 -0800 Subject: [PATCH 66/71] Enable passing e2e tests for ROCM, VMVX, and Vulkan backends (#23174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable tests that were previously excluded but now pass: ROCM/HIP (tests/e2e/linalg): - conv2d, narrow_n_matmuls, subbyte_to_fp, fp_to_subbyte, fp4_f32_conversion, index VMVX (tests/e2e/linalg): - argmax, index VMVX (tests/e2e/linalg_ext_ops): - attention Vulkan (tests/e2e/linalg): - argmax, index Vulkan (tests/e2e/linalg_ext_ops): - map_gather, map_scatter, top-k Vulkan (tests/e2e/stablehlo_ops): - reverse Below is the additional testing time on my machine (using gfx1100): ``` ● Test execution times for newly enabled tests: ┌──────────┬───────┬────────────┐ │ Backend │ Tests │ Total Time │ ├──────────┼───────┼────────────┤ │ ROCM/HIP │ 6 │ 3.06 sec │ ├──────────┼───────┼────────────┤ │ VMVX │ 3 │ 0.28 sec │ ├──────────┼───────┼────────────┤ │ Vulkan │ 6 │ 0.58 sec │ ├──────────┼───────┼────────────┤ │ Total │ 15 │ ~3.9 sec │ └──────────┴───────┴────────────┘ Individual test breakdown: ROCM/HIP: - conv2d: 0.28s - fp4_f32_conversion: 0.39s - fp_to_subbyte: 0.43s - index: 0.27s - narrow_n_matmuls: 0.97s - subbyte_to_fp: 0.72s VMVX: - argmax: 0.04s - index: 0.04s - attention: 0.20s Vulkan: - argmax: 0.05s - index: 0.05s - map_gather: 0.13s - map_scatter: 0.12s - top-k: 0.19s - reverse: 0.05s All tests are fast (under 1 second each). The slowest is narrow_n_matmuls on ROCM at ~1 second. ``` Signed-off-by: hanhanW --- tests/e2e/linalg/BUILD.bazel | 20 ++++++++++---------- tests/e2e/linalg/CMakeLists.txt | 10 ++++++++++ tests/e2e/linalg_ext_ops/BUILD.bazel | 8 ++++---- tests/e2e/linalg_ext_ops/CMakeLists.txt | 4 ++++ tests/e2e/stablehlo_ops/BUILD.bazel | 2 +- tests/e2e/stablehlo_ops/CMakeLists.txt | 1 + 6 files changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/e2e/linalg/BUILD.bazel b/tests/e2e/linalg/BUILD.bazel index ab8360219b06..80d1f03eb50b 100644 --- a/tests/e2e/linalg/BUILD.bazel +++ b/tests/e2e/linalg/BUILD.bazel @@ -73,8 +73,10 @@ iree_check_single_backend_test_suite( VMVX_SRCS = enforce_glob( # keep sorted [ + "argmax.mlir", "conv2d.mlir", "gather_like_ops.mlir", + "index.mlir", "narrow_n_matmuls.mlir", "pack.mlir", "pack_dynamic_inner_tiles.mlir", @@ -84,10 +86,8 @@ VMVX_SRCS = enforce_glob( ], include = ["*.mlir"], exclude = [ - "argmax.mlir", "fp_to_subbyte.mlir", "fp4_f32_conversion.mlir", - "index.mlir", "large_linalg_matmul.mlir", "subbyte_to_fp.mlir", ], @@ -124,18 +124,18 @@ iree_check_single_backend_test_suite( VULKAN_SRCS = enforce_glob( # keep sorted [ + "argmax.mlir", "conv2d.mlir", "gather_like_ops.mlir", + "index.mlir", "narrow_n_matmuls.mlir", "softmax.mlir", "subbyte_to_fp.mlir", ], include = ["*.mlir"], exclude = [ - "argmax.mlir", "fp_to_subbyte.mlir", "fp4_f32_conversion.mlir", - "index.mlir", "large_linalg_matmul.mlir", "pack.mlir", "pack_dynamic_inner_tiles.mlir", @@ -221,20 +221,20 @@ ROCM_SRCS = enforce_glob( # keep sorted [ "argmax.mlir", + "conv2d.mlir", + "fp4_f32_conversion.mlir", + "fp_to_subbyte.mlir", "gather_like_ops.mlir", + "index.mlir", + "narrow_n_matmuls.mlir", "pack_i8.mlir", "softmax.mlir", + "subbyte_to_fp.mlir", "unpack.mlir", ], include = ["*.mlir"], exclude = [ - "conv2d.mlir", - "fp_to_subbyte.mlir", - "fp4_f32_conversion.mlir", - "index.mlir", "large_linalg_matmul.mlir", - "narrow_n_matmuls.mlir", - "subbyte_to_fp.mlir", # https://github.com/llvm/llvm-project/issues/131386 causes # See bug #20294 "pack.mlir", diff --git a/tests/e2e/linalg/CMakeLists.txt b/tests/e2e/linalg/CMakeLists.txt index 9f941a2f6024..33db43f627b4 100644 --- a/tests/e2e/linalg/CMakeLists.txt +++ b/tests/e2e/linalg/CMakeLists.txt @@ -55,8 +55,10 @@ iree_check_single_backend_test_suite( NAME check_vmvx_local-task SRCS + "argmax.mlir" "conv2d.mlir" "gather_like_ops.mlir" + "index.mlir" "narrow_n_matmuls.mlir" "pack.mlir" "pack_dynamic_inner_tiles.mlir" @@ -89,8 +91,10 @@ iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan SRCS + "argmax.mlir" "conv2d.mlir" "gather_like_ops.mlir" + "index.mlir" "narrow_n_matmuls.mlir" "softmax.mlir" "subbyte_to_fp.mlir" @@ -156,9 +160,15 @@ iree_check_single_backend_test_suite( check_rocm_hip SRCS "argmax.mlir" + "conv2d.mlir" + "fp4_f32_conversion.mlir" + "fp_to_subbyte.mlir" "gather_like_ops.mlir" + "index.mlir" + "narrow_n_matmuls.mlir" "pack_i8.mlir" "softmax.mlir" + "subbyte_to_fp.mlir" "unpack.mlir" TARGET_BACKEND "rocm" diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel index 38442ae241e3..68295287a35f 100644 --- a/tests/e2e/linalg_ext_ops/BUILD.bazel +++ b/tests/e2e/linalg_ext_ops/BUILD.bazel @@ -69,6 +69,7 @@ VMVX_SRCS = enforce_glob( # keep sorted [ "arg_compare.mlir", + "attention.mlir", "gather.mlir", "map_gather.mlir", "map_scatter.mlir", @@ -81,7 +82,6 @@ VMVX_SRCS = enforce_glob( ], include = ["*.mlir"], exclude = [ - "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", ], @@ -195,9 +195,12 @@ iree_check_single_backend_test_suite( [ "arg_compare.mlir", "gather.mlir", + "map_gather.mlir", + "map_scatter.mlir", "scan.mlir", "scatter.mlir", "sort.mlir", + "top-k.mlir", "winograd_input.mlir", "winograd_output.mlir", ], @@ -206,9 +209,6 @@ iree_check_single_backend_test_suite( "attention.mlir", "attention_i1_mask.mlir", "attention_i1_mask_encoding.mlir", - "map_gather.mlir", - "map_scatter.mlir", - "top-k.mlir", ], ), driver = "vulkan", diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt index ce5279321bc7..8ce864dcfdb2 100644 --- a/tests/e2e/linalg_ext_ops/CMakeLists.txt +++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt @@ -57,6 +57,7 @@ iree_check_single_backend_test_suite( check_vmvx_local-task SRCS "arg_compare.mlir" + "attention.mlir" "gather.mlir" "map_gather.mlir" "map_scatter.mlir" @@ -138,9 +139,12 @@ iree_check_single_backend_test_suite( SRCS "arg_compare.mlir" "gather.mlir" + "map_gather.mlir" + "map_scatter.mlir" "scan.mlir" "scatter.mlir" "sort.mlir" + "top-k.mlir" "winograd_input.mlir" "winograd_output.mlir" TARGET_BACKEND diff --git a/tests/e2e/stablehlo_ops/BUILD.bazel b/tests/e2e/stablehlo_ops/BUILD.bazel index 9118e78eb292..c8297028824f 100644 --- a/tests/e2e/stablehlo_ops/BUILD.bazel +++ b/tests/e2e/stablehlo_ops/BUILD.bazel @@ -241,6 +241,7 @@ iree_check_single_backend_test_suite( "reduce_window.mlir", "remainder.mlir", "reshape.mlir", + "reverse.mlir", "rng_normal.mlir", "rng_uniform.mlir", "round.mlir", @@ -264,7 +265,6 @@ iree_check_single_backend_test_suite( exclude = [ "exponential_fp16.mlir", "fft.mlir", # TODO(#9583) - "reverse.mlir", # TODO(#12415): disabled due to miscompilation on Pixel 6. ], ), compiler_flags = [ diff --git a/tests/e2e/stablehlo_ops/CMakeLists.txt b/tests/e2e/stablehlo_ops/CMakeLists.txt index 2d357621cee5..c76b852331cb 100644 --- a/tests/e2e/stablehlo_ops/CMakeLists.txt +++ b/tests/e2e/stablehlo_ops/CMakeLists.txt @@ -287,6 +287,7 @@ iree_check_single_backend_test_suite( "reduce_window.mlir" "remainder.mlir" "reshape.mlir" + "reverse.mlir" "rng_normal.mlir" "rng_uniform.mlir" "round.mlir" From 75d7ce6df1308ffedd85505bac694c76c7c6357a Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:17:45 -0500 Subject: [PATCH 67/71] [GPU] Inject index hints during MMA lane distribution (#23152) Injects iree_codegen.index_hint ops on offsets in the populateOperandOffsetsSizesStrides functions for MMAAttrs. We inject the hints here, because the semantic information about the offsets is readily available, and can easily carry down to the later optimization pass that converts loads into transpose loads using these hints. These hints are intended for load to transpose load optimizations, but they are set unconditionally regardless of transpositions for simplicity. The later optimization pass is responsible for determining when the loads are transposed, since it is more explicit at that point. The hint ops will be dropped right after LLVMGPULowerExecutableTarget, since at that point the index_hint ops should already have been used. Currently, the pass that consumes these hint ops is not enabled, so the hint ops will be doing nothing until the pass is added. --------- Signed-off-by: Max Dawkins --- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 63 +++++- .../test/distribute_inner_tiled.mlir | 26 ++- .../test/distribute_inner_tiled_to_lanes.mlir | 208 ++++++++++-------- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 3 +- 4 files changed, 198 insertions(+), 102 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 75379573633b..888df6b06d0f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -7,10 +7,12 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h" #include "iree/compiler/Utils/EncodingUtils.h" #include "iree/compiler/Utils/Indexing.h" @@ -782,6 +784,59 @@ MMAAttr::buildUnderlyingOperations(OpBuilder &builder, Location loc, return failure(); } +/// Creates index_hint ops wrapping delinearized lane ID values. +/// The `delinearizedLaneId` values come from delinearizing the lane ID using +/// `basis`, with the innermost/fastest-varying dimension last. +/// +/// Non-final indices get lane_constant hints (uniform across lane groups). +/// The final index gets lane_increment hint (increments within lane group). +/// The group size is derived from the innermost basis element. +/// Indices with a unit basis are ignored, and given a lane_constant hint. +static SmallVector +createTransposeLoadIndexHint(OpBuilder &builder, Location loc, + ValueRange delinearizedLaneId, + ArrayRef basis) { + // Need at least 2 dimensions for transpose load pattern. + if (delinearizedLaneId.size() < 2) { + return SmallVector(delinearizedLaneId.begin(), + delinearizedLaneId.end()); + } + + // Find the index of the innermost non-unit (> 1) basis element. + // This determines which result gets the lane-increment hint. + // Size-1 dimensions produce constant 0 outputs regardless of lane ID, + // so they don't contribute to the meaningful group structure. + int64_t groupSize = 1; + size_t incrementResultIdx = delinearizedLaneId.size() - 1; + // The delinearized indices could have N or N + 1 results, and the basis + // elements are aligned with the last N results, so iterate backwards + // together. + for (size_t i = 1; i <= basis.size(); ++i) { + groupSize = basis[basis.size() - i]; + incrementResultIdx = delinearizedLaneId.size() - i; + if (groupSize > 1) { + break; + } + } + + auto laneConstantAttr = + IREE::GPU::LaneConstantAttr::get(builder.getContext(), groupSize); + auto laneIncrementAttr = IREE::GPU::LaneIncrementAttr::get( + builder.getContext(), groupSize, /*step=*/1); + + SmallVector results; + for (auto [i, value] : llvm::enumerate(delinearizedLaneId)) { + // The result corresponding to innermost non-unit basis gets lane-increment; + // all other results get lane-constant hints. + Attribute hint = (i == incrementResultIdx) ? Attribute(laneIncrementAttr) + : Attribute(laneConstantAttr); + auto hintOp = IREE::Codegen::IndexHintOp::create(builder, loc, value, hint); + results.push_back(hintOp.getResult()); + } + + return results; +} + static LogicalResult populateCanonicalOffsetsSizesAndStrides( OpBuilder &builder, Location loc, Value laneId, ArrayRef permutation, MMASingleSubgroupLayout subgroupLayout, @@ -819,6 +874,12 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides( auto splitLaneId = affine::AffineDelinearizeIndexOp::create( builder, loc, laneId, vtidBasis, /*hasOuterBound=*/false); + // Wrap delinearize results with index_hint ops for transpose load. + // The delinearize results are already in the correct order + // (innermost/fastest-varying dimension is last). + SmallVector hintedSplitLaneId = createTransposeLoadIndexHint( + builder, loc, splitLaneId.getResults(), vtidBasis); + // Each thread grabs `element` contiguous data, so the vtid needs to be // multiplied by `element` to get the next bunch of data. // vtid: virtual thread id @@ -830,7 +891,7 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides( // worsen the generated code quality. for (auto [splitResultIdx, element] : llvm::zip_equal(dimToVtid, subgroupLayout.element)) { - Value vtid = splitLaneId.getResult(splitResultIdx); + Value vtid = hintedSplitLaneId[splitResultIdx]; int64_t vtidLen = vtidBasis[splitResultIdx - 1]; if (element != 1) { vtid = affine::AffineLinearizeIndexOp::create( diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir index 8d1224dc9bc6..a5ba4ea8910d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_inner_tiled.mlir @@ -35,17 +35,19 @@ module attributes { transform.with_named_sequence } { // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xf32> // CHECK: scf.forall (%[[LANE_ID:.+]]) in (64) shared_outs(%[[ITER_ARG:.+]] = %[[ACC]]) -> (tensor<2x2x16x16xf32>) // CHECK: %[[ID:.+]]:3 = affine.delinearize_index %[[LANE_ID]] into (4, 16) -// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[ID1]]] +// CHECK: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[ID1]]] // CHECK-SAME: [2, 2, 1, 4] [1, 1, 1, 1] : tensor<2x2x16x16xf16> to tensor<2x2x1x4xf16> -// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID1]], %[[ID]]#2] +// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID1]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x16x16xf16> to tensor<2x2x4x1xf16> -// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1]], %[[ID]]#2] +// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x16x16xf32> to tensor<2x2x4x1xf32> // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<2x2x1x4xf16>, tensor<2x2x4x1xf16> into tensor<2x2x4x1xf32> // CHECK: scf.forall.in_parallel -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1]], %[[ID]]#2] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x4x1xf32> into tensor<2x2x16x16xf32> // CHECK: mapping = [#iree_gpu.lane_id<0>] @@ -87,17 +89,19 @@ module attributes { transform.with_named_sequence } { // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: tensor<2x2x16x16xi32> // CHECK: scf.forall (%[[LANE_ID:.+]]) in (64) shared_outs(%[[ITER_ARG:.+]] = %[[ACC]]) -> (tensor<2x2x16x16xi32>) // CHECK: %[[ID:.+]]:3 = affine.delinearize_index %[[LANE_ID]] into (4, 16) -// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 8) -// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[ID1]]] +// CHECK: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK: %[[ID1:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 8) +// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[ID1]]] // CHECK-SAME: [2, 2, 1, 8] [1, 1, 1, 1] : tensor<2x2x16x32xi8> to tensor<2x2x1x8xi8> -// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[ID1]]] +// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[ID1]]] // CHECK-SAME: [2, 2, 1, 8] [1, 1, 1, 1] : tensor<2x2x16x32xi8> to tensor<2x2x1x8xi8> -// CHECK: %[[ID1_2:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[ID]]#2] +// CHECK: %[[ID1_2:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x16x16xi32> to tensor<2x2x4x1xi32> // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<2x2x1x8xi8>, tensor<2x2x1x8xi8> into tensor<2x2x4x1xi32> // CHECK: scf.forall.in_parallel -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[ID]]#2] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ITER_ARG]][0, 0, %[[ID1_2]], %[[COL]]] // CHECK-SAME: [2, 2, 4, 1] [1, 1, 1, 1] : tensor<2x2x4x1xi32> into tensor<2x2x16x16xi32> // CHECK: mapping = [#iree_gpu.lane_id<0>] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir index 211ba3232414..5b1abd255d0a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_inner_tiled_to_lanes.mlir @@ -97,15 +97,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x4x8x32xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 4] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 4] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x4xf16>, tensor<8x2x1x4xf16> into tensor<2x2x4x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -137,15 +139,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x32x4x8xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[ID]]#2, 0, %[[IDY]]] [2, 2, 1, 4, 4] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 4] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 4] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[COL]], 0, %[[IDY]]] [2, 2, 1, 4, 4] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x4xf16>, tensor<8x2x1x4xf16> into tensor<2x2x1x4x4xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[ID]]#2, 0, %[[IDY]]] [2, 2, 1, 4, 4] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[COL]], 0, %[[IDY]]] [2, 2, 1, 4, 4] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -177,15 +181,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x4x8x32xi32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 4] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 4] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x4xi8>, tensor<8x2x1x4xi8> into tensor<2x2x4x4x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[COL]]] [2, 2, 4, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -217,16 +223,19 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x8x2x16xf32>) // CHECK-DAG: %[[ID_1:.+]]:2 = affine.delinearize_index %[[LANEID]] into (16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID_1]]#1, 0] [2, 8, 1, 16] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID_1]]#1, 0] [8, 2, 1, 16] +// CHECK-DAG: %[[ROW_1:.+]] = iree_codegen.index_hint %[[ID_1]]#1(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ROW_1]], 0] [2, 8, 1, 16] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ROW_1]], 0] [8, 2, 1, 16] // CHECK-DAG: %[[ID_2:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) +// CHECK-DAG: %[[ROW_2:.+]] = iree_codegen.index_hint %[[ID_2]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL_2:.+]] = iree_codegen.index_hint %[[ID_2]]#2(#iree_gpu.lane_increment<16>) : index // Note: ID_2#1 and I_2#2 should not be delinearize outputs once we move to linearized indexing -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ID_2]]#1, %[[ID_2]]#2] [2, 2, 8, 1, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ROW_2]], %[[COL_2]]] [2, 2, 8, 1, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x16xf16>, tensor<8x2x1x16xf16> into tensor<2x2x8x1x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ID_2]]#1, %[[ID_2]]#2] [2, 2, 8, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ROW_2]], %[[COL_2]]] [2, 2, 8, 1, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -251,14 +260,16 @@ func.func @distribute_MFMA_F32_16x16x4_F32(%lhs: tensor<16x4xf32>, %rhs: tensor< // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x16xf32> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[ID]]#1] [1, 1] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[ID]]#1, %[[ID]]#2] [1, 1] -// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[ROW]]] [1, 1] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[ROW]], %[[COL]]] [1, 1] +// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x1xf32>, tensor<1x1xf32> into tensor<4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -283,15 +294,17 @@ func.func @distribute_F32_16x16x32_F8E4M3FNUZ(%lhs: tensor<16x32xf8E4M3FNUZ>, %r // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<32x16xf8E4M3FNUZ> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDY]]] [1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[ID]]#2] [8, 1] -// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDY]]] [1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[COL]]] [8, 1] +// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x8xf8E4M3FNUZ>, tensor<8x1xf8E4M3FNUZ> into tensor<4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[ID]]#2] [4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDZ]], %[[COL]]] [4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -316,15 +329,17 @@ func.func @distribute_I32_32x32x16_I8(%lhs: tensor<32x16xi8>, %rhs: tensor<16x32 // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x32xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<4x8x32xi32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDY]]] [1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[ID]]#2] [8, 1] -// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[IDZ]], %[[ID]]#2] [4, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDY]]] [1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[COL]]] [8, 1] +// CHECK-DAG: %[[IDZ:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[IDZ]], %[[COL]]] [4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x8xi8>, tensor<8x1xi8> into tensor<4x4x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[IDZ]], %[[ID]]#2] [4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[IDZ]], %[[COL]]] [4, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -349,13 +364,15 @@ func.func @distribute_WMMAR3_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: ten // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x16xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x8x2xf16>) // CHECK-DAG: %[[ID:.+]]:2 = affine.delinearize_index %[[LANEID]] into (16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#1, 0] [1, 16] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, %[[ID]]#1] [16, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[ID]]#1] [16, 1, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %c0(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ROW]], 0] [1, 16] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, %[[ROW]]] [16, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, %[[COL]], %[[ROW]]] [16, 1, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x16xf16>, tensor<16x1xf16> into tensor<16x1x1xf16> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[ID]]#1] [16, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, %[[COL]], %[[ROW]]] [16, 1, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -387,15 +404,18 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x8x2x16xi32>) // CHECK-DAG: %[[ID:.+]]:2 = affine.delinearize_index %[[LANEID]] into (16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#1, 0] [2, 8, 1, 16] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#1, 0] [8, 2, 1, 16] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ROW]], 0] [2, 8, 1, 16] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ROW]], 0] [8, 2, 1, 16] // CHECK-DAG: %[[ID_ACC:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ID_ACC]]#1, %[[ID_ACC]]#2] [2, 2, 8, 1, 1] +// CHECK-DAG: %[[ROW_ACC:.+]] = iree_codegen.index_hint %[[ID_ACC]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL_ACC:.+]] = iree_codegen.index_hint %[[ID_ACC]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[ROW_ACC]], %[[COL_ACC]]] [2, 2, 8, 1, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x16xi8>, tensor<8x2x1x16xi8> into tensor<2x2x8x1x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ID_ACC]]#1, %[[ID_ACC]]#2] [2, 2, 8, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[ROW_ACC]], %[[COL_ACC]]] [2, 2, 8, 1, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -420,14 +440,16 @@ func.func @distribute_WMMAR4_F16_16x16x16_F16(%lhs: tensor<16x16xf16>, %rhs: ten // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<16x16xf16> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf16>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDY]]] [1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[ID]]#2] [8, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDY]]] [1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDY]], %[[COL]]] [8, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x8xf16>, tensor<8x1xf16> into tensor<8x1xf16> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -459,15 +481,17 @@ module { // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x16x16xi8> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x16x16xi32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 8] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 8] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[COL]], %[[IDY]]] [2, 8, 1, 8] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[COL]], %[[IDY]]] [8, 2, 1, 8] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [2, 2, 8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<2x8x1x8xi8>, tensor<8x2x1x8xi8> into tensor<2x2x8x1xi32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [2, 2, 8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [2, 2, 8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -492,15 +516,17 @@ func.func @distribute_WMMA_F32_16x16x4_F32(%lhs: tensor<16x4xf32>, %rhs: tensor< // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<4x16xf32> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 2) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDX]]] [1, 2] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[ID]]#2] [2, 1] -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 2) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDX]]] [1, 2] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[COL]]] [2, 1] +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x2xf32>, tensor<2x1xf32> into tensor<8x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -525,15 +551,17 @@ func.func @distribute_WMMA_F32_16x16x128_F8E4M3FN(%lhs: tensor<16x128xf8E4M3FN>, // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<128x16xf8E4M3FN> // CHECK: scf.forall (%[[LANEID:.+]]) in (32) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 16) -// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 64) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[ID]]#2, %[[IDX]]] [1, 64] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[ID]]#2] [64, 1] -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 8) -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[IDX:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 64) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][%[[COL]], %[[IDX]]] [1, 64] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][%[[IDX]], %[[COL]]] [64, 1] +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (2, 8) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: kind = #iree_gpu.mma_layout // CHECK-SAME: : tensor<1x64xf8E4M3FN>, tensor<64x1xf8E4M3FN> into tensor<8x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[ID]]#2] [8, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][%[[IDY]], %[[COL]]] [8, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -1087,16 +1115,18 @@ func.func @scaled_matmul_f32_16x16x128_b32_fp4_fp8(%lhs: tensor<3x5x1x16x4x32xf4 // CHECK-SAME: %[[RHS_SCALE:[A-Za-z0-9]+]]: tensor<5x7x4x16xf8E8M0FNU> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<3x7x16x16xf32>) // CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [3, 5, 1, 1, 1, 32] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ID]]#1, 0, %[[ID]]#2] [5, 1, 7, 1, 32, 1] -// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [3, 5, 1, 1] -// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ID]]#1, %[[ID]]#2] [5, 7, 1, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK-DAG: %[[ROW:.+]] = iree_codegen.index_hint %[[ID]]#1(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: %[[COL:.+]] = iree_codegen.index_hint %[[ID]]#2(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[COL]], %[[ROW]], 0] [3, 5, 1, 1, 1, 32] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ROW]], 0, %[[COL]]] [5, 1, 7, 1, 32, 1] +// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[COL]], %[[ROW]]] [3, 5, 1, 1] +// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ROW]], %[[COL]]] [5, 7, 1, 1] +// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ROW]], %c0] by (4, 4) +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [3, 7, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]], %[[LHS_SCALE_SLICE]], %[[RHS_SCALE_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]], #[[$MAP4]]] // CHECK-SAME: : tensor<3x5x1x1x1x32xf4E2M1FN>, tensor<5x1x7x1x32x1xf8E4M3FN>, tensor<3x5x1x1xf8E8M0FNU>, tensor<5x7x1x1xf8E8M0FNU> into tensor<3x7x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[COL]]] [3, 7, 4, 1] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -1137,16 +1167,16 @@ func.func @scaled_matmul_trb_f32_16x16x128_b32_fp4_fp8(%lhs: tensor<3x5x4x16x4x3 // CHECK-SAME: %[[LHS_SCALE:[A-Za-z0-9]+]]: tensor<3x5x16x4xf8E8M0FNU> // CHECK-SAME: %[[RHS_SCALE:[A-Za-z0-9]+]]: tensor<5x7x16x4xf8E8M0FNU> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<3x7x16x16xf32>) -// CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (4, 16) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (4, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [3, 5, 4, 1, 1, 32] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [5, 4, 7, 1, 1, 32] -// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [3, 5, 1, 1] -// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [5, 7, 1, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_constant<16>) : index +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_increment<16>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]{{.*}} [3, 5, 4, 1, 1, 32] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]{{.*}} [5, 4, 7, 1, 1, 32] +// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]]{{.*}} [3, 5, 1, 1] +// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]]{{.*}} [5, 7, 1, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]]{{.*}} [3, 7, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]], %[[LHS_SCALE_SLICE]], %[[RHS_SCALE_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<3x5x4x1x1x32xf4E2M1FN>, tensor<5x4x7x1x1x32xf8E4M3FN>, tensor<3x5x1x1xf8E8M0FNU>, tensor<5x7x1x1xf8E8M0FNU> into tensor<3x7x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- @@ -1183,16 +1213,16 @@ func.func @scaled_matmul_trb_f32_32x32x64_b32_fp4_fp8(%lhs: tensor<3x5x1x32x2x32 // CHECK-SAME: %[[LHS_SCALE:[A-Za-z0-9]+]]: tensor<3x5x32x2xf8E8M0FNU> // CHECK-SAME: %[[RHS_SCALE:[A-Za-z0-9]+]]: tensor<5x7x32x2xf8E8M0FNU> // CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<3x7x4x8x32xf32>) -// CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32) -// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [3, 5, 1, 1, 1, 32] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, 0, %[[ID]]#2, %[[ID]]#1, 0] [5, 1, 7, 1, 1, 32] -// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [3, 5, 1, 1] -// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]][0, 0, %[[ID]]#2, %[[ID]]#1] [5, 7, 1, 1] -// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 4, 1] +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_constant<32>) : index +// CHECK-DAG: iree_codegen.index_hint {{.*}}(#iree_gpu.lane_increment<32>) : index +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]{{.*}} [3, 5, 1, 1, 1, 32] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]{{.*}} [5, 1, 7, 1, 1, 32] +// CHECK-DAG: %[[LHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[LHS_SCALE]]{{.*}} [3, 5, 1, 1] +// CHECK-DAG: %[[RHS_SCALE_SLICE:.+]] = tensor.extract_slice %[[RHS_SCALE]]{{.*}} [5, 7, 1, 1] +// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]]{{.*}} [3, 7, 4, 4, 1] // CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS_SLICE]], %[[RHS_SLICE]], %[[LHS_SCALE_SLICE]], %[[RHS_SCALE_SLICE]]) outs(%[[ACC_SLICE]]) // CHECK-SAME: : tensor<3x5x1x1x1x32xf4E2M1FN>, tensor<5x1x7x1x1x32xf8E4M3FN>, tensor<3x5x1x1xf8E8M0FNU>, tensor<5x7x1x1xf8E8M0FNU> into tensor<3x7x4x4x1xf32> -// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, 0, %[[IDY]], %[[ID]]#2] [3, 7, 4, 4, 1] +// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]] // CHECK: mapping = [#iree_gpu.lane_id<0>] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 7cb5f22183af..5b84dd181b25 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1146,7 +1146,8 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager, FunctionLikeNest(modulePassManager) .addPass( [&] { return createLLVMGPULowerExecutableTargetPass(options); }) - .addPass(createVerifyWorkgroupDistributionPass); + .addPass(createVerifyWorkgroupDistributionPass) + .addPass(createRemoveIndexHintsPass); if (clPatchFuncOps) { modulePassManager.addPass(createPatchFuncOpsPass()); } From d92664a5687430b9f5ddbec7e5d39c63a30dcea5 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 16 Jan 2026 20:18:44 -0500 Subject: [PATCH 68/71] Fix test failures revealed by reverse iteration (nondeterminism) (#23162) Using https://github.com/iree-org/iree/pull/23161 --- .../Transforms/AutomaticReferenceCounting.cpp | 48 ++++--- .../Util/Transforms/HoistIntoGlobals.cpp | 126 +++++++++++------- 2 files changed, 106 insertions(+), 68 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp index 005c5e0daed3..9e5377c5e673 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp @@ -306,31 +306,40 @@ struct LastUseSet { } }; +// Returns the timepoints sorted by their order in the block (textual order). +// All timepoints must be in the same block. +static SmallVector getSortedTimepointsInBlock(TimepointSet &timepoints) { + auto sorted = llvm::to_vector_of(timepoints); + llvm::sort(sorted, [](Value a, Value b) { + Operation *opA = a.getDefiningOp(); + Operation *opB = b.getDefiningOp(); + if (!opA && !opB) { + // Both are block arguments, compare by argument number. + return cast(a).getArgNumber() < + cast(b).getArgNumber(); + } + if (!opA) { + return true; // Block argument comes before operation. + } + if (!opB) { + return false; // Operation comes before block argument. + } + return opA->isBeforeInBlock(opB); + }); + return sorted; +} + // Returns the last defined SSA value in the block in |timepoints| (textual // order within the block). All timepoints must be in the same block. static Value getLastTimepointInBlock(TimepointSet &timepoints) { if (timepoints.empty()) { return nullptr; - } else if (timepoints.size() == 1) { - return *timepoints.begin(); } - Value lastTimepoint; - for (auto timepoint : timepoints) { - if (!lastTimepoint) { - lastTimepoint = timepoint; - } else { - auto *timepointOp = timepoint.getDefiningOp(); - auto *lastTimepointOp = lastTimepoint.getDefiningOp(); - if (!timepointOp) { - continue; // block arg - } else if (!lastTimepointOp) { - lastTimepoint = timepoint; // last found was a block arg, this isn't - } else if (lastTimepointOp->isBeforeInBlock(timepointOp)) { - lastTimepoint = timepoint; - } - } + if (timepoints.size() == 1) { + return *timepoints.begin(); } - return lastTimepoint; + SmallVector sorted = getSortedTimepointsInBlock(timepoints); + return sorted.back(); } // Returns a FusedLoc with the location of all |timepoints| and the base |loc|. @@ -595,8 +604,7 @@ static void insertDeallocations(LastUseSet &lastUseSet, AsmState *asmState, auto joinOp = IREE::Stream::TimepointJoinOp::create( builder, timepointsLoc, builder.getType(), - llvm::map_to_vector(timepoints, - [](Value timepoint) { return timepoint; })); + getSortedTimepointsInBlock(timepoints)); auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( builder, timepointsLoc, builder.getType(), resource, diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index e19758e4d9bb..007552e38259 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -98,19 +98,25 @@ class HoistIntoGlobalsPass file.close(); } - // Maps original values to newly materialized values. - HoistedValueMap hoistedMap; - // Walk all operations in the program and hoist any escapes from // const-expr values into globals. Note that we must walk the const-exprs // in topological order so that corresponding initializers will be created // in order without depending on globals that have not been initialized // yet. + OpBuilder builder(&getContext()); for (auto funcOp : getOperation().getOps()) { // Ignore initializers. if (isa(funcOp.getOperation())) { continue; } + + // Maps original values to newly materialized globals (per-function). + HoistedValueMap hoistedMap; + + // Operation order for deterministic sorting (per-function). + llvm::DenseMap opOrder; + unsigned orderIdx = 0; + auto walkRes = funcOp.walk([&](Operation *iterOp) { // We only want to look at const-expr ops (non roots) since they may // have interesting escapes. Early exit here for efficiency. @@ -118,6 +124,10 @@ class HoistIntoGlobalsPass if (!iterInfo) { return WalkResult::advance(); } + + // Record operation order for deterministic sorting. Since we walk in + // PreOrder, producers are visited before their users. + opOrder[iterOp] = orderIdx++; for (Value constExprResult : iterOp->getResults()) { auto *resultInfo = constExprs.lookup(constExprResult); assert(resultInfo && "must have const-expr info"); @@ -126,7 +136,7 @@ class HoistIntoGlobalsPass continue; } if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols, - constExprs))) { + constExprs, opOrder))) { return WalkResult::interrupt(); } } @@ -135,35 +145,42 @@ class HoistIntoGlobalsPass if (walkRes.wasInterrupted()) { return signalPassFailure(); } - } - // Apply any remaining RAUW cleanups. We have to do these at the cleanup - // phase since modifying the source program can invalidate the analysis. - // Up to this point, we have only been cloning. - OpBuilder builder(&getContext()); - for (auto [originalValue, globalOp] : hoistedMap) { - builder.setInsertionPointAfterValue(originalValue); - auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); - if (!originalValue.getDefiningOp() - ->getParentOfType()) { - loadOp.setGlobalImmutable(true); - } - Value loadedValue = loadOp.getLoadedGlobalValue(); - // Call user hook to cast back to the original type. - if (auto hoistableType = dyn_cast( - originalValue.getType())) { - loadedValue = hoistableType.decodeStorageType( - builder, loadedValue.getLoc(), originalValue.getType(), - loadedValue); - } - if (loadedValue.getType() != originalValue.getType()) { - getOperation().emitError() - << "Unresolved conflict between casted global of type " - << loadedValue.getType() << " and original type " - << originalValue.getType(); - return signalPassFailure(); + // Apply RAUW cleanups for this function. We do this after cloning to + // avoid invalidating the analysis during the walk. + // Sort the hoisted values by program order for deterministic output. + using HoistedValue = std::pair; + auto sortedHoisted = llvm::to_vector_of(hoistedMap); + llvm::sort(sortedHoisted, + [&opOrder](const HoistedValue &lhs, const HoistedValue &rhs) { + return opOrder[lhs.first.getDefiningOp()] < + opOrder[rhs.first.getDefiningOp()]; + }); + + for (auto [originalValue, globalOp] : sortedHoisted) { + builder.setInsertionPointAfterValue(originalValue); + auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); + if (!originalValue.getDefiningOp() + ->getParentOfType()) { + loadOp.setGlobalImmutable(true); + } + Value loadedValue = loadOp.getLoadedGlobalValue(); + // Call user hook to cast back to the original type. + if (auto hoistableType = dyn_cast( + originalValue.getType())) { + loadedValue = hoistableType.decodeStorageType( + builder, loadedValue.getLoc(), originalValue.getType(), + loadedValue); + } + if (loadedValue.getType() != originalValue.getType()) { + getOperation().emitError() + << "Unresolved conflict between casted global of type " + << loadedValue.getType() << " and original type " + << originalValue.getType(); + return signalPassFailure(); + } + originalValue.replaceAllUsesWith(loadedValue); } - originalValue.replaceAllUsesWith(loadedValue); } cleanupDeadOps(constExprs); } @@ -177,9 +194,11 @@ class HoistIntoGlobalsPass return op; } - LogicalResult hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap, - SymbolTable &moduleSymbols, - const ConstExprAnalysis &constExprs) { + LogicalResult + hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap, + SymbolTable &moduleSymbols, + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { IREE::Util::GlobalOp existingGlobal = hoistedMap.lookup(originalValue); if (existingGlobal) { return success(); @@ -202,7 +221,7 @@ class HoistIntoGlobalsPass if (failed(cloneConstExprInto(initializerOp.getLoc(), moduleBuilder, initializerBuilder, originalValue, dialectAttrs, hoistedMap, moduleSymbols, - constExprs))) { + constExprs, opOrder))) { return failure(); } @@ -218,7 +237,8 @@ class HoistIntoGlobalsPass cloneProducerTreeInto(OpBuilder &initializerBuilder, const ConstExprAnalysis::ConstValueInfo *producerInfo, HoistedValueMap &hoistedMap, IRMapping &cloneMapping, - const ConstExprAnalysis &constExprs) { + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { if (cloneMapping.contains(producerInfo->constValue)) { return; } @@ -243,10 +263,20 @@ class HoistIntoGlobalsPass return; } - // Materialize all producers recursively. - for (auto *producerInfo : producerInfo->producers) { - cloneProducerTreeInto(initializerBuilder, producerInfo, hoistedMap, - cloneMapping, constExprs); + // Materialize all producers recursively. Sort producers by their program + // order for deterministic output. + auto sortedProducers = + llvm::to_vector_of( + producerInfo->producers); + llvm::sort(sortedProducers, + [&opOrder](ConstExprAnalysis::ConstValueInfo *lhs, + ConstExprAnalysis::ConstValueInfo *rhs) { + return opOrder.lookup(lhs->constValue.getDefiningOp()) < + opOrder.lookup(rhs->constValue.getDefiningOp()); + }); + for (ConstExprAnalysis::ConstValueInfo *prodInfo : sortedProducers) { + cloneProducerTreeInto(initializerBuilder, prodInfo, hoistedMap, + cloneMapping, constExprs, opOrder); } // And clone the requested op. @@ -264,13 +294,13 @@ class HoistIntoGlobalsPass // Clones the const expr tree rooted at `constExprValue` into the given // initializer, noting any new hoisted value mappings that result. At // a minimum, a mapping will be created for the requested value. - LogicalResult cloneConstExprInto(Location loc, OpBuilder &moduleBuilder, - OpBuilder &initializerBuilder, - Value constExprValue, - NamedAttrList dialectAttrs, - HoistedValueMap &hoistedMap, - SymbolTable &moduleSymbols, - const ConstExprAnalysis &constExprs) { + LogicalResult + cloneConstExprInto(Location loc, OpBuilder &moduleBuilder, + OpBuilder &initializerBuilder, Value constExprValue, + NamedAttrList dialectAttrs, HoistedValueMap &hoistedMap, + SymbolTable &moduleSymbols, + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { // Do a depth first traversal of the producers, emitting them in a valid // def-use order. Operation *rootOp = constExprValue.getDefiningOp(); @@ -281,7 +311,7 @@ class HoistIntoGlobalsPass // Clone the whole tree as needed. IRMapping cloneMapping; cloneProducerTreeInto(initializerBuilder, rootInfo, hoistedMap, - cloneMapping, constExprs); + cloneMapping, constExprs, opOrder); // And for each result, create a global and store into it. for (Value origResult : rootOp->getResults()) { From 1695fb91e95915da11c01c7526bebaddaea4a3b5 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 16 Jan 2026 21:20:21 -0500 Subject: [PATCH 69/71] [CI] Enable reverse iteration in UBsan workflow (#23178) I don't want to add too many CI workflows, so adding together with ubsan. --- .github/workflows/ci_linux_x64_clang_ubsan.yml | 5 ++++- build_tools/cmake/build_and_test_ubsan.sh | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_linux_x64_clang_ubsan.yml b/.github/workflows/ci_linux_x64_clang_ubsan.yml index 956ab5fe18b9..b245e1d9d9b0 100644 --- a/.github/workflows/ci_linux_x64_clang_ubsan.yml +++ b/.github/workflows/ci_linux_x64_clang_ubsan.yml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - Linux x64 clang UBSan +name: CI - Linux x64 clang UBSan and Reverse Iteration on: workflow_call: @@ -29,6 +29,9 @@ jobs: # Use a modern clang explicitly. CC: clang-19 CXX: clang++-19 + # Enable reverse iteration of unordered LLVM containers. This helps + # catch non-determinism bugs. + IREE_REVERSE_ITERATE: "ON" SCCACHE_AZURE_CONNECTION_STRING: "${{ secrets.AZURE_CCACHE_CONNECTION_STRING }}" SCCACHE_AZURE_BLOB_CONTAINER: ccache-container SCCACHE_CACHE_ZSTD_LEVEL: 10 diff --git a/build_tools/cmake/build_and_test_ubsan.sh b/build_tools/cmake/build_and_test_ubsan.sh index 9cc551bba448..27bde66c8504 100755 --- a/build_tools/cmake/build_and_test_ubsan.sh +++ b/build_tools/cmake/build_and_test_ubsan.sh @@ -21,6 +21,7 @@ set -xeuo pipefail BUILD_DIR="${1:-${IREE_UBSAN_BUILD_DIR:-build-ubsan}}" IREE_ENABLE_ASSERTIONS="${IREE_ENABLE_ASSERTIONS:-ON}" +IREE_REVERSE_ITERATE="${IREE_REVERSE_ITERATE:-OFF}" # Enable CUDA and HIP/ROCM compiler and runtime by default if not on Darwin. OFF_IF_DARWIN="$(uname | awk '{print ($1 == "Darwin") ? "OFF" : "ON"}')" IREE_HAL_DRIVER_CUDA="${IREE_HAL_DRIVER_CUDA:-${OFF_IF_DARWIN}}" @@ -45,6 +46,7 @@ CMAKE_ARGS=( "-DIREE_BUILD_PYTHON_BINDINGS=OFF" "-DIREE_ENABLE_ASSERTIONS=${IREE_ENABLE_ASSERTIONS}" + "-DIREE_REVERSE_ITERATE=${IREE_REVERSE_ITERATE}" "-DIREE_ENABLE_LLD=ON" "-DIREE_ENABLE_SPLIT_DWARF=ON" "-DIREE_ENABLE_THIN_ARCHIVES=ON" From c8d75155701248f9a2d8a23e07a85312a39c5d58 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sat, 17 Jan 2026 08:07:46 -0500 Subject: [PATCH 70/71] [CPU] Fix nondeterminism in host cpu features (#23179) This is hard to test for because only the (dynamic) host feature list is unordered, unlike features for a specific target, and we can't assume a specific host in tests. --- .../target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp b/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp index 59da38c40bfa..1705ba96ff7e 100644 --- a/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp +++ b/compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.cpp @@ -6,6 +6,7 @@ #include "compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TargetParser/AArch64TargetParser.h" @@ -29,9 +30,18 @@ resolveHostCPUAndCPUFeatures(std::string &cpu, std::string &cpuFeatures) { return ResolveCPUAndCPUFeaturesStatus::InconsistentHost; } cpu = llvm::sys::getHostCPUName(); + // Sort features to ensure deterministic iteration order. The StringMap + // returned by getHostCPUFeatures() has non-deterministic iteration order. + llvm::StringMap hostFeatures = + llvm::sys::getHostCPUFeatures(); + auto sortedFeatures = + llvm::to_vector_of(hostFeatures.keys()); + llvm::sort(sortedFeatures); + + // Add all features in lexicographically sorted order. llvm::SubtargetFeatures features; - for (auto &feature : llvm::sys::getHostCPUFeatures()) { - features.AddFeature(feature.first(), feature.second); + for (llvm::StringRef feature : sortedFeatures) { + features.AddFeature(feature, hostFeatures.lookup(feature)); } cpuFeatures = features.getString(); return ResolveCPUAndCPUFeaturesStatus::OK; From b47101f3e272d526e1596f873ae278c85cb137dd Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sat, 17 Jan 2026 13:15:40 -0500 Subject: [PATCH 71/71] Simplify quantifiers all_of, any_of, none_of. NFC. (#23180) * Use `llvm::IsaPred` instead of lambdas where possible * `!any_of` --> `none_of` --- .../Conversion/Preprocessing/StableHLOToStableHLO.cpp | 6 +++--- .../Bindings/Native/Transforms/WrapEntryPoints.cpp | 2 +- .../compiler/Codegen/Common/DecomposePackUnPackOps.cpp | 5 ++--- .../Codegen/Common/MaterializeEncodingPatterns.cpp | 2 +- .../Common/TransformExtensions/CommonExtensions.cpp | 4 +--- .../compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp | 5 ++--- .../Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp | 4 +--- compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp | 2 +- .../src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp | 2 +- compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp | 5 ++--- .../src/iree/compiler/Codegen/Transforms/Transforms.cpp | 5 ++--- .../iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp | 4 +--- .../Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp | 7 ++++--- .../compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp | 5 ++--- .../Dialect/Stream/Transforms/PropagateTimepoints.cpp | 2 +- .../Dialect/Util/Transforms/PropagateSubranges.cpp | 2 +- .../iree/compiler/DispatchCreation/FormDispatchRegions.cpp | 7 +++---- .../DispatchCreation/FuseHorizontalContractions.cpp | 2 +- .../compiler/GlobalOptimization/ExpandTensorShapes.cpp | 2 +- 19 files changed, 31 insertions(+), 42 deletions(-) diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp index 7251188e0b0a..c67ceb8e5bae 100644 --- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp @@ -921,9 +921,9 @@ struct ScatterMaterializeInsertedDim final } } - llvm::ArrayRef toInsertDims = + auto toInsertDims = llvm::ArrayRef(isInsertDims).drop_front(frontInsertedDims); - if (!llvm::any_of(toInsertDims, [](auto d) { return d; })) { + if (llvm::none_of(toInsertDims, [](bool d) { return d; })) { return rewriter.notifyMatchFailure(op, "no dimensions to insert"); } @@ -931,7 +931,7 @@ struct ScatterMaterializeInsertedDim final SmallVector reassociationMap; reassociationMap.push_back({rewriter.getAffineDimExpr(0)}); - for (auto it : llvm::enumerate(llvm::ArrayRef(toInsertDims))) { + for (auto it : llvm::enumerate(toInsertDims)) { if (!it.value()) { reassociationMap.push_back({}); } diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index 0fb8ba7e0375..2aa6f8bbc35a 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -413,7 +413,7 @@ static void formatIOAttr(DictionaryAttr attrs, llvm::raw_ostream &os) { auto shouldIncludeAttr = [](const NamedAttribute &attr) { return attr.getName().getValue() != "iree.abi.name"; }; - if (!llvm::any_of(attrs, shouldIncludeAttr)) { + if (llvm::none_of(attrs, shouldIncludeAttr)) { return; } os << " {"; diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index 736e69212536..cce497024012 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -344,9 +344,8 @@ static LogicalResult isUnpaddedAndAtBoundary(Operation *op) { // If all consumers are dispatch tensor stores, then the `op` is decomposable // if it is an UnPackOp. if (isa(op) && - llvm::all_of(op->getUsers(), [&](Operation *user) { - return isa(user); - })) { + llvm::all_of(op->getUsers(), + llvm::IsaPred)) { return success(); } return failure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp index b4db56291403..b87bfc53e342 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp @@ -734,7 +734,7 @@ void populateMaterializeEncodingPatterns( return resultType == typeConverter.convertType(resultType); }); target.addDynamicallyLegalOp([](func::ReturnOp returnOp) { - return !llvm::any_of(returnOp.getOperandTypes(), + return llvm::none_of(returnOp.getOperandTypes(), isRankedTensorTypeWithEncoding); }); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 5af5331db67f..ac26d8146758 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -423,9 +423,7 @@ static LogicalResult rewriteForallToWorkgroup(RewriterBase &rewriter, } SmallVector blockMapping = llvm::to_vector(forallOp.getMapping()->getValue()); - if (llvm::any_of(blockMapping, [](Attribute map) { - return !isa(map); - })) { + if (!llvm::all_of(blockMapping, llvm::IsaPred)) { return forallOp->emitError("mapping must be #gpu.block"); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 190a1c27e401..8f95eed4c316 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -485,9 +485,8 @@ fuseNestedLaneAndWarpForalls(RewriterBase &rewriter, scf::ForallOp warpForallOp, scf::ForallOp laneForallOp) { // Verify mappings. if (!warpForallOp.getMapping() || - !llvm::all_of(*warpForallOp.getMapping(), [](Attribute mappingAttr) { - return isa(mappingAttr); - })) { + !llvm::all_of(*warpForallOp.getMapping(), + llvm::IsaPred)) { return rewriter.notifyMatchFailure(warpForallOp, "not a warp forall op"); } if (!laneForallOp.getMapping() || laneForallOp.getMapping()->size() != 1 || diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp index cffd8cf605c4..344026ded253 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp @@ -303,9 +303,7 @@ NestedLayoutAttr::getRecombinedLayout(ArrayRef layouts, ArrayRef maps, AffineMap resultMap) { constexpr int64_t kInvalid = -1; - if (llvm::any_of(layouts, [](VectorLayoutInterface layout) { - return !isa(layout); - })) { + if (!llvm::all_of(layouts, llvm::IsaPred)) { return NestedLayoutAttr(); } MLIRContext *context = resultMap.getContext(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp index bbd2f35d5af9..51158f9939d4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp @@ -798,7 +798,7 @@ MemRefDescriptor HALDispatchABI::loadBinding(Operation *forOp, int64_t ordinal, // requested range is valid. auto [strides, offset] = memRefType.getStridesAndOffset(); if (memRefType.hasStaticShape() && - !llvm::any_of(strides, ShapedType::isDynamic) && + llvm::none_of(strides, ShapedType::isDynamic) && ShapedType::isStatic(offset)) { return MemRefDescriptor::fromStaticShape(builder, loc, *typeConverter, memRefType, basePtrValue); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index 4fa209a2aa7c..7f44dd67bcfe 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -408,7 +408,7 @@ struct ConvertIREEBindingSubspanOp final auto [strides, offset] = memrefType.getStridesAndOffset(); if (memrefType.hasStaticShape() && - !llvm::any_of(strides, ShapedType::isDynamic) && + llvm::none_of(strides, ShapedType::isDynamic) && ShapedType::isStatic(offset)) { auto desc = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), memrefType, llvmBufferBasePtr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 5b84dd181b25..d3f8639a497d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -403,9 +403,8 @@ LogicalResult isAtBoundary(Operation *op) { return success(); } } else if (isa(op)) { - if (llvm::all_of(op->getUsers(), [](Operation *user) { - return isa(user); - })) { + if (llvm::all_of(op->getUsers(), + llvm::IsaPred)) { return success(); } } diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index 8165a7743a74..c49671bb9c5e 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -661,9 +661,8 @@ struct FoldSplitReductionForallWithWorkgroupForall } std::optional workgroupMapping = workgroupLoop.getMapping(); if (!workgroupMapping || - llvm::any_of(workgroupMapping->getValue(), [](Attribute attr) { - return !isa(attr); - })) { + !llvm::all_of(workgroupMapping->getValue(), + llvm::IsaPred)) { return rewriter.notifyMatchFailure( workgroupLoop, "nested loop is not a workgroup mapping loop"); } diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index 3cdde7c627c9..80cb384caaaf 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -229,9 +229,7 @@ static FailureOr getComposedAffineMap(Attribute attr) { return AffineMap(); } // All entries should have type `AffineMapAttr`. - if (!llvm::all_of(mapsAttr, [](Attribute attr) { - return isa(attr); - })) { + if (!llvm::all_of(mapsAttr, llvm::IsaPred)) { return failure(); } AffineMap map = diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp index 133ed3bea247..f1214d07f4e2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp @@ -101,9 +101,10 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups( llvm::SetVector argumentsSet; mlir::getUsedValuesDefinedAbove(region, argumentsSet); // Unranked tensors are not supported. - assert(!llvm::any_of(argumentsSet, [](Value v) { - return isa(v.getType()); - }) && "unranked tensors are not supported"); + assert(llvm::none_of( + argumentsSet, + [](Value v) { return isa(v.getType()); }) && + "unranked tensors are not supported"); // Compute dimensions of tensor args. SmallVector argumentDims; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 9e461995781a..b023712db9b6 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -643,9 +643,8 @@ FailureOr hoistOutOfDispatch(RewriterBase &rewriter, return producer && producer->getParentOfType(); })) { rewriter.setInsertionPoint(dispatchRegionOp); - } else if (llvm::all_of(op->getUsers(), [&](Operation *user) { - return isa(user); - })) { + } else if (llvm::all_of(op->getUsers(), + llvm::IsaPred)) { rewriter.setInsertionPointAfter(dispatchRegionOp); } else { return rewriter.notifyMatchFailure( diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp index 3e51b04f75b4..ace1ef6ac0bd 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp @@ -254,7 +254,7 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, // Update all block arguments. auto timepointType = IREE::Stream::TimepointType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) { + if (llvm::none_of(block.getArgumentTypes(), isResourceType)) { continue; } if (block.isEntryBlock() && !canModifyEntryBlock) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp index bacff11291e6..c6fffa1f66f1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp @@ -230,7 +230,7 @@ static void expandRegion(Region ®ion, bool canModifyEntryBlock, // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isResourceType)) { + if (llvm::none_of(block.getArgumentTypes(), isResourceType)) { continue; } diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 9df5a62a2d96..370fd631b096 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -737,7 +737,7 @@ isFusableWithConsumer(OpOperand &fusedOperand, const FusionTracker &tracker, continue; } if (isa(producer) && - !llvm::any_of( + llvm::none_of( consumerDstOp.getDpsInitsMutable(), [&](OpOperand &initOperand) { return canUseInOperandAsInitOperand(inputOperand, &initOperand); })) { @@ -979,9 +979,8 @@ decideFusableLinalgOps(Region ®ion, DominanceInfo const &dominanceInfo, // by the `isClonableIntoDispatchOp` call above, but for now this is done // as a point fix. if (IREE::LinalgExt::isGatherlikeOp(&op) && - llvm::all_of(op.getUsers(), [](Operation *op) { - return isa(op); - })) { + llvm::all_of(op.getUsers(), + llvm::IsaPred)) { continue; } diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index ed03dc8891ea..68315e478288 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -253,7 +253,7 @@ static bool isHorizontalToGroup(Operation *op, llvm::SetVector slice; [[maybe_unused]] LogicalResult result = getBackwardSlice(op, &slice, options); assert(result.succeeded()); - return !llvm::any_of(currGroup, [&](Operation *groupedOp) { + return llvm::none_of(currGroup, [&](Operation *groupedOp) { return slice.contains(groupedOp); }); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp index d52968c0b53a..e2b900ade36e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp @@ -213,7 +213,7 @@ static void expandRegion(Region ®ion, SymbolTable &symbolTable, // Update all block arguments. auto indexType = IndexType::get(region.getContext()); for (auto &block : region.getBlocks()) { - if (!llvm::any_of(block.getArgumentTypes(), isDynamicTensor)) { + if (llvm::none_of(block.getArgumentTypes(), isDynamicTensor)) { continue; }