Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed cuda load when no cuda device available #225

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Whisper.net/Internals/Native/INativeCuda.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Runtime.InteropServices;

namespace Whisper.net.Internals.Native;
internal interface INativeCuda : IDisposable
{
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
public delegate int cudaGetDeviceCount(out int count);

cudaGetDeviceCount CudaGetDeviceCount { get; }

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Runtime.InteropServices;

namespace Whisper.net.Internals.Native.Implementations.Cuda;

internal class DllImportNativeCuda_64_12 : INativeCuda
{
public const string LibraryName = "cudart64_12";

[DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int cudaGetDeviceCount(out int count);

public INativeCuda.cudaGetDeviceCount CudaGetDeviceCount => cudaGetDeviceCount;

public void Dispose()
{
}
}

internal class DllImportNativeLibcuda : INativeCuda
{
public const string LibraryName = "libcudart";

[DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int cudaGetDeviceCount(out int count);

public INativeCuda.cudaGetDeviceCount CudaGetDeviceCount => cudaGetDeviceCount;

public void Dispose()
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT
#if NET6_0_OR_GREATER
using System.Runtime.InteropServices;
using static Whisper.net.Internals.Native.INativeCuda;

namespace Whisper.net.Internals.Native.Implementations.Cuda;
internal class NativeLibraryCuda(IntPtr cudaHandle) : INativeCuda
{

public cudaGetDeviceCount CudaGetDeviceCount { get; } = Marshal.GetDelegateForFunctionPointer<cudaGetDeviceCount>(NativeLibrary.GetExport(cudaHandle, nameof(cudaGetDeviceCount)));

public void Dispose()
{
NativeLibrary.Free(cudaHandle);
}
}
#endif
47 changes: 47 additions & 0 deletions Whisper.net/LibraryLoader/CudaHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using System.Runtime.InteropServices;
using Whisper.net.Internals.Native;
using Whisper.net.Internals.Native.Implementations.Cuda;

namespace Whisper.net.LibraryLoader;
internal static class CudaHelper
{
public static bool IsCudaAvailable()
{
INativeCuda? nativeCuda = null;
int cudaDevices = 0;
try
{
#if NET6_0_OR_GREATER
var libName = RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
? DllImportNativeCuda_64_12.LibraryName // Only 64-bit Windows is supported for now
: DllImportNativeLibcuda.LibraryName;

if (!NativeLibrary.TryLoad(libName, out var library))
{
return false;
}
nativeCuda = new NativeLibraryCuda(library);
nativeCuda.CudaGetDeviceCount(out cudaDevices);
#else
try
{
nativeCuda = RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
? new DllImportNativeCuda_64_12()
: new DllImportNativeLibcuda();
nativeCuda.CudaGetDeviceCount(out cudaDevices);
}
catch
{
return false;
}
#endif
return cudaDevices > 0;
}
finally
{
nativeCuda?.Dispose();
}
}
}
36 changes: 32 additions & 4 deletions Whisper.net/LibraryLoader/NativeLibraryLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,43 @@ private static string GetLibraryPath(string platform, string libraryName, string
Path.GetDirectoryName(Environment.GetCommandLineArgs()[0])
}.Where(it => !string.IsNullOrEmpty(it));

foreach (var library in RuntimeOptions.Instance.RuntimeLibraryOrder)
static bool IsRuntimeSupported(RuntimeLibrary runtime, string platform)
{

#if !NETSTANDARD
if (library == RuntimeLibrary.Cpu && (platform == "win" || platform == "linux") && !Avx.IsSupported && !Avx2.IsSupported)
// If AVX is not supported, we can't use CPU runtime on windows and linux (we should use noavx runtime instead).
if (runtime == RuntimeLibrary.Cpu && (platform == "win" || platform == "linux") && !Avx.IsSupported && !Avx2.IsSupported)
{
continue;
return false;
}
#endif
// If Cuda is not available, we can't use Cuda runtime (unless there is no other runtime available, where CUDA runtime can be used as a fallback to the CPU)
if (runtime == RuntimeLibrary.Cuda && !CudaHelper.IsCudaAvailable())
{
var cudaIndex = RuntimeOptions.Instance.RuntimeLibraryOrder.IndexOf(RuntimeLibrary.Cuda);

if (cudaIndex == RuntimeOptions.Instance.RuntimeLibraryOrder.Count - 1)
{
// We still can use Cuda as a fallback to the CPU if it's the last runtime in the list.

// This scenario can be used to not install 2 runtimes (CPU and Cuda) on the same host,
// + override the default RuntimeLibraryOrder to have only [ Cuda ].
return true;
}

return false;
}

return true;

}

foreach (var library in RuntimeOptions.Instance.RuntimeLibraryOrder)
{
if (!IsRuntimeSupported(library, platform))
{
continue;
}
foreach (var assemblySearchPath in assemblySearchPaths)
{

Expand Down Expand Up @@ -165,5 +194,4 @@ private static string GetLibraryPath(string platform, string libraryName, string

#endif
}

}
6 changes: 5 additions & 1 deletion examples/NvidiaCuda/NvidiaCuda.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Whisper.net" Version="1.7.0" />
<!--<PackageReference Include="Whisper.net" Version="1.7.0" />-->
<PackageReference Include="Whisper.net.Runtime.Cuda" Version="1.7.0" />
</ItemGroup>

Expand All @@ -16,5 +16,9 @@
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Whisper.net\Whisper.net.csproj" />
</ItemGroup>

</Project>