diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index 47ffc4be93b33..dd2023223f1d2 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -533,6 +533,11 @@ def CUFDeviceFuncTransform /*default=*/"0", "CUDA compute capability version">]; } +def CUFFunctionRewrite : Pass<"cuf-function-rewrite", ""> { + let summary = "Convert some CUDA Fortran specific call"; + let dependentDialects = ["fir::FIROpsDialect"]; +} + def CUFLaunchAttachAttr : Pass<"cuf-launch-attach-attr", ""> { let summary = "Attach CUDA attribute to CUF kernel generated launch"; let description = [{ diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 4496e80aa7c40..4ee5eab6247e1 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_flang_library(FIRTransforms ControlFlowConverter.cpp CUDA/CUFAllocationConversion.cpp CUDA/CUFDeviceFuncTransform.cpp + CUDA/CUFFunctionRewrite.cpp CUDA/CUFLaunchAttachAttr.cpp CUDA/CUFPredefinedVarToGPU.cpp CUFAddConstructor.cpp diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp new file mode 100644 index 0000000000000..d6f9dc097831c --- /dev/null +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp @@ -0,0 +1,103 @@ +//===-- CUFFUnctionRewrite.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/CodeGen/TypeConverter.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/DataLayout.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "flang-cuf-function-rewrite" + +namespace fir { +#define GEN_PASS_DEF_CUFFUNCTIONREWRITE +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { + +using genFunctionType = + std::function; + +class CallConversion : public OpRewritePattern { +public: + CallConversion(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult + matchAndRewrite(fir::CallOp op, + mlir::PatternRewriter &rewriter) const override { + auto callee = op.getCallee(); + if (!callee) + return failure(); + auto name = callee->getRootReference().getValue(); + + if (genMappings_.contains(name)) { + auto fct = genMappings_.find(name); + mlir::Value result = fct->second(rewriter, op); + if (result) + rewriter.replaceOp(op, result); + else + rewriter.eraseOp(op); + return success(); + } + return failure(); + } + +private: + static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter, + fir::CallOp op) { + assert(op.getArgs().size() == 0 && "expect 0 arguments"); + mlir::Location loc = op.getLoc(); + unsigned inGPUMod = op->getParentOfType() ? 1 : 0; + mlir::Type i1Ty = rewriter.getIntegerType(1); + mlir::Value t = mlir::arith::ConstantOp::create( + rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod)); + return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t); + } + + const llvm::StringMap genMappings_ = { + {"on_device", &genOnDevice}}; +}; + +class CUFFunctionRewrite + : public fir::impl::CUFFunctionRewriteBase { +public: + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + + patterns.insert(patterns.getContext()); + + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(ctx), + "error in CUFFunctionRewrite op conversion\n"); + signalPassFailure(); + } + } +}; + +} // namespace diff --git a/flang/test/Fir/CUDA/cuda-function-rewrite.mlir b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir new file mode 100644 index 0000000000000..da1d601a2eb8b --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir @@ -0,0 +1,44 @@ +// RUN: fir-opt --split-input-file --cuf-function-rewrite %s | FileCheck %s + +gpu.module @cuda_device_mod { + func.func @_QMmtestsPdo2(%arg0: !fir.ref {cuf.data_attr = #cuf.cuda, fir.bindc_name = "c"}, %arg1: !fir.ref {cuf.data_attr = #cuf.cuda, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc} { + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = fir.dummy_scope : !fir.dscope + %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref, !fir.dscope) -> !fir.ref + %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref, !fir.dscope) -> !fir.ref + %13 = fir.call @on_device() proc_attrs fastmath : () -> !fir.logical<4> + %14 = fir.convert %13 : (!fir.logical<4>) -> i1 + fir.if %14 { + fir.store %c1_i32 to %5 : !fir.ref + } else { + fir.store %c2_i32 to %5 : !fir.ref + } + return + } +} + +// CHECK-LABEL: gpu.module @cuda_device_mod +// CHECK: func.func @_QMmtestsPdo2 +// CHECK: fir.if %true + +// ----- + +func.func @_QMmtestsPdo3(%arg0: !fir.ref {cuf.data_attr = #cuf.cuda, fir.bindc_name = "c"}, %arg1: !fir.ref {cuf.data_attr = #cuf.cuda, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc} { + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %0 = fir.dummy_scope : !fir.dscope + %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref, !fir.dscope) -> !fir.ref + %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref, !fir.dscope) -> !fir.ref + %13 = fir.call @on_device() proc_attrs fastmath : () -> !fir.logical<4> + %14 = fir.convert %13 : (!fir.logical<4>) -> i1 + fir.if %14 { + fir.store %c1_i32 to %5 : !fir.ref + } else { + fir.store %c2_i32 to %5 : !fir.ref + } + return +} + +// CHECK-LABEL: func.func @_QMmtestsPdo3 +// CHECK: fir.if %false