@@ -465,21 +465,34 @@ Waifu2x::eWaifu2xcuDNNError Waifu2x::can_use_cuDNN()
465
465
typedef cudnnStatus_t (CUDNNWINAPI* cudnnCreateType)(cudnnHandle_t *);
466
466
typedef cudnnStatus_t (CUDNNWINAPI* cudnnDestroyType)(cudnnHandle_t);
467
467
typedef size_t (CUDNNWINAPI* cudnnGetVersionType)();
468
+ typedef size_t (CUDNNWINAPI* cudnnGetCudartVersionType)();
468
469
469
470
cudnnCreateType cudnnCreateFunc = (cudnnCreateType)GetProcAddress (hModule, " cudnnCreate" );
470
471
cudnnDestroyType cudnnDestroyFunc = (cudnnDestroyType)GetProcAddress (hModule, " cudnnDestroy" );
471
472
cudnnGetVersionType cudnnGetVersionFunc = (cudnnGetVersionType)GetProcAddress (hModule, " cudnnGetVersion" );
472
- if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr )
473
+ cudnnGetCudartVersionType cudnnGetCudartVersionFunc = (cudnnGetCudartVersionType)GetProcAddress (hModule, " cudnnGetCudartVersion" );
474
+ if (cudnnCreateFunc != nullptr && cudnnDestroyFunc != nullptr && cudnnGetVersionFunc != nullptr && cudnnGetCudartVersionFunc != nullptr )
473
475
{
474
476
if (cudnnGetVersionFunc () >= CUDNN_REQUIRE_VERION)
475
477
{
476
- cudnnHandle_t h ;
477
- if (cudnnCreateFunc (&h ) == CUDNN_STATUS_SUCCESS )
478
+ int runtimeVersion ;
479
+ if (cudaRuntimeGetVersion (&runtimeVersion ) == cudaSuccess )
478
480
{
479
- if (cudnnDestroyFunc (h) == CUDNN_STATUS_SUCCESS)
480
- cuDNNFlag = eWaifu2xcuDNNError_OK;
481
+ if (cudnnGetCudartVersionFunc () >= runtimeVersion)
482
+ {
483
+ cudnnHandle_t h;
484
+ if (cudnnCreateFunc (&h) == CUDNN_STATUS_SUCCESS)
485
+ {
486
+ if (cudnnDestroyFunc (h) == CUDNN_STATUS_SUCCESS)
487
+ cuDNNFlag = eWaifu2xcuDNNError_OK;
488
+ else
489
+ cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
490
+ }
491
+ else
492
+ cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
493
+ }
481
494
else
482
- cuDNNFlag = eWaifu2xcuDNNError_CannotCreate ;
495
+ cuDNNFlag = eWaifu2xcuDNNError_OldCudaVersion ;
483
496
}
484
497
else
485
498
cuDNNFlag = eWaifu2xcuDNNError_CannotCreate;
0 commit comments