Skip to content

Commit 9ead91f

Browse files
Refactor CUB util_device
1 parent 20cd6ce commit 9ead91f

File tree

1 file changed

+13
-31
lines changed

1 file changed

+13
-31
lines changed

cub/cub/util_device.cuh

+13-31
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,10 @@ CUB_RUNTIME_FUNCTION inline int DeviceCountUncached()
155155

156156
/**
157157
* \brief Cache for an arbitrary value produced by a nullary function.
158+
* deprecated [Since 2.6.0]
158159
*/
159160
template <typename T, T (*Function)()>
160-
struct ValueCache
161+
struct CUB_DEPRECATED ValueCache
161162
{
162163
T const value;
163164

@@ -170,13 +171,11 @@ struct ValueCache
170171
{}
171172
};
172173

173-
// Host code, only safely usable in C++11 or newer, where thread-safe
174-
// initialization of static locals is guaranteed. This is a separate function
175-
// to avoid defining a local static in a host/device function.
174+
// Host code. This is a separate function to avoid defining a local static in a host/device function.
176175
_CCCL_HOST inline int DeviceCountCachedValue()
177176
{
178-
static ValueCache<int, DeviceCountUncached> cache;
179-
return cache.value;
177+
static int count = DeviceCountUncached();
178+
return count;
180179
}
181180

182181
/**
@@ -211,7 +210,7 @@ struct PerDeviceAttributeCache
211210
// Each entry starts in the `DeviceEntryEmpty` state, then proceeds to the
212211
// `DeviceEntryInitializing` state, and then proceeds to the
213212
// `DeviceEntryReady` state. These are the only state transitions allowed;
214-
// e.g. a linear sequence of transitions.
213+
// i.e. a linear sequence of transitions.
215214
enum DeviceEntryStatus
216215
{
217216
DeviceEntryEmpty = 0,
@@ -372,7 +371,6 @@ _CCCL_HOST inline cudaError_t PtxVersionUncached(int& ptx_version, int device)
372371
template <typename Tag>
373372
_CCCL_HOST inline PerDeviceAttributeCache& GetPerDeviceAttributeCache()
374373
{
375-
// C++11 guarantees that initialization of static locals is thread safe.
376374
static PerDeviceAttributeCache cache;
377375
return cache;
378376
}
@@ -392,8 +390,7 @@ struct SmVersionCacheTag
392390
_CCCL_HOST inline cudaError_t PtxVersion(int& ptx_version, int device)
393391
{
394392
auto const payload = GetPerDeviceAttributeCache<PtxVersionCacheTag>()(
395-
// If this call fails, then we get the error code back in the payload,
396-
// which we check with `CubDebug` below.
393+
// If this call fails, then we get the error code back in the payload, which we check with `CubDebug` below.
397394
[=](int& pv) {
398395
return PtxVersionUncached(pv, device);
399396
},
@@ -417,23 +414,10 @@ _CCCL_HOST inline cudaError_t PtxVersion(int& ptx_version, int device)
417414
CUB_RUNTIME_FUNCTION inline cudaError_t PtxVersion(int& ptx_version)
418415
{
419416
cudaError_t result = cudaErrorUnknown;
420-
NV_IF_TARGET(
421-
NV_IS_HOST,
422-
(auto const device = CurrentDevice();
423-
auto const payload = GetPerDeviceAttributeCache<PtxVersionCacheTag>()(
424-
// If this call fails, then we get the error code back in the payload,
425-
// which we check with `CubDebug` below.
426-
[=](int& pv) {
427-
return PtxVersionUncached(pv, device);
428-
},
429-
device);
430-
431-
if (!CubDebug(payload.error)) { ptx_version = payload.attribute; }
432-
433-
result = payload.error;),
434-
( // NV_IS_DEVICE:
435-
result = PtxVersionUncached(ptx_version);));
436-
417+
NV_IF_TARGET(NV_IS_HOST,
418+
(result = PtxVersion(ptx_version, CurrentDevice());),
419+
( // NV_IS_DEVICE:
420+
result = PtxVersionUncached(ptx_version);));
437421
return result;
438422
}
439423

@@ -477,8 +461,7 @@ CUB_RUNTIME_FUNCTION inline cudaError_t SmVersion(int& sm_version, int device =
477461
NV_IF_TARGET(
478462
NV_IS_HOST,
479463
(auto const payload = GetPerDeviceAttributeCache<SmVersionCacheTag>()(
480-
// If this call fails, then we get the error code back in
481-
// the payload, which we check with `CubDebug` below.
464+
// If this call fails, then we get the error code back in the payload, which we check with `CubDebug` below.
482465
[=](int& pv) {
483466
return SmVersionUncached(pv, device);
484467
},
@@ -565,9 +548,8 @@ CUB_RUNTIME_FUNCTION inline cudaError_t DebugSyncStream(cudaStream_t stream)
565548
CUB_RUNTIME_FUNCTION inline cudaError_t HasUVA(bool& has_uva)
566549
{
567550
has_uva = false;
568-
cudaError_t error = cudaSuccess;
569551
int device = -1;
570-
error = CubDebug(cudaGetDevice(&device));
552+
cudaError_t error = CubDebug(cudaGetDevice(&device));
571553
if (cudaSuccess != error)
572554
{
573555
return error;

0 commit comments

Comments
 (0)