@@ -528,43 +528,57 @@ pi_result cuda_piPlatformsGet(pi_uint32 num_entries, pi_platform *platforms,
528528 pi_uint32 *num_platforms) {
529529
530530 try {
531- static constexpr pi_uint32 numPlatforms = 1 ;
531+ static std::once_flag initFlag;
532+ static pi_uint32 numPlatforms = 1 ;
533+ static _pi_platform platformId;
532534
533- if (num_platforms != nullptr ) {
534- *num_platforms = numPlatforms;
535+ if (num_entries == 0 and platforms != nullptr ) {
536+ return PI_INVALID_VALUE;
537+ }
538+ if (platforms == nullptr and num_platforms == nullptr ) {
539+ return PI_INVALID_VALUE;
535540 }
536541
537542 pi_result err = PI_SUCCESS;
538543
539- if (platforms != nullptr ) {
540-
541- assert (num_entries != 0 );
542-
543- static std::once_flag initFlag;
544- static _pi_platform platformId;
545- std::call_once (
546- initFlag,
547- [](pi_result &err) {
548- err = PI_CHECK_ERROR (cuInit (0 ));
549-
550- int numDevices = 0 ;
551- err = PI_CHECK_ERROR (cuDeviceGetCount (&numDevices));
544+ std::call_once (
545+ initFlag,
546+ [](pi_result &err) {
547+ if (cuInit (0 ) != CUDA_SUCCESS) {
548+ numPlatforms = 0 ;
549+ return ;
550+ }
551+ int numDevices = 0 ;
552+ err = PI_CHECK_ERROR (cuDeviceGetCount (&numDevices));
553+ if (numDevices == 0 ) {
554+ numPlatforms = 0 ;
555+ return ;
556+ }
557+ try {
552558 platformId.devices_ .reserve (numDevices);
553- try {
554- for (int i = 0 ; i < numDevices; ++i) {
555- CUdevice device;
556- err = PI_CHECK_ERROR (cuDeviceGet (&device, i));
557- platformId.devices_ .emplace_back (
558- new _pi_device{device, &platformId});
559- }
560- } catch (...) {
561- // Clear and rethrow to allow retry
562- platformId.devices_ .clear ();
563- throw ;
559+ for (int i = 0 ; i < numDevices; ++i) {
560+ CUdevice device;
561+ err = PI_CHECK_ERROR (cuDeviceGet (&device, i));
562+ platformId.devices_ .emplace_back (
563+ new _pi_device{device, &platformId});
564564 }
565- },
566- err);
565+ } catch (const std::bad_alloc &) {
566+ // Signal out-of-memory situation
567+ platformId.devices_ .clear ();
568+ err = PI_OUT_OF_HOST_MEMORY;
569+ } catch (...) {
570+ // Clear and rethrow to allow retry
571+ platformId.devices_ .clear ();
572+ throw ;
573+ }
574+ },
575+ err);
567576
577+ if (num_platforms != nullptr ) {
578+ *num_platforms = numPlatforms;
579+ }
580+
581+ if (platforms != nullptr ) {
568582 *platforms = &platformId;
569583 }
570584
@@ -1110,12 +1124,30 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
11101124}
11111125
11121126/* Context APIs */
1113- pi_result cuda_piContextCreate (const cl_context_properties *properties,
1114- pi_uint32 num_devices, const pi_device *devices,
1115- void (*pfn_notify)(const char *errinfo,
1116- const void *private_info,
1117- size_t cb, void *user_data),
1118- void *user_data, pi_context *retcontext) {
1127+
1128+ // / Create a PI CUDA context.
1129+ // /
1130+ // / By default creates a scoped context and keeps the last active CUDA context
1131+ // / on top of the CUDA context stack.
1132+ // / With the PI_CONTEXT_PROPERTIES_CUDA_PRIMARY key/id and a value of PI_TRUE
1133+ // / creates a primary CUDA context and activates it on the CUDA context stack.
1134+ // /
1135+ // / @param[in] properties 0 terminated array of key/id-value combinations. Can
1136+ // / be nullptr. Only accepts property key/id PI_CONTEXT_PROPERTIES_CUDA_PRIMARY
1137+ // / with a pi_bool value.
1138+ // / @param[in] num_devices Number of devices to create the context for.
1139+ // / @param[in] devices Devices to create the context for.
1140+ // / @param[in] pfn_notify Callback, currently unused.
1141+ // / @param[in] user_data User data for callback.
1142+ // / @param[out] retcontext Set to created context on success.
1143+ // /
1144+ // / @return PI_SUCCESS on success, otherwise an error return code.
1145+ pi_result cuda_piContextCreate (const pi_context_properties *properties,
1146+ pi_uint32 num_devices, const pi_device *devices,
1147+ void (*pfn_notify)(const char *errinfo,
1148+ const void *private_info,
1149+ size_t cb, void *user_data),
1150+ void *user_data, pi_context *retcontext) {
11191151
11201152 assert (devices != nullptr );
11211153 // TODO: How to implement context callback?
@@ -1127,31 +1159,51 @@ pi_result cuda_piContextCreate(const cl_context_properties *properties,
11271159 assert (retcontext != nullptr );
11281160 pi_result errcode_ret = PI_SUCCESS;
11291161
1162+ // Parse properties.
1163+ bool property_cuda_primary = false ;
1164+ while (properties && (0 != *properties)) {
1165+ // Consume property ID.
1166+ pi_context_properties id = *properties;
1167+ ++properties;
1168+ // Consume property value.
1169+ pi_context_properties value = *properties;
1170+ ++properties;
1171+ switch (id) {
1172+ case PI_CONTEXT_PROPERTIES_CUDA_PRIMARY:
1173+ assert (value == PI_FALSE || value == PI_TRUE);
1174+ property_cuda_primary = static_cast <bool >(value);
1175+ break ;
1176+ default :
1177+ // Unknown property.
1178+ assert (!" Unknown piContextCreate property in property list" );
1179+ return PI_INVALID_VALUE;
1180+ }
1181+ }
1182+
11301183 std::unique_ptr<_pi_context> piContextPtr{nullptr };
11311184 try {
1132- if (properties && *properties != PI_CONTEXT_PROPERTIES_CUDA_PRIMARY) {
1133- throw pi_result (CL_INVALID_VALUE);
1134- } else if (!properties) {
1185+ if (property_cuda_primary) {
1186+ // Use the CUDA primary context and assume that we want to use it
1187+ // immediately as we want to forge context switches.
1188+ CUcontext Ctxt;
1189+ errcode_ret = PI_CHECK_ERROR (
1190+ cuDevicePrimaryCtxRetain (&Ctxt, devices[0 ]->cuDevice_ ));
1191+ piContextPtr = std::unique_ptr<_pi_context>(
1192+ new _pi_context{_pi_context::kind::primary, Ctxt, *devices});
1193+ errcode_ret = PI_CHECK_ERROR (cuCtxPushCurrent (Ctxt));
1194+ } else {
1195+ // Create a scoped context.
11351196 CUcontext newContext, current;
11361197 PI_CHECK_ERROR (cuCtxGetCurrent (¤t));
1137- errcode_ret = PI_CHECK_ERROR (cuCtxCreate (&newContext, CU_CTX_MAP_HOST,
1138- (* devices) ->cuDevice_ ));
1198+ errcode_ret = PI_CHECK_ERROR (
1199+ cuCtxCreate (&newContext, CU_CTX_MAP_HOST, devices[ 0 ] ->cuDevice_ ));
11391200 piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{
11401201 _pi_context::kind::user_defined, newContext, *devices});
1202+ // For scoped contexts keep the last active CUDA one on top of the stack
1203+ // as `cuCtxCreate` replaces it implicitly otherwise.
11411204 if (current != nullptr ) {
1142- // If there was an existing context on the thread we recover it
11431205 PI_CHECK_ERROR (cuCtxSetCurrent (current));
11441206 }
1145- } else if (properties
1146- && *properties == PI_CONTEXT_PROPERTIES_CUDA_PRIMARY) {
1147- CUcontext Ctxt;
1148- errcode_ret = PI_CHECK_ERROR (cuDevicePrimaryCtxRetain (
1149- &Ctxt, (*devices)->cuDevice_ ));
1150- piContextPtr = std::unique_ptr<_pi_context>(
1151- new _pi_context{_pi_context::kind::primary, Ctxt, *devices});
1152- errcode_ret = PI_CHECK_ERROR (cuCtxPushCurrent (Ctxt));
1153- } else {
1154- throw pi_result (CL_INVALID_VALUE);
11551207 }
11561208
11571209 *retcontext = piContextPtr.release ();
0 commit comments