Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===- DIExpressionLegalization.h - DIExpression Legalization Patterns ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Declarations for known legalization patterns for DIExpressions that should
// be performed before translation into llvm.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H

#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"

namespace mlir {
namespace LLVM {

//===----------------------------------------------------------------------===//
// Rewrite Patterns
//===----------------------------------------------------------------------===//

/// Adjacent DW_OP_LLVM_fragment should be merged into one.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: please elaborate with a short psuedocode example, like expr(frag, frag) -> expr(frag) or something.

class MergeFragments : public DIExpressionRewriter::ExprRewritePattern {
public:
OpIterT match(OpIterRange operators) const override;
SmallVector<OperatorT> replace(OpIterRange operators) const override;
};

//===----------------------------------------------------------------------===//
// Runner
//===----------------------------------------------------------------------===//

/// Register all known legalization patterns declared here and apply them to
/// all ops in `op`.
void legalizeDIExpressionsRecursively(Operation *op);

} // namespace LLVM
} // namespace mlir

#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H
66 changes: 66 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- DIExpressionRewriter.h - Rewriter for DIExpression operators -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A driver for running rewrite patterns on DIExpression operators.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include <deque>

namespace mlir {
namespace LLVM {

/// Rewriter for DIExpressionAttr.
///
/// Users of this rewriter register their own rewrite patterns. Each pattern
/// matches on a contiguous range of LLVM DIExpressionElemAttrs, and can be
/// used to rewrite it into a new range of DIExpressionElemAttrs of any length.
class DIExpressionRewriter {
public:
using OperatorT = LLVM::DIExpressionElemAttr;

class ExprRewritePattern {
public:
using OperatorT = DIExpressionRewriter::OperatorT;
using OpIterT = std::deque<OperatorT>::const_iterator;
using OpIterRange = llvm::iterator_range<OpIterT>;

virtual ~ExprRewritePattern() = default;
/// Check whether a particular prefix of operators matches this pattern.
/// The provided argument is guaranteed non-empty.
/// Return the iterator after the last matched element.
virtual OpIterT match(OpIterRange) const = 0;
/// Replace the operators with a new list of operators.
/// The provided argument is guaranteed to be the same length as returned
/// by the `match` function.
virtual SmallVector<OperatorT> replace(OpIterRange) const = 0;
};

/// Register a rewrite pattern with the simplifier.
/// Rewriter patterns are attempted in the order of registration.
void addPattern(std::unique_ptr<ExprRewritePattern> pattern);

/// Simplify a DIExpression according to all the patterns registered.
/// A non-negative `maxNumRewrites` will limit the number of rewrites this
/// simplifier applies.
LLVM::DIExpressionAttr simplify(LLVM::DIExpressionAttr expr,
int64_t maxNumRewrites = -1) const;

private:
/// The registered patterns.
SmallVector<std::unique_ptr<ExprRewritePattern>> patterns;
};

} // namespace LLVM
} // namespace mlir

#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRLLVMIRTransforms
AddComdats.cpp
DIExpressionLegalization.cpp
DIExpressionRewriter.cpp
DIScopeForLLVMFuncOp.cpp
LegalizeForExport.cpp
OptimizeForNVVM.cpp
Expand Down
61 changes: 61 additions & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//===- DIExpressionLegalization.cpp - DIExpression Legalization Patterns --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"

#include "llvm/BinaryFormat/Dwarf.h"

using namespace mlir;
using namespace LLVM;

//===----------------------------------------------------------------------===//
// MergeFragments
//===----------------------------------------------------------------------===//

MergeFragments::OpIterT MergeFragments::match(OpIterRange operators) const {
OpIterT it = operators.begin();
if (it == operators.end() ||
it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment)
return operators.begin();

++it;
if (it == operators.end() ||
it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment)
return operators.begin();

return ++it;
}

SmallVector<MergeFragments::OperatorT>
MergeFragments::replace(OpIterRange operators) const {
OpIterT it = operators.begin();
OperatorT first = *(it++);
OperatorT second = *it;
// Add offsets & select the size of the earlier operator (the one closer to
// the IR value).
uint64_t offset = first.getArguments()[0] + second.getArguments()[0];
uint64_t size = first.getArguments()[1];
OperatorT newOp = OperatorT::get(
first.getContext(), llvm::dwarf::DW_OP_LLVM_fragment, {offset, size});
return SmallVector<OperatorT>{newOp};
}

//===----------------------------------------------------------------------===//
// Runner
//===----------------------------------------------------------------------===//

void mlir::LLVM::legalizeDIExpressionsRecursively(Operation *op) {
LLVM::DIExpressionRewriter rewriter;
rewriter.addPattern(std::make_unique<MergeFragments>());

mlir::AttrTypeReplacer replacer;
replacer.addReplacement([&rewriter](LLVM::DIExpressionAttr expr) {
return rewriter.simplify(expr);
});
replacer.recursivelyReplaceElementsIn(op);
}
74 changes: 74 additions & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//===- DIExpressionRewriter.cpp - Rewriter for DIExpression operators -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace LLVM;

#define DEBUG_TYPE "llvm-di-expression-simplifier"

//===----------------------------------------------------------------------===//
// DIExpressionRewriter
//===----------------------------------------------------------------------===//

void DIExpressionRewriter::addPattern(
std::unique_ptr<ExprRewritePattern> pattern) {
patterns.emplace_back(std::move(pattern));
}

DIExpressionAttr DIExpressionRewriter::simplify(DIExpressionAttr expr,
int64_t maxNumRewrites) const {
ArrayRef<OperatorT> operators = expr.getOperations();

// `inputs` contains the unprocessed postfix of operators.
// `result` contains the already finalized prefix of operators.
// Invariant: concat(result, inputs) is equivalent to `operators` after some
// application of the rewrite patterns.
// Using a deque for inputs so that we have efficient front insertion and
// removal. Random access is not necessary for patterns.
std::deque<OperatorT> inputs(operators.begin(), operators.end());
SmallVector<OperatorT> result;

int64_t numRewrites = 0;
while (!inputs.empty() &&
(maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {
(maxNumRewrites < 0 || numRewrites < maxNumRewrites)) {

Wouldn't std::optional<uint64_t> be much nicer than using -1 in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good 👍 was a bit hesitant to wrap an int since the mlir rewriter was not, but I like this 😂 .

bool foundMatch = false;
for (const std::unique_ptr<ExprRewritePattern> &pattern : patterns) {
ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs);
if (matchEnd == inputs.begin())
continue;

foundMatch = true;
SmallVector<OperatorT> replacement =
pattern->replace(llvm::make_range(inputs.cbegin(), matchEnd));
inputs.erase(inputs.begin(), matchEnd);
inputs.insert(inputs.begin(), replacement.begin(), replacement.end());
++numRewrites;
break;
}

if (!foundMatch) {
// If no match, pass along the current operator.
result.push_back(inputs.front());
inputs.pop_front();
}
}

if (maxNumRewrites >= 0 && numRewrites >= maxNumRewrites) {
LLVM_DEBUG(llvm::dbgs()
<< "LLVMDIExpressionSimplifier exceeded max num rewrites ("
<< maxNumRewrites << ")\n");
// Skip rewriting the rest.
result.append(inputs.begin(), inputs.end());
}

return LLVM::DIExpressionAttr::get(expr.getContext(), result);
}
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -79,6 +80,7 @@ struct LegalizeForExportPass
: public LLVM::impl::LLVMLegalizeForExportBase<LegalizeForExportPass> {
void runOnOperation() override {
LLVM::ensureDistinctSuccessors(getOperation());
LLVM::legalizeDIExpressionsRecursively(getOperation());
}
};
} // namespace
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h"
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
Expand Down Expand Up @@ -1568,6 +1569,7 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
return nullptr;

LLVM::ensureDistinctSuccessors(module);
LLVM::legalizeDIExpressionsRecursively(module);

ModuleTranslation translator(module, std::move(llvmModule));
llvm::IRBuilder<> llvmBuilder(llvmContext);
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: mlir-opt -llvm-legalize-for-export --split-input-file %s | FileCheck %s -check-prefix=CHECK-OPT
// RUN: mlir-translate -mlir-to-llvmir --split-input-file %s | FileCheck %s -check-prefix=CHECK-TRANSLATE

#di_file = #llvm.di_file<"foo.c" in "/mlir/">
#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, sourceLanguage = DW_LANG_C, file = #di_file, producer = "MLIR", isOptimized = true, emissionKind = Full>
#di_subprogram = #llvm.di_subprogram<compileUnit = #di_compile_unit, scope = #di_file, name = "simplify", file = #di_file, subprogramFlags = Definition>
#i32_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "i32", sizeInBits = 32, encoding = DW_ATE_unsigned>
#i8_type = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "i8", sizeInBits = 8, encoding = DW_ATE_unsigned>

// struct0: {i8, i32}
#struct0_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct0_first", baseType = #i8_type, sizeInBits = 8, alignInBits = 8>
#struct0_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct0_second", baseType = #i32_type, sizeInBits = 32, alignInBits = 32, offsetInBits = 32>
#struct0 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct0", sizeInBits = 64, alignInBits = 32, elements = #struct0_first, #struct0_second>

// struct1: {i8, struct0}
#struct1_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct1_first", baseType = #i8_type, sizeInBits = 8, alignInBits = 8>
#struct1_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct1_second", baseType = #struct0, sizeInBits = 64, alignInBits = 32>
#struct1 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct1", sizeInBits = 96, alignInBits = 32, elements = #struct1_first, #struct1_second>

// struct2: {i32, struct1}
#struct2_first = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct2_first", baseType = #i32_type, sizeInBits = 32, alignInBits = 32>
#struct2_second = #llvm.di_derived_type<tag = DW_TAG_member, name = "struct2_second", baseType = #struct1, sizeInBits = 96, alignInBits = 32>
#struct2 = #llvm.di_composite_type<tag = DW_TAG_structure_type, name = "struct2", sizeInBits = 128, alignInBits = 32, elements = #struct2_first, #struct2_second>

#var0 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct0_var", file = #di_file, line = 10, alignInBits = 32, type = #struct0>
#var1 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct1_var", file = #di_file, line = 10, alignInBits = 32, type = #struct1>
#var2 = #llvm.di_local_variable<scope = #di_subprogram, name = "struct2_var", file = #di_file, line = 10, alignInBits = 32, type = #struct2>

#loc = loc("test.mlir":0:0)

llvm.func @merge_fragments(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) {
// CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]>
// CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 32, 32))
llvm.intr.dbg.value #var0 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]> = %arg0 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
// CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(64, 32)]>
// CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 64, 32))
llvm.intr.dbg.value #var1 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64)]> = %arg1 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
// CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(96, 32)]>
// CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 96, 32))
llvm.intr.dbg.value #var2 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64), DW_OP_LLVM_fragment(32, 96)]> = %arg2 : !llvm.ptr loc(fused<#di_subprogram>[#loc])
llvm.return
}