forked from NVIDIA/cccl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnvrtcc_common.h
114 lines (96 loc) · 4.04 KB
/
nvrtcc_common.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <fstream>
#include <iostream>
#include <regex>
#include <string>
#include <vector>
#define NVRTC_SAFE_CALL(x) \
do { \
nvrtcResult result = x; \
if (result != NVRTC_SUCCESS) { \
printf("\nNVRTC ERROR: %s failed with error %s\n", \
#x, nvrtcGetErrorString(result)); \
exit(1); \
} \
} while(0)
#define CUDA_SAFE_CALL(x) \
do { \
CUresult result = x; \
if (result != CUDA_SUCCESS) { \
const char *msg; \
cuGetErrorName(result, &msg); \
printf("\nCUDA ERROR: %s failed with error %s\n", #x, msg); \
exit(1); \
} \
} while(0)
#define CUDA_API_CALL(x) \
do { \
cudaError_t err = x; \
if (err != cudaSuccess) \
{ \
printf("\nCUDA ERROR: %s: %s\n", \
cudaGetErrorName(err), cudaGetErrorString(err)); \
exit(1); \
} \
} while(0)
static void write_output_file(const char* data, size_t datasz, const std::string& file) {
std::ofstream ostr(file, std::ios::binary);
assert(!!ostr);
ostr.write(data, datasz);
ostr.close();
}
static std::string load_input_file(const std::string& file) {
if (file == "-") {
return std::string(
std::istream_iterator<char>{std::cin},
std::istream_iterator<char>{} );
} else {
std::ifstream istr(file);
assert(!!istr);
return std::string(
std::istreambuf_iterator<char>{istr},
std::istreambuf_iterator<char>{} );
}
}
static int parse_int_assignment(const std::string& input, std::string var, int def) {
auto lineBegin = input.find(var);
auto lineEnd = input.find('\n', lineBegin);
if (lineBegin == std::string::npos || lineEnd == std::string::npos)
return def;
std::string line(input.begin()+lineBegin, input.begin()+lineEnd);
std::regex varRegex("^" + var + ".*([0-9]+);?$");
std::smatch match;
std::regex_match(line, match, varRegex);
if (match.size()) {
return std::stoi(match[1].str(), nullptr);
}
return def;
}
struct RunConfig {
int threadCount = 1;
int shmemSize = 0;
};
static RunConfig parse_run_config(const std::string& input) {
return RunConfig {
parse_int_assignment(input, "cuda_thread_count", 1),
parse_int_assignment(input, "cuda_block_shmem_size", 0),
};
}
// Fake main for adapting kernels
static const char * program = R"program(
__host__ __device__ int fake_main(int argc, char ** argv);
#define main fake_main
// extern "C" to stop the name from being mangled
extern "C" __global__ void main_kernel() {
fake_main(0, NULL);
}
)program";