@@ -149,6 +149,8 @@ pi_result _pi_event::start() {
149149 }
150150
151151 isStarted_ = true ;
152+ // let observers know that the event is "submitted"
153+ trigger_callback (get_execution_status ());
152154 return result;
153155}
154156
@@ -195,6 +197,22 @@ pi_result _pi_event::record() {
195197
196198 try {
197199 result = PI_CHECK_ERROR (cuEventRecord (evEnd_, cuStream));
200+
201+ result = cuda_piEventRetain (this );
202+ try {
203+ result = PI_CHECK_ERROR (cuLaunchHostFunc (
204+ cuStream,
205+ [](void *userData) {
206+ pi_event event = reinterpret_cast <pi_event>(userData);
207+ event->set_event_complete ();
208+ cuda_piEventRelease (event);
209+ },
210+ this ));
211+ } catch (...) {
212+ // If host function fails to enqueue we must release the event here
213+ result = cuda_piEventRelease (this );
214+ throw ;
215+ }
198216 } catch (pi_result error) {
199217 result = error;
200218 }
@@ -215,6 +233,7 @@ pi_result _pi_event::wait() {
215233 if (is_native_event ()) {
216234 try {
217235 retErr = PI_CHECK_ERROR (cuEventSynchronize (evEnd_));
236+ isCompleted_ = true ;
218237 } catch (pi_result error) {
219238 retErr = error;
220239 }
@@ -226,30 +245,12 @@ pi_result _pi_event::wait() {
226245 retErr = PI_SUCCESS;
227246 }
228247
229- return retErr;
230- }
231-
232- pi_event_status _pi_event::get_execution_status () const noexcept {
248+ auto is_success = retErr == PI_SUCCESS;
249+ auto status = is_success ? get_execution_status () : pi_int32 (retErr);
233250
234- if (!is_recorded ()) {
235- return PI_EVENT_SUBMITTED;
236- }
237-
238- if (is_native_event ()) {
239- // native event status
240-
241- auto status = cuEventQuery (get ());
242- if (status == CUDA_ERROR_NOT_READY) {
243- return PI_EVENT_RUNNING;
244- } else if (status != CUDA_SUCCESS) {
245- cl::sycl::detail::pi::die (" Invalid CUDA event status" );
246- }
247- return PI_EVENT_COMPLETE;
248- } else {
249- // user event status
251+ trigger_callback (status);
250252
251- return is_completed () ? PI_EVENT_COMPLETE : PI_EVENT_RUNNING;
252- }
253+ return retErr;
253254}
254255
255256// iterates over the event wait list, returns correct pi_result error codes.
@@ -2530,24 +2531,21 @@ pi_result cuda_piEventGetInfo(pi_event event, pi_event_info param_name,
25302531
25312532 switch (param_name) {
25322533 case PI_EVENT_INFO_COMMAND_QUEUE:
2533- return getInfo<pi_queue> (param_value_size, param_value,
2534- param_value_size_ret, event->get_queue ());
2534+ return getInfo (param_value_size, param_value, param_value_size_ret ,
2535+ event->get_queue ());
25352536 case PI_EVENT_INFO_COMMAND_TYPE:
2536- return getInfo<pi_command_type>(param_value_size, param_value,
2537- param_value_size_ret,
2538- event->get_command_type ());
2537+ return getInfo (param_value_size, param_value, param_value_size_ret,
2538+ event->get_command_type ());
25392539 case PI_EVENT_INFO_REFERENCE_COUNT:
2540- return getInfo<pi_uint32>(param_value_size, param_value,
2541- param_value_size_ret,
2542- event->get_reference_count ());
2540+ return getInfo (param_value_size, param_value, param_value_size_ret,
2541+ event->get_reference_count ());
25432542 case PI_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
2544- return getInfo<pi_event_status>(param_value_size, param_value,
2545- param_value_size_ret,
2546- event->get_execution_status ());
2543+ return getInfo (param_value_size, param_value, param_value_size_ret,
2544+ static_cast <pi_event_status>(event->get_execution_status ()));
25472545 }
25482546 case PI_EVENT_INFO_CONTEXT:
2549- return getInfo<pi_context> (param_value_size, param_value,
2550- param_value_size_ret, event->get_context ());
2547+ return getInfo (param_value_size, param_value, param_value_size_ret ,
2548+ event->get_context ());
25512549 default :
25522550 PI_HANDLE_UNKNOWN_PARAM_NAME (param_name);
25532551 }
@@ -2582,13 +2580,21 @@ pi_result cuda_piEventGetProfilingInfo(
25822580 return {};
25832581}
25842582
2585- pi_result cuda_piEventSetCallback (
2586- pi_event event, pi_int32 command_exec_callback_type,
2587- void (*pfn_notify)(pi_event event, pi_int32 event_command_status,
2588- void *user_data),
2589- void *user_data) {
2590- cl::sycl::detail::pi::die (" cuda_piEventSetCallback not implemented" );
2591- return {};
2583+ pi_result cuda_piEventSetCallback (pi_event event,
2584+ pi_int32 command_exec_callback_type,
2585+ pfn_notify notify, void *user_data) {
2586+
2587+ assert (event);
2588+ assert (notify);
2589+ assert (command_exec_callback_type == PI_EVENT_SUBMITTED ||
2590+ command_exec_callback_type == PI_EVENT_RUNNING ||
2591+ command_exec_callback_type == PI_EVENT_COMPLETE);
2592+ event_callback callback (pi_event_status (command_exec_callback_type), notify,
2593+ user_data);
2594+
2595+ event->set_event_callback (callback);
2596+
2597+ return PI_SUCCESS;
25922598}
25932599
25942600pi_result cuda_piEventSetStatus (pi_event event, pi_int32 execution_status) {
@@ -2601,7 +2607,7 @@ pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
26012607 }
26022608
26032609 if (execution_status == PI_EVENT_COMPLETE) {
2604- return event->set_user_event_complete ();
2610+ return event->set_event_complete ();
26052611 } else if (execution_status < 0 ) {
26062612 // TODO: A negative integer value causes all enqueued commands that wait
26072613 // on this user event to be terminated.
0 commit comments