[flang][cuda] Add CUFFunctionRewrite pass#174650
Merged
clementval merged 1 commit intollvm:mainfrom Jan 6, 2026
Merged
Conversation
Member
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThis rewrite some CUDA Fortran specific like Full diff: https://github.com/llvm/llvm-project/pull/174650.diff 4 Files Affected:
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 47ffc4be93b33..dd2023223f1d2 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -533,6 +533,11 @@ def CUFDeviceFuncTransform
/*default=*/"0", "CUDA compute capability version">];
}
+def CUFFunctionRewrite : Pass<"cuf-function-rewrite", ""> {
+ let summary = "Convert some CUDA Fortran specific call";
+ let dependentDialects = ["fir::FIROpsDialect"];
+}
+
def CUFLaunchAttachAttr : Pass<"cuf-launch-attach-attr", ""> {
let summary = "Attach CUDA attribute to CUF kernel generated launch";
let description = [{
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 4496e80aa7c40..4ee5eab6247e1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_flang_library(FIRTransforms
ControlFlowConverter.cpp
CUDA/CUFAllocationConversion.cpp
CUDA/CUFDeviceFuncTransform.cpp
+ CUDA/CUFFunctionRewrite.cpp
CUDA/CUFLaunchAttachAttr.cpp
CUDA/CUFPredefinedVarToGPU.cpp
CUFAddConstructor.cpp
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
new file mode 100644
index 0000000000000..d6f9dc097831c
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
@@ -0,0 +1,103 @@
+//===-- CUFFUnctionRewrite.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include <string_view>
+
+#define DEBUG_TYPE "flang-cuf-function-rewrite"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+using genFunctionType =
+ std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;
+
+class CallConversion : public OpRewritePattern<fir::CallOp> {
+public:
+ CallConversion(MLIRContext *context)
+ : OpRewritePattern<fir::CallOp>(context) {}
+
+ LogicalResult
+ matchAndRewrite(fir::CallOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto callee = op.getCallee();
+ if (!callee)
+ return failure();
+ auto name = callee->getRootReference().getValue();
+
+ if (genMappings_.contains(name)) {
+ auto fct = genMappings_.find(name);
+ mlir::Value result = fct->second(rewriter, op);
+ if (result)
+ rewriter.replaceOp(op, result);
+ else
+ rewriter.eraseOp(op);
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
+ fir::CallOp op) {
+ assert(op.getArgs().size() == 0 && "expect 0 arguments");
+ mlir::Location loc = op.getLoc();
+ unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
+ mlir::Type i1Ty = rewriter.getIntegerType(1);
+ mlir::Value t = mlir::arith::ConstantOp::create(
+ rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
+ return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
+ }
+
+ const llvm::StringMap<genFunctionType> genMappings_ = {
+ {"on_device", &genOnDevice}};
+};
+
+class CUFFunctionRewrite
+ : public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+
+ patterns.insert<CallConversion>(patterns.getContext());
+
+ if (mlir::failed(
+ mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUFFunctionRewrite op conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/flang/test/Fir/CUDA/cuda-function-rewrite.mlir b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir
new file mode 100644
index 0000000000000..da1d601a2eb8b
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir
@@ -0,0 +1,44 @@
+// RUN: fir-opt --split-input-file --cuf-function-rewrite %s | FileCheck %s
+
+gpu.module @cuda_device_mod {
+ func.func @_QMmtestsPdo2(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
+ %c2_i32 = arith.constant 2 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %0 = fir.dummy_scope : !fir.dscope
+ %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
+ %14 = fir.convert %13 : (!fir.logical<4>) -> i1
+ fir.if %14 {
+ fir.store %c1_i32 to %5 : !fir.ref<i32>
+ } else {
+ fir.store %c2_i32 to %5 : !fir.ref<i32>
+ }
+ return
+ }
+}
+
+// CHECK-LABEL: gpu.module @cuda_device_mod
+// CHECK: func.func @_QMmtestsPdo2
+// CHECK: fir.if %true
+
+// -----
+
+func.func @_QMmtestsPdo3(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
+ %c2_i32 = arith.constant 2 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %0 = fir.dummy_scope : !fir.dscope
+ %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
+ %14 = fir.convert %13 : (!fir.logical<4>) -> i1
+ fir.if %14 {
+ fir.store %c1_i32 to %5 : !fir.ref<i32>
+ } else {
+ fir.store %c2_i32 to %5 : !fir.ref<i32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @_QMmtestsPdo3
+// CHECK: fir.if %false
|
wangzpgi
approved these changes
Jan 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This rewrite some CUDA Fortran specific like
on_devicefunction to constant boolean values.