diff --git a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaContainerImageTags.cs b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaContainerImageTags.cs index cd7ec34da..19b2ea17c 100644 --- a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaContainerImageTags.cs +++ b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaContainerImageTags.cs @@ -4,9 +4,9 @@ internal static class OllamaContainerImageTags { public const string Registry = "docker.io"; public const string Image = "ollama/ollama"; - public const string Tag = "0.6.8"; + public const string Tag = "0.7.1"; public const string OpenWebUIRegistry = "ghcr.io"; public const string OpenWebUIImage = "open-webui/open-webui"; - public const string OpenWebUITag = "0.5.20"; + public const string OpenWebUITag = "0.6.10"; } diff --git a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs index 0e2f2e5c4..65f9823a6 100644 --- a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs +++ b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs @@ -103,11 +103,23 @@ public static IResourceBuilder WithGPUSupport(this IResourceBuil return vendor switch { OllamaGpuVendor.Nvidia => builder.WithContainerRuntimeArgs("--gpus", "all"), - OllamaGpuVendor.AMD => builder.WithContainerRuntimeArgs("--device", "/dev/kfd"), + OllamaGpuVendor.AMD => builder.WithAMDGPUSupport(), _ => throw new ArgumentException("Invalid GPU vendor", nameof(vendor)) }; } + private static IResourceBuilder WithAMDGPUSupport(this IResourceBuilder builder) + { + if (builder.Resource.TryGetLastAnnotation(out var containerAnnotation)) + { + if (containerAnnotation.Tag?.EndsWith("rocm") == false) + { + containerAnnotation.Tag += "-rocm"; + } + } + return builder.WithContainerRuntimeArgs("--device", "/dev/kfd", "--device", "/dev/dri"); + } + private static OllamaResource AddServerResourceCommand( this OllamaResource ollamaResource, string name, diff --git a/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs b/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs index 236876fe5..50d3f5bf8 100644 --- a/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs +++ b/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs @@ -581,13 +581,40 @@ public void OllamaResourceCommandsUpdateState(string commandType) Assert.Equal(ResourceCommandState.Enabled, state); } - [Theory] - [InlineData(OllamaGpuVendor.Nvidia, "--gpus", "all")] - [InlineData(OllamaGpuVendor.AMD, "--device", "/dev/kfd")] - public async Task WithGPUSupport(OllamaGpuVendor vendor, string expectedArg, string expectedValue) + [Fact] + public async Task WithNvidiaGPUSupport() + { + var builder = DistributedApplication.CreateBuilder(); + _ = builder.AddOllama("ollama").WithGPUSupport(OllamaGpuVendor.Nvidia); + + using var app = builder.Build(); + + var appModel = app.Services.GetRequiredService(); + + var resource = Assert.Single(appModel.Resources.OfType()); + + Assert.True(resource.TryGetLastAnnotation(out ContainerRuntimeArgsCallbackAnnotation? argsAnnotations)); + ContainerRuntimeArgsCallbackContext context = new([]); + await argsAnnotations.Callback(context); + + Assert.Collection( + context.Args, + arg => + { + Assert.Equal("--gpus", arg); + }, + arg => + { + Assert.Equal("all", arg); + } + ); + } + + [Fact] + public async Task WithAMDGPUSupport() { var builder = DistributedApplication.CreateBuilder(); - _ = builder.AddOllama("ollama").WithGPUSupport(vendor); + _ = builder.AddOllama("ollama").WithGPUSupport(OllamaGpuVendor.AMD); using var app = builder.Build(); @@ -603,12 +630,24 @@ public async Task WithGPUSupport(OllamaGpuVendor vendor, string expectedArg, str context.Args, arg => { - Assert.Equal(expectedArg, arg); + Assert.Equal("--device", arg); + }, + arg => + { + Assert.Equal("/dev/kfd", arg); }, arg => { - Assert.Equal(expectedValue, arg); + Assert.Equal("--device", arg); + }, + arg => + { + Assert.Equal("/dev/dri", arg); } ); + + Assert.True(resource.TryGetLastAnnotation(out var imageAnnotation)); + Assert.NotNull(imageAnnotation); + Assert.EndsWith("-rocm", imageAnnotation.Tag, StringComparison.OrdinalIgnoreCase); } }