Skip to content

Commit 50d5520

Browse files
committed
[SYCL] Fixes for multiple backends in the same program (intel#1252)
2 parents 5fc7ae1 + 63c0b40 commit 50d5520

File tree

7 files changed

+237
-71
lines changed

7 files changed

+237
-71
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ pi_result _pi_event::start() {
149149
}
150150

151151
isStarted_ = true;
152+
// let observers know that the event is "submitted"
153+
trigger_callback(get_execution_status());
152154
return result;
153155
}
154156

@@ -195,6 +197,22 @@ pi_result _pi_event::record() {
195197

196198
try {
197199
result = PI_CHECK_ERROR(cuEventRecord(evEnd_, cuStream));
200+
201+
result = cuda_piEventRetain(this);
202+
try {
203+
result = PI_CHECK_ERROR(cuLaunchHostFunc(
204+
cuStream,
205+
[](void *userData) {
206+
pi_event event = reinterpret_cast<pi_event>(userData);
207+
event->set_event_complete();
208+
cuda_piEventRelease(event);
209+
},
210+
this));
211+
} catch (...) {
212+
// If host function fails to enqueue we must release the event here
213+
result = cuda_piEventRelease(this);
214+
throw;
215+
}
198216
} catch (pi_result error) {
199217
result = error;
200218
}
@@ -215,6 +233,7 @@ pi_result _pi_event::wait() {
215233
if (is_native_event()) {
216234
try {
217235
retErr = PI_CHECK_ERROR(cuEventSynchronize(evEnd_));
236+
isCompleted_ = true;
218237
} catch (pi_result error) {
219238
retErr = error;
220239
}
@@ -226,30 +245,12 @@ pi_result _pi_event::wait() {
226245
retErr = PI_SUCCESS;
227246
}
228247

229-
return retErr;
230-
}
231-
232-
pi_event_status _pi_event::get_execution_status() const noexcept {
248+
auto is_success = retErr == PI_SUCCESS;
249+
auto status = is_success ? get_execution_status() : pi_int32(retErr);
233250

234-
if (!is_recorded()) {
235-
return PI_EVENT_SUBMITTED;
236-
}
237-
238-
if (is_native_event()) {
239-
// native event status
240-
241-
auto status = cuEventQuery(get());
242-
if (status == CUDA_ERROR_NOT_READY) {
243-
return PI_EVENT_RUNNING;
244-
} else if (status != CUDA_SUCCESS) {
245-
cl::sycl::detail::pi::die("Invalid CUDA event status");
246-
}
247-
return PI_EVENT_COMPLETE;
248-
} else {
249-
// user event status
251+
trigger_callback(status);
250252

251-
return is_completed() ? PI_EVENT_COMPLETE : PI_EVENT_RUNNING;
252-
}
253+
return retErr;
253254
}
254255

255256
// iterates over the event wait list, returns correct pi_result error codes.
@@ -2530,24 +2531,21 @@ pi_result cuda_piEventGetInfo(pi_event event, pi_event_info param_name,
25302531

25312532
switch (param_name) {
25322533
case PI_EVENT_INFO_COMMAND_QUEUE:
2533-
return getInfo<pi_queue>(param_value_size, param_value,
2534-
param_value_size_ret, event->get_queue());
2534+
return getInfo(param_value_size, param_value, param_value_size_ret,
2535+
event->get_queue());
25352536
case PI_EVENT_INFO_COMMAND_TYPE:
2536-
return getInfo<pi_command_type>(param_value_size, param_value,
2537-
param_value_size_ret,
2538-
event->get_command_type());
2537+
return getInfo(param_value_size, param_value, param_value_size_ret,
2538+
event->get_command_type());
25392539
case PI_EVENT_INFO_REFERENCE_COUNT:
2540-
return getInfo<pi_uint32>(param_value_size, param_value,
2541-
param_value_size_ret,
2542-
event->get_reference_count());
2540+
return getInfo(param_value_size, param_value, param_value_size_ret,
2541+
event->get_reference_count());
25432542
case PI_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
2544-
return getInfo<pi_event_status>(param_value_size, param_value,
2545-
param_value_size_ret,
2546-
event->get_execution_status());
2543+
return getInfo(param_value_size, param_value, param_value_size_ret,
2544+
static_cast<pi_event_status>(event->get_execution_status()));
25472545
}
25482546
case PI_EVENT_INFO_CONTEXT:
2549-
return getInfo<pi_context>(param_value_size, param_value,
2550-
param_value_size_ret, event->get_context());
2547+
return getInfo(param_value_size, param_value, param_value_size_ret,
2548+
event->get_context());
25512549
default:
25522550
PI_HANDLE_UNKNOWN_PARAM_NAME(param_name);
25532551
}
@@ -2582,13 +2580,21 @@ pi_result cuda_piEventGetProfilingInfo(
25822580
return {};
25832581
}
25842582

2585-
pi_result cuda_piEventSetCallback(
2586-
pi_event event, pi_int32 command_exec_callback_type,
2587-
void (*pfn_notify)(pi_event event, pi_int32 event_command_status,
2588-
void *user_data),
2589-
void *user_data) {
2590-
cl::sycl::detail::pi::die("cuda_piEventSetCallback not implemented");
2591-
return {};
2583+
pi_result cuda_piEventSetCallback(pi_event event,
2584+
pi_int32 command_exec_callback_type,
2585+
pfn_notify notify, void *user_data) {
2586+
2587+
assert(event);
2588+
assert(notify);
2589+
assert(command_exec_callback_type == PI_EVENT_SUBMITTED ||
2590+
command_exec_callback_type == PI_EVENT_RUNNING ||
2591+
command_exec_callback_type == PI_EVENT_COMPLETE);
2592+
event_callback callback(pi_event_status(command_exec_callback_type), notify,
2593+
user_data);
2594+
2595+
event->set_event_callback(callback);
2596+
2597+
return PI_SUCCESS;
25922598
}
25932599

25942600
pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
@@ -2601,7 +2607,7 @@ pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
26012607
}
26022608

26032609
if (execution_status == PI_EVENT_COMPLETE) {
2604-
return event->set_user_event_complete();
2610+
return event->set_event_complete();
26052611
} else if (execution_status < 0) {
26062612
// TODO: A negative integer value causes all enqueued commands that wait
26072613
// on this user event to be terminated.

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,39 @@ struct _pi_queue {
235235
pi_uint32 get_reference_count() const noexcept { return refCount_; }
236236
};
237237

238+
typedef void (*pfn_notify)(pi_event event, pi_int32 eventCommandStatus,
239+
void *userData);
240+
241+
class event_callback {
242+
public:
243+
void trigger_callback(pi_event event, pi_int32 currentEventStatus) const {
244+
245+
auto validParameters = callback_ && event;
246+
247+
// As a pi_event_status value approaches 0, it gets closer to completion.
248+
// If the calling pi_event's status is less than or equal to the event
249+
// status the user is interested in, invoke the callback anyway. The event
250+
// will have passed through that state anyway.
251+
auto validStatus = currentEventStatus <= observedEventStatus_;
252+
253+
if (validParameters && validStatus) {
254+
255+
callback_(event, currentEventStatus, userData_);
256+
}
257+
}
258+
259+
event_callback(pi_event_status status, pfn_notify callback, void *userData)
260+
: observedEventStatus_{status}, callback_{callback}, userData_{userData} {
261+
}
262+
263+
pi_event_status get_status() const noexcept { return observedEventStatus_; }
264+
265+
private:
266+
pi_event_status observedEventStatus_;
267+
pfn_notify callback_;
268+
void *userData_;
269+
};
270+
238271
class _pi_event {
239272
public:
240273
using native_type = CUevent;
@@ -247,18 +280,39 @@ class _pi_event {
247280

248281
native_type get() const noexcept { return evEnd_; };
249282

250-
pi_result set_user_event_complete() noexcept {
283+
pi_result set_event_complete() noexcept {
251284

252285
if (isCompleted_) {
253286
return PI_INVALID_OPERATION;
254287
}
255288

256-
if (is_user_event()) {
257-
isRecorded_ = true;
258-
isCompleted_ = true;
259-
return PI_SUCCESS;
289+
isRecorded_ = true;
290+
isCompleted_ = true;
291+
292+
trigger_callback(get_execution_status());
293+
294+
return PI_SUCCESS;
295+
}
296+
297+
void trigger_callback(pi_int32 status) {
298+
299+
std::vector<event_callback> callbacks;
300+
301+
// Here we move all callbacks into local variable before we call them.
302+
// This is a defensive maneuver; if any of the callbacks attempt to
303+
// add additional callbacks, we will end up in a bad spot. Our mutex
304+
// will be locked twice and the vector will be modified as it is being
305+
// iterated over! By moving everything locally, we can call all of these
306+
// callbacks and let them modify the original vector without much worry.
307+
308+
{
309+
std::lock_guard<std::mutex> lock(mutex_);
310+
event_callbacks_.swap(callbacks);
311+
}
312+
313+
for (auto &event_callback : callbacks) {
314+
event_callback.trigger_callback(this, status);
260315
}
261-
return PI_INVALID_EVENT;
262316
}
263317

264318
pi_queue get_queue() const noexcept { return queue_; }
@@ -273,7 +327,27 @@ class _pi_event {
273327

274328
bool is_started() const noexcept { return isStarted_; }
275329

276-
pi_event_status get_execution_status() const noexcept;
330+
pi_int32 get_execution_status() const noexcept {
331+
332+
if (!is_recorded()) {
333+
return PI_EVENT_SUBMITTED;
334+
}
335+
336+
if (!is_completed()) {
337+
return PI_EVENT_RUNNING;
338+
}
339+
return PI_EVENT_COMPLETE;
340+
}
341+
342+
void set_event_callback(const event_callback &callback) {
343+
auto current_status = get_execution_status();
344+
if (current_status <= callback.get_status()) {
345+
callback.trigger_callback(this, current_status);
346+
} else {
347+
std::lock_guard<std::mutex> lock(mutex_);
348+
event_callbacks_.emplace_back(callback);
349+
}
350+
}
277351

278352
pi_context get_context() const noexcept { return context_; };
279353

@@ -343,6 +417,12 @@ class _pi_event {
343417
pi_context context_; // pi_context associated with the event. If this is a
344418
// native event, this will be the same context associated
345419
// with the queue_ member.
420+
421+
std::mutex mutex_; // Protect access to event_callbacks_. TODO: There might be
422+
// a lock-free data structure we can use here.
423+
std::vector<event_callback>
424+
event_callbacks_; // Callbacks that can be triggered when an event's state
425+
// changes.
346426
};
347427

348428
struct _pi_program {

sycl/source/detail/scheduler/commands.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,45 +161,50 @@ void EventCompletionClbk(RT::PiEvent, pi_int32, void *data) {
161161
EventImplPtr *Event = (reinterpret_cast<EventImplPtr *>(data));
162162
RT::PiEvent &EventHandle = (*Event)->getHandleRef();
163163
const detail::plugin &Plugin = (*Event)->getPlugin();
164-
Plugin.call<PiApiKind::piEventSetStatus>(EventHandle, CL_COMPLETE);
164+
Plugin.call<PiApiKind::piEventSetStatus>(EventHandle, PI_EVENT_COMPLETE);
165165
delete (Event);
166166
}
167167

168168
// Method prepares PI event's from list sycl::event's
169169
std::vector<EventImplPtr> Command::prepareEvents(ContextImplPtr Context) {
170170
std::vector<EventImplPtr> Result;
171171
std::vector<EventImplPtr> GlueEvents;
172-
for (EventImplPtr &Event : MDepsEvents) {
172+
for (EventImplPtr &DepEvent : MDepsEvents) {
173173
// Async work is not supported for host device.
174-
if (Event->is_host()) {
175-
Event->waitInternal();
174+
if (DepEvent->is_host()) {
175+
DepEvent->waitInternal();
176176
continue;
177177
}
178178
// The event handle can be null in case of, for example, alloca command,
179179
// which is currently synchrounious, so don't generate OpenCL event.
180-
if (Event->getHandleRef() == nullptr) {
180+
if (DepEvent->getHandleRef() == nullptr) {
181181
continue;
182182
}
183-
ContextImplPtr EventContext = Event->getContextImpl();
184-
const detail::plugin &Plugin = Event->getPlugin();
185-
// If contexts don't match - connect them using user event
186-
if (EventContext != Context && !Context->is_host()) {
183+
ContextImplPtr DepEventContext = DepEvent->getContextImpl();
187184

185+
// If contexts don't match - connect them using user event
186+
if (DepEventContext != Context && !Context->is_host()) {
188187
EventImplPtr GlueEvent(new detail::event_impl());
189188
GlueEvent->setContextImpl(Context);
189+
EventImplPtr *GlueEventCopy =
190+
new EventImplPtr(GlueEvent); // To increase the reference count by 1.
191+
190192
RT::PiEvent &GlueEventHandle = GlueEvent->getHandleRef();
193+
auto Plugin = Context->getPlugin();
194+
auto DepPlugin = DepEventContext->getPlugin();
195+
// Add an event on the current context that
196+
// is triggered when the DepEvent is complete
191197
Plugin.call<PiApiKind::piEventCreate>(Context->getHandleRef(),
192198
&GlueEventHandle);
193-
EventImplPtr *GlueEventCopy =
194-
new EventImplPtr(GlueEvent); // To increase the reference count by 1.
195-
Plugin.call<PiApiKind::piEventSetCallback>(
196-
Event->getHandleRef(), CL_COMPLETE, EventCompletionClbk,
199+
200+
DepPlugin.call<PiApiKind::piEventSetCallback>(
201+
DepEvent->getHandleRef(), PI_EVENT_COMPLETE, EventCompletionClbk,
197202
/*void *data=*/(GlueEventCopy));
198203
GlueEvents.push_back(GlueEvent);
199204
Result.push_back(std::move(GlueEvent));
200205
continue;
201206
}
202-
Result.push_back(Event);
207+
Result.push_back(DepEvent);
203208
}
204209
MDepsEvents.insert(MDepsEvents.end(), GlueEvents.begin(), GlueEvents.end());
205210
return Result;

sycl/test/basic_tests/buffer/buffer_dev_to_dev.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
// RUN: %GPU_RUN_PLACEHOLDER %t.out
55
// RUN: %ACC_RUN_PLACEHOLDER %t.out
66

7-
// TODO: pi_die: cuda_piEventSetCallback not implemented
8-
// XFAIL: cuda
9-
107
//==---------- buffer_dev_to_dev.cpp - SYCL buffer basic test --------------==//
118
//
129
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/DataMovement.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33
//
4-
// XFAIL: cuda
54
//==-------------------------- DataMovement.cpp ----------------------------==//
65
//
76
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/MultipleDevices.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33

4-
// TODO: pi_die: cuda_piEventSetCallback not implemented
5-
// XFAIL: cuda
6-
74
//===- MultipleDevices.cpp - Test checking multi-device execution --------===//
85
//
96
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

0 commit comments

Comments
 (0)