diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp index 424a8fd9d959b..352f8abde6093 100644 --- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp @@ -29,6 +29,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -49,9 +50,9 @@ namespace { static bool inDeviceContext(mlir::Operation *op) { if (op->getParentOfType()) return true; - if (auto funcOp = op->getParentOfType()) + if (op->getParentOfType()) return true; - if (auto funcOp = op->getParentOfType()) + if (auto funcOp = op->getParentOfType()) return true; if (auto funcOp = op->getParentOfType()) { if (auto cudaProcAttr = @@ -128,6 +129,9 @@ struct DeclareOpConversion : public mlir::OpRewritePattern { if (op.getResult().getUsers().empty()) return success(); if (auto addrOfOp = op.getMemref().getDefiningOp()) { + if (inDeviceContext(addrOfOp)) { + return failure(); + } if (auto global = symTab.lookup( addrOfOp.getSymbol().getRootReference().getValue())) { if (cuf::isRegisteredDeviceGlobal(global)) { diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir index 6f7816c9163cb..ae88af3d3c16c 100644 --- a/flang/test/Fir/CUDA/cuda-global-addr.mlir +++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir @@ -94,6 +94,33 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} { // ----- +// Check that we do not introduce call to _FortranACUFGetDeviceAddress when the +// address_of is inside an acc.parallel region (OffloadRegionOpInterface). + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry, dense<64> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry, dense<32> : vector<4xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<2xi64>>, #dlti.dl_entry : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} { +fir.global @_QMmod1Eadev_acc {data_attr = #cuf.cuda} : !fir.array<10xi32> { + %0 = fir.zero_bits !fir.array<10xi32> + fir.has_value %0 : !fir.array<10xi32> +} +func.func @_QQmain_acc() attributes {fir.bindc_name = "test_acc"} { + acc.parallel { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %3 = fir.address_of(@_QMmod1Eadev_acc) : !fir.ref> + %4 = fir.declare %3(%1) {data_attr = #cuf.cuda, uniq_name = "_QMmod1Eadev_acc"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + acc.yield + } + return +} + +// CHECK-LABEL: func.func @_QQmain_acc() +// CHECK: acc.parallel +// CHECK-NOT: fir.call {{.*}}GetDeviceAddress + +} + +// ----- + // Check that we do not introduce call to _FortranACUFGetDeviceAddress when the // value has no user.