Skip to content

Commit ae89c1e

Browse files
authored
[OpenCL] Don't initialize OpenCL runtime on host (#15745)
* [OpenCL] Don't initialize OpenCL runtime on host After adding OpenCL wrapper, it is possible to build TVM with OpenCL support also on the host which doesn't have OpenCL libraries. But if you want to compile OpenCL module for a remote device on such host machine then you will see an error that OpenCL lib cannot be open. To avoid such problem, we need to call OpenCL functions only in runtime. So function for initializing OpenCL workspace was removed from OpenCLModuleNode. And a new function `IsProgramCreated` was added. The last function is necessary to prepare vectors with OpenCL programs, associated with OpenCL devices. Previously it was done during OpenCLModule initialization. So, now we create such vectors only in runtime after getting list of available OpenCL devices. * Call workspace init function before all OpenCL API calls
1 parent ab3511a commit ae89c1e

File tree

4 files changed

+24
-9
lines changed

4 files changed

+24
-9
lines changed

src/runtime/opencl/opencl_common.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ struct BufferDescriptor;
220220
class OpenCLWorkspace : public DeviceAPI {
221221
public:
222222
// type key
223-
std::string type_key;
223+
std::string type_key{"opencl"};
224224
// available platforms
225225
std::vector<cl_platform_id> platform_ids;
226226
// map platform to its context
@@ -253,7 +253,7 @@ class OpenCLWorkspace : public DeviceAPI {
253253
// Initialize the device.
254254
void Init(const std::string& type_key, const std::string& device_type,
255255
const std::string& platform_name = "");
256-
virtual void Init() { Init("opencl", "gpu"); }
256+
virtual void Init() { Init(this->type_key, "gpu"); }
257257
// Check whether the context is OpenCL or not.
258258
virtual bool IsOpenCLDevice(Device dev) { return dev.device_type == kDLOpenCL; }
259259
// get the queue of the device
@@ -465,6 +465,8 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase {
465465
: OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {}
466466

467467
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
468+
// Return true if OpenCL program for the requested function and device was created
469+
bool IsProgramCreated(const std::string& func_name, int device_id);
468470
void SaveToFile(const String& file_name, const String& format) final;
469471
void SaveToBinary(dmlc::Stream* stream) final;
470472
void SetPreCompiledPrograms(const std::string& bytes);

src/runtime/opencl/opencl_device_api.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
111111
}
112112

113113
cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) {
114+
this->Init();
114115
ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError();
115116
return devices[device_id];
116117
}
@@ -210,6 +211,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
210211

211212
void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device dev, size_t size) {
212213
#if defined(OPENCL_ENABLE_HOST_PTR)
214+
this->Init();
213215
cl_int err_code;
214216
desc->host_ptr = reinterpret_cast<cl_uchar*>(
215217
clEnqueueMapBuffer(this->GetQueue(dev), desc->buffer, CL_TRUE, CL_MAP_WRITE, 0,
@@ -300,6 +302,7 @@ void OpenCLWorkspace::FreeTextureWorkspace(Device dev, void* ptr) {
300302
}
301303

302304
void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
305+
this->Init();
303306
size_t nbytes = GetDataSize(*from);
304307
ICHECK_EQ(nbytes, GetDataSize(*to));
305308
ICHECK(IsContiguous(*from) && IsContiguous(*to))
@@ -379,6 +382,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand
379382
}
380383

381384
void OpenCLWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
385+
this->Init();
382386
ICHECK(stream == nullptr);
383387
OPENCL_CALL(clFinish(this->GetQueue(dev)));
384388
}

src/runtime/opencl/opencl_module.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ String OpenCLModuleNode::GetSource(const String& format) {
185185

186186
void OpenCLModuleNode::Init() {
187187
workspace_ = GetGlobalWorkspace();
188-
workspace_->Init();
189188
// initialize the kernel id, need to lock global table.
190189
std::lock_guard<std::mutex> lock(workspace_->mu);
191190
for (const auto& kv : fmap_) {
@@ -208,10 +207,17 @@ void OpenCLModuleNode::Init() {
208207
<< "delimiter was found.";
209208
ICHECK_EQ(fmap_.size(), parsed_kernels_.size())
210209
<< "The number of parsed kernel sources does not match the number of kernel functions";
210+
}
211+
212+
bool OpenCLModuleNode::IsProgramCreated(const std::string& func_name, int device_id) {
213+
auto size = programs_[func_name].size();
214+
if (size > 0 && programs_[func_name][device_id] != nullptr) return true;
215+
auto dev_size = GetGlobalWorkspace()->devices.size();
216+
ICHECK(device_id < static_cast<int>(dev_size))
217+
<< "Device id " << device_id << " is bigger than number of available devices";
211218
// zero initialize cl_program pointers for each device kernel
212-
for (auto& kv : parsed_kernels_) {
213-
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
214-
}
219+
if (size == 0) programs_[func_name].resize(dev_size, nullptr);
220+
return false;
215221
}
216222

217223
cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
@@ -220,7 +226,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
220226
int device_id = t->device.device_id;
221227
auto did = w->GetCLDeviceID(device_id);
222228
auto platform = w->device_to_platform[did];
223-
if (programs_[func_name][device_id] == nullptr) {
229+
if (!IsProgramCreated(func_name, device_id)) {
224230
// create program
225231
if (fmt_ == "cl") {
226232
const char* s = parsed_kernels_[func_name].c_str();
@@ -268,6 +274,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
268274
}
269275

270276
void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
277+
workspace_->Init();
271278
std::string data = bytes;
272279
dmlc::MemoryStringStream reader(&data);
273280
dmlc::Stream* strm = &reader;
@@ -280,7 +287,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
280287
std::vector<unsigned char> bin_vector;
281288
strm->Read(&name);
282289
strm->Read(&bin_vector);
283-
if (programs_[name][device_id] == nullptr) {
290+
if (!IsProgramCreated(name, device_id)) {
284291
cl_int err = 0;
285292
cl_int binaryStatus;
286293
size_t binarySize = bin_vector.size();
@@ -310,6 +317,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) {
310317
}
311318

312319
std::string OpenCLModuleNode::GetPreCompiledPrograms() {
320+
workspace_->Init();
313321
std::string data;
314322
dmlc::MemoryStringStream writer(&data);
315323
dmlc::Stream* strm = &writer;
@@ -319,7 +327,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() {
319327
cl::OpenCLThreadEntry* t = workspace_->GetThreadEntry();
320328
int device_id = t->device.device_id;
321329
t->kernel_table.resize(workspace_->num_registered_kernels);
322-
if (programs_[std::string(name)][device_id] == nullptr) {
330+
if (!IsProgramCreated(name, device_id)) {
323331
InstallKernel(workspace_, t, name, kid_map_[name]);
324332
}
325333
size_t size;

src/runtime/opencl/opencl_module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ namespace runtime {
4242
* \param data The module data.
4343
* \param fmt The format of the data, can be "clbin", "cl"
4444
* \param fmap The map function information map of each function.
45+
* \param source Generated OpenCL kernels.
4546
*/
4647
Module OpenCLModuleCreate(std::string data, std::string fmt,
4748
std::unordered_map<std::string, FunctionInfo> fmap, std::string source);

0 commit comments

Comments
 (0)