Skip to content

Commit ac8ab30

Browse files
committed
[CIR][CUDA] Skeleton of NVPTX target lowering info
1 parent a0091e3 commit ac8ab30

File tree

5 files changed

+110
-0
lines changed

5 files changed

+110
-0
lines changed

clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_clang_library(TargetLowering
1313
TargetInfo.cpp
1414
TargetLoweringInfo.cpp
1515
Targets/AArch64.cpp
16+
Targets/NVPTX.cpp
1617
Targets/SPIR.cpp
1718
Targets/X86.cpp
1819
Targets/LoweringPrepareAArch64CXXABI.cpp

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ createTargetLoweringInfo(LowerModule &LM) {
8181
}
8282
case llvm::Triple::spirv64:
8383
return createSPIRVTargetLoweringInfo(LM);
84+
case llvm::Triple::nvptx64:
85+
return createNVPTXTargetLoweringInfo(LM);
8486
default:
8587
cir_cconv_unreachable("ABI NYI");
8688
}

clang/lib/CIR/Dialect/Transforms/TargetLowering/TargetInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ createAArch64TargetLoweringInfo(LowerModule &CGM, cir::AArch64ABIKind AVXLevel);
3030
std::unique_ptr<TargetLoweringInfo>
3131
createSPIRVTargetLoweringInfo(LowerModule &CGM);
3232

33+
std::unique_ptr<TargetLoweringInfo>
34+
createNVPTXTargetLoweringInfo(LowerModule &CGM);
35+
3336
} // namespace cir
3437

3538
#endif // LLVM_CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_TARGETINFO_H
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- NVPTX.cpp - TargetInfo for NVPTX -----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "ABIInfoImpl.h"
10+
#include "LowerFunctionInfo.h"
11+
#include "LowerTypes.h"
12+
#include "TargetInfo.h"
13+
#include "TargetLoweringInfo.h"
14+
#include "clang/CIR/ABIArgInfo.h"
15+
#include "clang/CIR/MissingFeatures.h"
16+
#include "llvm/Support/ErrorHandling.h"
17+
18+
using ABIArgInfo = cir::ABIArgInfo;
19+
using MissingFeature = cir::MissingFeatures;
20+
21+
namespace cir {
22+
23+
//===----------------------------------------------------------------------===//
24+
// NVPTX ABI Implementation
25+
//===----------------------------------------------------------------------===//
26+
27+
namespace {
28+
29+
class NVPTXABIInfo : public ABIInfo {
30+
public:
31+
NVPTXABIInfo(LowerTypes &lt) : ABIInfo(lt) {}
32+
33+
private:
34+
void computeInfo(LowerFunctionInfo &fi) const override {
35+
llvm_unreachable("NYI");
36+
}
37+
};
38+
39+
class NVPTXTargetLoweringInfo : public TargetLoweringInfo {
40+
public:
41+
NVPTXTargetLoweringInfo(LowerTypes &lt)
42+
: TargetLoweringInfo(std::make_unique<NVPTXABIInfo>(lt)) {}
43+
44+
unsigned getTargetAddrSpaceFromCIRAddrSpace(
45+
cir::AddressSpaceAttr addressSpaceAttr) const override {
46+
using Kind = cir::AddressSpaceAttr::Kind;
47+
switch (addressSpaceAttr.getValue()) {
48+
case Kind::offload_private:
49+
return 0;
50+
case Kind::offload_local:
51+
return 3;
52+
case Kind::offload_global:
53+
return 1;
54+
case Kind::offload_constant:
55+
return 2;
56+
case Kind::offload_generic:
57+
return 4;
58+
default:
59+
cir_cconv_unreachable("Unknown CIR address space for this target");
60+
}
61+
}
62+
};
63+
64+
} // namespace
65+
66+
std::unique_ptr<TargetLoweringInfo>
67+
createNVPTXTargetLoweringInfo(LowerModule &lowerModule) {
68+
return std::make_unique<NVPTXTargetLoweringInfo>(lowerModule.getTypes());
69+
}
70+
71+
} // namespace cir

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,34 @@
1010
// RUN: %s -o %t.cir
1111
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1212

13+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
14+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
15+
// RUN: %s -o %t.cir
16+
// RUN: FileCheck --check-prefix=LLVM-HOST --input-file=%t.cir %s
17+
18+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
19+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
20+
// RUN: %s -o %t.cir
21+
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.cir %s
22+
1323
// Attribute for global_fn
1424
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
1525

1626
__host__ void host_fn(int *a, int *b, int *c) {}
1727
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
1828
// CIR-DEVICE-NOT: cir.func @_Z7host_fnPiS_S_
29+
// LLVM-HOST: void @_Z7host_fnPiS_S_
30+
// LLVM-DEVICE-NOT: void @_Z7host_fnPiS_S_
1931

2032
__device__ void device_fn(int* a, double b, float c) {}
2133
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
2234
// CIR-DEVICE: cir.func @_Z9device_fnPidf
35+
// LLVM-HOST-NOT: void @_Z9device_fnPidf
36+
// LLVM-DEVICE: void @_Z9device_fnPidf
2337

2438
__global__ void global_fn(int a) {}
2539
// CIR-DEVICE: @_Z9global_fni
40+
// LLVM-DEVICE: @_Z9global_fni
2641

2742
// Check for device stub emission.
2843

@@ -32,10 +47,16 @@ __global__ void global_fn(int a) {}
3247
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
3348
// CIR-HOST: cir.call @cudaLaunchKernel
3449

50+
// LLVM-HOST: void @_Z24__device_stub__global_fni
51+
// LLVM-HOST: alloca [1 x ptr], i64 1, align 16
52+
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
53+
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni, %struct.dim3 %{{[0-9]+}}, %struct.dim3 %{{[0-9]+}}, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, ptr %{{[0-9]+}})
54+
3555
int main() {
3656
global_fn<<<1, 1>>>(1);
3757
}
3858
// CIR-DEVICE-NOT: cir.func @main()
59+
// LLVM-DEVICE-NOT: i32 @main()
3960

4061
// CIR-HOST: cir.func @main()
4162
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
@@ -46,3 +67,15 @@ int main() {
4667
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
4768
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
4869
// CIR-HOST: }
70+
71+
// LLVM-HOST: i32 @main()
72+
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
73+
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
74+
// LLVM-HOST: [[PushLLVM:%[0-9]+]] = call i32 @__cudaPushCallConfiguration
75+
// LLVM-HOST: [[ConfigOKLLVM:%[0-9]+]] = icmp ne i32 [[PushLLVM]], 0
76+
// LLVM-HOST: br i1 [[ConfigOKLLVM]], label %[[Ifso:[0-9]+]], label %[[Ifnot:[0-9]+]]
77+
// LLVM-HOST: [[Ifso]]:
78+
// LLVM-HOST: call void @_Z24__device_stub__global_fni(i32 1)
79+
// LLVM-HOST: br label %[[Ifnot]]
80+
// LLVM-HOST: [[Ifnot]]:
81+
// LLVM-HOST: ret i32

0 commit comments

Comments
 (0)