Skip to content

[flang][cuda] Add CUFFunctionRewrite pass#174650

Merged
clementval merged 1 commit intollvm:mainfrom
clementval:cuf_function_rewrite
Jan 6, 2026
Merged

[flang][cuda] Add CUFFunctionRewrite pass#174650
clementval merged 1 commit intollvm:mainfrom
clementval:cuf_function_rewrite

Conversation

@clementval
Copy link
Contributor

This rewrite some CUDA Fortran specific like on_device function to constant boolean values.

@clementval clementval requested a review from wangzpgi January 6, 2026 21:10
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 6, 2026
@llvmbot
Copy link
Member

llvmbot commented Jan 6, 2026

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This rewrite some CUDA Fortran specific like on_device function to constant boolean values.


Full diff: https://github.com/llvm/llvm-project/pull/174650.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+5)
  • (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1)
  • (added) flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp (+103)
  • (added) flang/test/Fir/CUDA/cuda-function-rewrite.mlir (+44)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 47ffc4be93b33..dd2023223f1d2 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -533,6 +533,11 @@ def CUFDeviceFuncTransform
                         /*default=*/"0", "CUDA compute capability version">];
 }
 
+def CUFFunctionRewrite : Pass<"cuf-function-rewrite", ""> {
+  let summary = "Convert some CUDA Fortran specific call";
+  let dependentDialects = ["fir::FIROpsDialect"];
+}
+
 def CUFLaunchAttachAttr : Pass<"cuf-launch-attach-attr", ""> {
   let summary = "Attach CUDA attribute to CUF kernel generated launch";
   let description = [{
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 4496e80aa7c40..4ee5eab6247e1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_flang_library(FIRTransforms
   ControlFlowConverter.cpp
   CUDA/CUFAllocationConversion.cpp
   CUDA/CUFDeviceFuncTransform.cpp
+  CUDA/CUFFunctionRewrite.cpp
   CUDA/CUFLaunchAttachAttr.cpp
   CUDA/CUFPredefinedVarToGPU.cpp
   CUFAddConstructor.cpp
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
new file mode 100644
index 0000000000000..d6f9dc097831c
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
@@ -0,0 +1,103 @@
+//===-- CUFFUnctionRewrite.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include <string_view>
+
+#define DEBUG_TYPE "flang-cuf-function-rewrite"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+using genFunctionType =
+    std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;
+
+class CallConversion : public OpRewritePattern<fir::CallOp> {
+public:
+  CallConversion(MLIRContext *context)
+      : OpRewritePattern<fir::CallOp>(context) {}
+
+  LogicalResult
+  matchAndRewrite(fir::CallOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto callee = op.getCallee();
+    if (!callee)
+      return failure();
+    auto name = callee->getRootReference().getValue();
+
+    if (genMappings_.contains(name)) {
+      auto fct = genMappings_.find(name);
+      mlir::Value result = fct->second(rewriter, op);
+      if (result)
+        rewriter.replaceOp(op, result);
+      else
+        rewriter.eraseOp(op);
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
+                                 fir::CallOp op) {
+    assert(op.getArgs().size() == 0 && "expect 0 arguments");
+    mlir::Location loc = op.getLoc();
+    unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
+    mlir::Type i1Ty = rewriter.getIntegerType(1);
+    mlir::Value t = mlir::arith::ConstantOp::create(
+        rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
+    return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
+  }
+
+  const llvm::StringMap<genFunctionType> genMappings_ = {
+      {"on_device", &genOnDevice}};
+};
+
+class CUFFunctionRewrite
+    : public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
+public:
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    mlir::RewritePatternSet patterns(ctx);
+
+    patterns.insert<CallConversion>(patterns.getContext());
+
+    if (mlir::failed(
+            mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(ctx),
+                      "error in CUFFunctionRewrite op conversion\n");
+      signalPassFailure();
+    }
+  }
+};
+
+} // namespace
diff --git a/flang/test/Fir/CUDA/cuda-function-rewrite.mlir b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir
new file mode 100644
index 0000000000000..da1d601a2eb8b
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir
@@ -0,0 +1,44 @@
+// RUN: fir-opt --split-input-file --cuf-function-rewrite %s | FileCheck %s
+
+gpu.module @cuda_device_mod {
+  func.func @_QMmtestsPdo2(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
+    %c2_i32 = arith.constant 2 : i32
+    %c1_i32 = arith.constant 1 : i32
+    %0 = fir.dummy_scope : !fir.dscope
+    %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+    %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+    %13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
+    %14 = fir.convert %13 : (!fir.logical<4>) -> i1
+    fir.if %14 {
+      fir.store %c1_i32 to %5 : !fir.ref<i32>
+    } else {
+      fir.store %c2_i32 to %5 : !fir.ref<i32>
+    }
+    return
+  }
+}
+
+// CHECK-LABEL: gpu.module @cuda_device_mod
+// CHECK: func.func @_QMmtestsPdo2
+// CHECK: fir.if %true
+
+// -----
+
+func.func @_QMmtestsPdo3(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
+  %c2_i32 = arith.constant 2 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %0 = fir.dummy_scope : !fir.dscope
+  %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+  %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+  %13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
+  %14 = fir.convert %13 : (!fir.logical<4>) -> i1
+  fir.if %14 {
+    fir.store %c1_i32 to %5 : !fir.ref<i32>
+  } else {
+    fir.store %c2_i32 to %5 : !fir.ref<i32>
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @_QMmtestsPdo3
+// CHECK: fir.if %false

@clementval clementval merged commit a19b464 into llvm:main Jan 6, 2026
11 of 12 checks passed
@clementval clementval deleted the cuf_function_rewrite branch January 6, 2026 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants