Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ def CUFDeviceFuncTransform
/*default=*/"0", "CUDA compute capability version">];
}

def CUFFunctionRewrite : Pass<"cuf-function-rewrite", ""> {
let summary = "Convert some CUDA Fortran specific call";
let dependentDialects = ["fir::FIROpsDialect"];
}

def CUFLaunchAttachAttr : Pass<"cuf-launch-attach-attr", ""> {
let summary = "Attach CUDA attribute to CUF kernel generated launch";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_flang_library(FIRTransforms
ControlFlowConverter.cpp
CUDA/CUFAllocationConversion.cpp
CUDA/CUFDeviceFuncTransform.cpp
CUDA/CUFFunctionRewrite.cpp
CUDA/CUFLaunchAttachAttr.cpp
CUDA/CUFPredefinedVarToGPU.cpp
CUFAddConstructor.cpp
Expand Down
103 changes: 103 additions & 0 deletions flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===-- CUFFUnctionRewrite.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include <string_view>

#define DEBUG_TYPE "flang-cuf-function-rewrite"

namespace fir {
#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;

namespace {

using genFunctionType =
std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;

class CallConversion : public OpRewritePattern<fir::CallOp> {
public:
CallConversion(MLIRContext *context)
: OpRewritePattern<fir::CallOp>(context) {}

LogicalResult
matchAndRewrite(fir::CallOp op,
mlir::PatternRewriter &rewriter) const override {
auto callee = op.getCallee();
if (!callee)
return failure();
auto name = callee->getRootReference().getValue();

if (genMappings_.contains(name)) {
auto fct = genMappings_.find(name);
mlir::Value result = fct->second(rewriter, op);
if (result)
rewriter.replaceOp(op, result);
else
rewriter.eraseOp(op);
return success();
}
return failure();
}

private:
static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
fir::CallOp op) {
assert(op.getArgs().size() == 0 && "expect 0 arguments");
mlir::Location loc = op.getLoc();
unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
mlir::Type i1Ty = rewriter.getIntegerType(1);
mlir::Value t = mlir::arith::ConstantOp::create(
rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
}

const llvm::StringMap<genFunctionType> genMappings_ = {
{"on_device", &genOnDevice}};
};

class CUFFunctionRewrite
: public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
public:
void runOnOperation() override {
auto *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);

patterns.insert<CallConversion>(patterns.getContext());

if (mlir::failed(
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
"error in CUFFunctionRewrite op conversion\n");
signalPassFailure();
}
}
};

} // namespace
44 changes: 44 additions & 0 deletions flang/test/Fir/CUDA/cuda-function-rewrite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: fir-opt --split-input-file --cuf-function-rewrite %s | FileCheck %s

gpu.module @cuda_device_mod {
func.func @_QMmtestsPdo2(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%0 = fir.dummy_scope : !fir.dscope
%5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
%8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
%13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
%14 = fir.convert %13 : (!fir.logical<4>) -> i1
fir.if %14 {
fir.store %c1_i32 to %5 : !fir.ref<i32>
} else {
fir.store %c2_i32 to %5 : !fir.ref<i32>
}
return
}
}

// CHECK-LABEL: gpu.module @cuda_device_mod
// CHECK: func.func @_QMmtestsPdo2
// CHECK: fir.if %true

// -----

func.func @_QMmtestsPdo3(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%0 = fir.dummy_scope : !fir.dscope
%5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
%8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
%13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
%14 = fir.convert %13 : (!fir.logical<4>) -> i1
fir.if %14 {
fir.store %c1_i32 to %5 : !fir.ref<i32>
} else {
fir.store %c2_i32 to %5 : !fir.ref<i32>
}
return
}

// CHECK-LABEL: func.func @_QMmtestsPdo3
// CHECK: fir.if %false