diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index 33f723e54f80..7648cd1196a8 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -37,6 +37,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Test EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama", "src\Autogen.Ollama\Autogen.Ollama.csproj", "{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama.Tests", "test\Autogen.Ollama.Tests\Autogen.Ollama.Tests.csproj", "{C24FDE63-952D-4F8E-A807-AF31D43AD675}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -91,6 +95,14 @@ Global {15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU + {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.Build.0 = Release|Any CPU + {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.Build.0 = Release|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -116,6 +128,8 @@ Global {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {C24FDE63-952D-4F8E-A807-AF31D43AD675} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection diff --git a/dotnet/src/Autogen.Ollama/Agent/OllamaAgent.cs b/dotnet/src/Autogen.Ollama/Agent/OllamaAgent.cs new file mode 100644 index 000000000000..6f87e20e2332 --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Agent/OllamaAgent.cs @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaAgent.cs + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core; + +namespace Autogen.Ollama; + +/// +/// An agent that can interact with ollama models. +/// +public class OllamaAgent : IStreamingAgent +{ + private readonly HttpClient _httpClient; + public string Name { get; } + private readonly string _modelName; + private readonly string _systemMessage; + private readonly OllamaReplyOptions? _replyOptions; + + public OllamaAgent(HttpClient httpClient, string name, string modelName, + string systemMessage = "You are a helpful AI assistant", + OllamaReplyOptions? replyOptions = null) + { + Name = name; + _httpClient = httpClient; + _modelName = modelName; + _systemMessage = systemMessage; + _replyOptions = replyOptions; + } + public async Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellation = default) + { + ChatRequest request = await BuildChatRequest(messages, options); + request.Stream = false; + using (HttpResponseMessage? response = await _httpClient + .SendAsync(BuildRequestMessage(request), HttpCompletionOption.ResponseContentRead, cancellation)) + { + response.EnsureSuccessStatusCode(); + Stream? streamResponse = await response.Content.ReadAsStreamAsync(); + ChatResponse chatResponse = await JsonSerializer.DeserializeAsync(streamResponse, cancellationToken: cancellation) + ?? throw new Exception("Failed to deserialize response"); + var output = new MessageEnvelope(chatResponse, from: Name); + return output; + } + } + public async IAsyncEnumerable GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + ChatRequest request = await BuildChatRequest(messages, options); + request.Stream = true; + HttpRequestMessage message = BuildRequestMessage(request); + using (HttpResponseMessage? response = await _httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) + { + response.EnsureSuccessStatusCode(); + using Stream? stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); + using var reader = new StreamReader(stream); + + while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested) + { + string? line = await reader.ReadLineAsync(); + if (string.IsNullOrWhiteSpace(line)) continue; + + ChatResponseUpdate? update = JsonSerializer.Deserialize(line); + if (update != null) + { + yield return new MessageEnvelope(update, from: Name); + } + + if (update is { Done: false }) continue; + + ChatResponse? chatMessage = JsonSerializer.Deserialize(line); + if (chatMessage == null) continue; + yield return new MessageEnvelope(chatMessage, from: Name); + } + } + } + private async Task BuildChatRequest(IEnumerable messages, GenerateReplyOptions? options) + { + var request = new ChatRequest + { + Model = _modelName, + Messages = await BuildChatHistory(messages) + }; + + if (options is OllamaReplyOptions replyOptions) + { + BuildChatRequestOptions(replyOptions, request); + return request; + } + + if (_replyOptions != null) + { + BuildChatRequestOptions(_replyOptions, request); + return request; + } + return request; + } + private void BuildChatRequestOptions(OllamaReplyOptions replyOptions, ChatRequest request) + { + request.Format = replyOptions.Format == FormatType.Json ? OllamaConsts.JsonFormatType : null; + request.Template = replyOptions.Template; + request.KeepAlive = replyOptions.KeepAlive; + + if (replyOptions.Temperature != null + || replyOptions.MaxToken != null + || replyOptions.StopSequence != null + || replyOptions.Seed != null + || replyOptions.MiroStat != null + || replyOptions.MiroStatEta != null + || replyOptions.MiroStatTau != null + || replyOptions.NumCtx != null + || replyOptions.NumGqa != null + || replyOptions.NumGpu != null + || replyOptions.NumThread != null + || replyOptions.RepeatLastN != null + || replyOptions.RepeatPenalty != null + || replyOptions.TopK != null + || replyOptions.TopP != null + || replyOptions.TfsZ != null) + { + request.Options = new ModelReplyOptions + { + Temperature = replyOptions.Temperature, + NumPredict = replyOptions.MaxToken, + Stop = replyOptions.StopSequence?[0], + Seed = replyOptions.Seed, + MiroStat = replyOptions.MiroStat, + MiroStatEta = replyOptions.MiroStatEta, + MiroStatTau = replyOptions.MiroStatTau, + NumCtx = replyOptions.NumCtx, + NumGqa = replyOptions.NumGqa, + NumGpu = replyOptions.NumGpu, + NumThread = replyOptions.NumThread, + RepeatLastN = replyOptions.RepeatLastN, + RepeatPenalty = replyOptions.RepeatPenalty, + TopK = replyOptions.TopK, + TopP = replyOptions.TopP, + TfsZ = replyOptions.TfsZ + }; + } + } + private async Task> BuildChatHistory(IEnumerable messages) + { + if (!messages.Any(m => m.IsSystemMessage())) + { + var systemMessage = new TextMessage(Role.System, _systemMessage, from: Name); + messages = new[] { systemMessage }.Concat(messages); + } + + var collection = new List(); + foreach (IMessage? message in messages) + { + Message item; + switch (message) + { + case TextMessage tm: + item = new Message { Role = tm.Role.ToString(), Value = tm.Content }; + break; + case ImageMessage im: + string base64Image = await ImageUrlToBase64(im.Url!); + item = new Message { Role = im.Role.ToString(), Images = [base64Image] }; + break; + case MultiModalMessage mm: + var textsGroupedByRole = mm.Content.OfType().GroupBy(tm => tm.Role) + .ToDictionary(g => g.Key, g => string.Join(Environment.NewLine, g.Select(tm => tm.Content))); + + string content = string.Join($"{Environment.NewLine}", textsGroupedByRole + .Select(g => $"{g.Key}{Environment.NewLine}:{g.Value}")); + + IEnumerable> imagesConversionTasks = mm.Content + .OfType() + .Select(async im => await ImageUrlToBase64(im.Url!)); + + string[]? imagesBase64 = await Task.WhenAll(imagesConversionTasks); + item = new Message { Role = mm.Role.ToString(), Value = content, Images = imagesBase64 }; + break; + default: + throw new NotSupportedException(); + } + + collection.Add(item); + } + + return collection; + } + private static HttpRequestMessage BuildRequestMessage(ChatRequest request) + { + string serialized = JsonSerializer.Serialize(request); + return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.ChatCompletionEndpoint) + { + Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType) + }; + } + private async Task ImageUrlToBase64(string imageUrl) + { + if (string.IsNullOrWhiteSpace(imageUrl)) + { + throw new ArgumentException("required parameter", nameof(imageUrl)); + } + byte[] imageBytes = await _httpClient.GetByteArrayAsync(imageUrl); + return imageBytes != null + ? Convert.ToBase64String(imageBytes) + : throw new InvalidOperationException("no image byte array"); + } +} diff --git a/dotnet/src/Autogen.Ollama/Autogen.Ollama.csproj b/dotnet/src/Autogen.Ollama/Autogen.Ollama.csproj new file mode 100644 index 000000000000..9a01f95ca8ea --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Autogen.Ollama.csproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + True + + + + + + + diff --git a/dotnet/src/Autogen.Ollama/DTOs/ChatRequest.cs b/dotnet/src/Autogen.Ollama/DTOs/ChatRequest.cs new file mode 100644 index 000000000000..a48fb42cfbfb --- /dev/null +++ b/dotnet/src/Autogen.Ollama/DTOs/ChatRequest.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatRequest.cs + +using System; +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Autogen.Ollama; + +public class ChatRequest +{ + /// + /// (required) the model name + /// + [JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + + /// + /// the messages of the chat, this can be used to keep a chat memory + /// + [JsonPropertyName("messages")] + public IList Messages { get; set; } = Array.Empty(); + + /// + /// the format to return a response in. Currently, the only accepted value is json + /// + [JsonPropertyName("format")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Format { get; set; } + + /// + /// additional model parameters listed in the documentation for the Modelfile such as temperature + /// + [JsonPropertyName("options")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public ModelReplyOptions? Options { get; set; } + /// + /// the prompt template to use (overrides what is defined in the Modelfile) + /// + [JsonPropertyName("template")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Template { get; set; } + /// + /// if false the response will be returned as a single response object, rather than a stream of objects + /// + [JsonPropertyName("stream")] + public bool Stream { get; set; } + /// + /// controls how long the model will stay loaded into memory following the request (default: 5m) + /// + [JsonPropertyName("keep_alive")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? KeepAlive { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/DTOs/ChatResponse.cs b/dotnet/src/Autogen.Ollama/DTOs/ChatResponse.cs new file mode 100644 index 000000000000..2de150f7235a --- /dev/null +++ b/dotnet/src/Autogen.Ollama/DTOs/ChatResponse.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatResponse.cs + +using System.Text.Json.Serialization; + +namespace Autogen.Ollama; + +public class ChatResponse : ChatResponseUpdate +{ + /// + /// time spent generating the response + /// + [JsonPropertyName("total_duration")] + public long TotalDuration { get; set; } + + /// + /// time spent in nanoseconds loading the model + /// + [JsonPropertyName("load_duration")] + public long LoadDuration { get; set; } + + /// + /// number of tokens in the prompt + /// + [JsonPropertyName("prompt_eval_count")] + public int PromptEvalCount { get; set; } + + /// + /// time spent in nanoseconds evaluating the prompt + /// + [JsonPropertyName("prompt_eval_duration")] + public long PromptEvalDuration { get; set; } + + /// + /// number of tokens the response + /// + [JsonPropertyName("eval_count")] + public int EvalCount { get; set; } + + /// + /// time in nanoseconds spent generating the response + /// + [JsonPropertyName("eval_duration")] + public long EvalDuration { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/DTOs/ChatResponseUpdate.cs b/dotnet/src/Autogen.Ollama/DTOs/ChatResponseUpdate.cs new file mode 100644 index 000000000000..181dacfc34b5 --- /dev/null +++ b/dotnet/src/Autogen.Ollama/DTOs/ChatResponseUpdate.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatResponseUpdate.cs + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Autogen.Ollama; + +public class ChatResponseUpdate +{ + [JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + + [JsonPropertyName("created_at")] + public string CreatedAt { get; set; } = string.Empty; + + [JsonPropertyName("message")] + public Message? Message { get; set; } + + [JsonPropertyName("done")] + public bool Done { get; set; } +} + +public class Message +{ + /// + /// the role of the message, either system, user or assistant + /// + [JsonPropertyName("role")] + public string Role { get; set; } = string.Empty; + /// + /// the content of the message + /// + [JsonPropertyName("content")] + public string Value { get; set; } = string.Empty; + + /// + /// (optional): a list of images to include in the message (for multimodal models such as llava) + /// + [JsonPropertyName("images")] + public IList? Images { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/DTOs/ModelReplyOptions.cs b/dotnet/src/Autogen.Ollama/DTOs/ModelReplyOptions.cs new file mode 100644 index 000000000000..d7854b77b20e --- /dev/null +++ b/dotnet/src/Autogen.Ollama/DTOs/ModelReplyOptions.cs @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ModelReplyOptions.cs + +using System.Text.Json.Serialization; + +namespace Autogen.Ollama; + +//https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +public class ModelReplyOptions +{ + /// + /// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + /// + [JsonPropertyName("mirostat")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? MiroStat { get; set; } + + /// + /// Influences how quickly the algorithm responds to feedback from the generated text. + /// A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) + /// + [JsonPropertyName("mirostat_eta")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? MiroStatEta { get; set; } + + /// + /// Controls the balance between coherence and diversity of the output. + /// A lower value will result in more focused and coherent text. (Default: 5.0) + /// + [JsonPropertyName("mirostat_tau")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? MiroStatTau { get; set; } + + /// + /// Sets the size of the context window used to generate the next token. (Default: 2048) + /// + [JsonPropertyName("num_ctx")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? NumCtx { get; set; } + + /// + /// The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b + /// + [JsonPropertyName("num_gqa")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? NumGqa { get; set; } + + /// + /// The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. + /// + [JsonPropertyName("num_gpu")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? NumGpu { get; set; } + + /// + /// Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. + /// It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). + /// + [JsonPropertyName("num_thread")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? NumThread { get; set; } + + /// + /// Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) + /// + [JsonPropertyName("repeat_last_n")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? RepeatLastN { get; set; } + + /// + /// Sets how strongly to penalize repetitions. + /// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) + /// + [JsonPropertyName("repeat_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? RepeatPenalty { get; set; } + + /// + /// The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) + /// + [JsonPropertyName("temperature")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? Temperature { get; set; } + + /// + /// Sets the random number seed to use for generation. + /// Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) + /// + [JsonPropertyName("seed")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? Seed { get; set; } + + /// + /// Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. + /// Multiple stop patterns may be set by specifying multiple separate stop parameters in a modelfile. + /// + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Stop { get; set; } + + /// + /// Tail free sampling is used to reduce the impact of less probable tokens from the output. + /// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) + /// + [JsonPropertyName("tfs_z")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? TfsZ { get; set; } + + /// + /// Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) + /// + [JsonPropertyName("num_predict")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? NumPredict { get; set; } + + /// + /// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) + /// + [JsonPropertyName("top_k")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TopK { get; set; } + + /// + /// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) + /// + [JsonPropertyName("top_p")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TopP { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/DTOs/OllamaReplyOptions.cs b/dotnet/src/Autogen.Ollama/DTOs/OllamaReplyOptions.cs new file mode 100644 index 000000000000..97bf57cb10cb --- /dev/null +++ b/dotnet/src/Autogen.Ollama/DTOs/OllamaReplyOptions.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaReplyOptions.cs + +using AutoGen.Core; + +namespace Autogen.Ollama; + +public enum FormatType +{ + None, + Json +} + +public class OllamaReplyOptions : GenerateReplyOptions +{ + /// + /// the format to return a response in. Currently, the only accepted value is json + /// + public FormatType Format { get; set; } = FormatType.None; + + /// + /// the prompt template to use (overrides what is defined in the Modelfile) + /// + public string? Template { get; set; } + + /// + /// The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) + /// + public new float? Temperature { get; set; } + + /// + /// controls how long the model will stay loaded into memory following the request (default: 5m) + /// + public string? KeepAlive { get; set; } + + /// + /// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + /// + public int? MiroStat { get; set; } + + /// + /// Influences how quickly the algorithm responds to feedback from the generated text. + /// A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) + /// + public float? MiroStatEta { get; set; } + + /// + /// Controls the balance between coherence and diversity of the output. + /// A lower value will result in more focused and coherent text. (Default: 5.0) + /// + public float? MiroStatTau { get; set; } + + /// + /// Sets the size of the context window used to generate the next token. (Default: 2048) + /// + public int? NumCtx { get; set; } + + /// + /// The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b + /// + public int? NumGqa { get; set; } + + /// + /// The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. + /// + public int? NumGpu { get; set; } + + /// + /// Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. + /// It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). + /// + public int? NumThread { get; set; } + + /// + /// Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) + /// + public int? RepeatLastN { get; set; } + + /// + /// Sets how strongly to penalize repetitions. + /// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) + /// + public float? RepeatPenalty { get; set; } + + /// + /// Sets the random number seed to use for generation. + /// Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) + /// + public int? Seed { get; set; } + + /// + /// Tail free sampling is used to reduce the impact of less probable tokens from the output. + /// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) + /// + public float? TfsZ { get; set; } + + /// + /// Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) + /// + public new int? MaxToken { get; set; } + + /// + /// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) + /// + public int? TopK { get; set; } + + /// + /// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) + /// + public int? TopP { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/Embeddings/ITextEmbeddingService.cs b/dotnet/src/Autogen.Ollama/Embeddings/ITextEmbeddingService.cs new file mode 100644 index 000000000000..f1ea1b8406c6 --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Embeddings/ITextEmbeddingService.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ITextEmbeddingService.cs + +using System.Threading; +using System.Threading.Tasks; + +namespace Autogen.Ollama; + +public interface ITextEmbeddingService +{ + public Task GenerateAsync(TextEmbeddingsRequest request, CancellationToken cancellationToken); +} diff --git a/dotnet/src/Autogen.Ollama/Embeddings/OllamaTextEmbeddingService.cs b/dotnet/src/Autogen.Ollama/Embeddings/OllamaTextEmbeddingService.cs new file mode 100644 index 000000000000..db913377a5f5 --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Embeddings/OllamaTextEmbeddingService.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaTextEmbeddingService.cs + +using System; +using System.IO; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Autogen.Ollama; + +public class OllamaTextEmbeddingService : ITextEmbeddingService +{ + private readonly HttpClient _client; + + public OllamaTextEmbeddingService(HttpClient client) + { + _client = client; + } + public async Task GenerateAsync(TextEmbeddingsRequest request, CancellationToken cancellationToken = default) + { + using (HttpResponseMessage? response = await _client + .SendAsync(BuildPostRequest(request), HttpCompletionOption.ResponseContentRead, cancellationToken)) + { + response.EnsureSuccessStatusCode(); + + Stream? streamResponse = await response.Content.ReadAsStreamAsync(); + TextEmbeddingsResponse output = await JsonSerializer + .DeserializeAsync(streamResponse, cancellationToken: cancellationToken) + ?? throw new Exception("Failed to deserialize response"); + return output; + } + } + private static HttpRequestMessage BuildPostRequest(TextEmbeddingsRequest request) + { + string serialized = JsonSerializer.Serialize(request); + return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.EmbeddingsEndpoint) + { + Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType) + }; + } +} diff --git a/dotnet/src/Autogen.Ollama/Embeddings/TextEmbeddingsRequest.cs b/dotnet/src/Autogen.Ollama/Embeddings/TextEmbeddingsRequest.cs new file mode 100644 index 000000000000..1577dc536433 --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Embeddings/TextEmbeddingsRequest.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TextEmbeddingsRequest.cs + +using System.Text.Json.Serialization; + +namespace Autogen.Ollama; + +public class TextEmbeddingsRequest +{ + /// + /// name of model to generate embeddings from + /// + [JsonPropertyName("model")] + public string Model { get; set; } = string.Empty; + /// + /// text to generate embeddings for + /// + [JsonPropertyName("prompt")] + public string Prompt { get; set; } = string.Empty; + /// + /// additional model parameters listed in the documentation for the Modelfile such as temperature + /// + [JsonPropertyName("options")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public ModelReplyOptions? Options { get; set; } + /// + /// controls how long the model will stay loaded into memory following the request (default: 5m) + /// + [JsonPropertyName("keep_alive")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? KeepAlive { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/Embeddings/TextEmbeddingsResponse.cs b/dotnet/src/Autogen.Ollama/Embeddings/TextEmbeddingsResponse.cs new file mode 100644 index 000000000000..eb46359fdb2d --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Embeddings/TextEmbeddingsResponse.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TextEmbeddingsResponse.cs + +using System.Text.Json.Serialization; + +namespace Autogen.Ollama; + +public class TextEmbeddingsResponse +{ + [JsonPropertyName("embedding")] + public double[]? Embedding { get; set; } +} diff --git a/dotnet/src/Autogen.Ollama/Middlewares/OllamaMessageConnector.cs b/dotnet/src/Autogen.Ollama/Middlewares/OllamaMessageConnector.cs new file mode 100644 index 000000000000..6defedbe02a4 --- /dev/null +++ b/dotnet/src/Autogen.Ollama/Middlewares/OllamaMessageConnector.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaMessageConnector.cs + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core; + +namespace Autogen.Ollama; + +public class OllamaMessageConnector : IMiddleware, IStreamingMiddleware +{ + public string Name => nameof(OllamaMessageConnector); + + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, + CancellationToken cancellationToken = default) + { + IEnumerable messages = context.Messages; + IMessage reply = await agent.GenerateReplyAsync(messages, context.Options, cancellationToken); + switch (reply) + { + case IMessage messageEnvelope: + Message? message = messageEnvelope.Content.Message; + return new TextMessage(Role.Assistant, message != null ? message.Value : "EMPTY_CONTENT", messageEnvelope.From); + default: + throw new NotSupportedException(); + } + } + + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (IStreamingMessage? update in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken)) + { + switch (update) + { + case IMessage complete: + { + string? textContent = complete.Content.Message?.Value; + yield return new TextMessage(Role.Assistant, textContent!, complete.From); + break; + } + case IMessage updatedMessage: + { + string? textContent = updatedMessage.Content.Message?.Value; + yield return new TextMessageUpdate(Role.Assistant, textContent, updatedMessage.From); + break; + } + default: + throw new InvalidOperationException("Message type not supported."); + } + } + } +} diff --git a/dotnet/src/Autogen.Ollama/OllamaConsts.cs b/dotnet/src/Autogen.Ollama/OllamaConsts.cs new file mode 100644 index 000000000000..49e91ebc318a --- /dev/null +++ b/dotnet/src/Autogen.Ollama/OllamaConsts.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaConsts.cs + +namespace Autogen.Ollama; + +public class OllamaConsts +{ + public const string JsonFormatType = "json"; + public const string JsonMediaType = "application/json"; + public const string ChatCompletionEndpoint = "/api/chat"; + public const string EmbeddingsEndpoint = "/api/embeddings"; +} diff --git a/dotnet/test/Autogen.Ollama.Tests/Autogen.Ollama.Tests.csproj b/dotnet/test/Autogen.Ollama.Tests/Autogen.Ollama.Tests.csproj new file mode 100644 index 000000000000..a10ce496ae7e --- /dev/null +++ b/dotnet/test/Autogen.Ollama.Tests/Autogen.Ollama.Tests.csproj @@ -0,0 +1,33 @@ + + + + net8.0 + enable + enable + + false + true + True + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + diff --git a/dotnet/test/Autogen.Ollama.Tests/OllamaAgentTests.cs b/dotnet/test/Autogen.Ollama.Tests/OllamaAgentTests.cs new file mode 100644 index 000000000000..b22432fdad69 --- /dev/null +++ b/dotnet/test/Autogen.Ollama.Tests/OllamaAgentTests.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaAgentTests.cs + +using System.Text.Json; +using AutoGen.Core; +using AutoGen.Tests; +using FluentAssertions; + +namespace Autogen.Ollama.Tests; + +public class OllamaAgentTests +{ + + [ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")] + public async Task GenerateReplyAsync_ReturnsValidMessage_WhenCalled() + { + string host = Environment.GetEnvironmentVariable("OLLAMA_HOST") + ?? throw new InvalidOperationException("OLLAMA_HOST is not set."); + string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME") + ?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set."); + OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName); + + var messages = new IMessage[] { new TextMessage(Role.User, "Hello, how are you") }; + IMessage result = await ollamaAgent.GenerateReplyAsync(messages); + + result.Should().NotBeNull(); + result.Should().BeOfType>(); + result.From.Should().Be(ollamaAgent.Name); + } + + [ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")] + public async Task GenerateReplyAsync_ReturnsValidJsonMessageContent_WhenCalled() + { + string host = Environment.GetEnvironmentVariable("OLLAMA_HOST") + ?? throw new InvalidOperationException("OLLAMA_HOST is not set."); + string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME") + ?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set."); + OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName); + + var messages = new IMessage[] { new TextMessage(Role.User, "Hello, how are you") }; + IMessage result = await ollamaAgent.GenerateReplyAsync(messages, new OllamaReplyOptions + { + Format = FormatType.Json + }); + + result.Should().NotBeNull(); + result.Should().BeOfType>(); + result.From.Should().Be(ollamaAgent.Name); + + string jsonContent = ((MessageEnvelope)result).Content.Message!.Value; + bool isValidJson = IsValidJsonMessage(jsonContent); + isValidJson.Should().BeTrue(); + } + + [ApiKeyFact("OLLAMA_HOST", "OLLAMA_MODEL_NAME")] + public async Task GenerateStreamingReplyAsync_ReturnsValidMessages_WhenCalled() + { + string host = Environment.GetEnvironmentVariable("OLLAMA_HOST") + ?? throw new InvalidOperationException("OLLAMA_HOST is not set."); + string modelName = Environment.GetEnvironmentVariable("OLLAMA_MODEL_NAME") + ?? throw new InvalidOperationException("OLLAMA_MODEL_NAME is not set."); + OllamaAgent ollamaAgent = BuildOllamaAgent(host, modelName); + + var messages = new IMessage[] { new TextMessage(Role.User, "Hello how are you") }; + IStreamingMessage? finalReply = default; + await foreach (IStreamingMessage message in ollamaAgent.GenerateStreamingReplyAsync(messages)) + { + message.Should().NotBeNull(); + message.From.Should().Be(ollamaAgent.Name); + finalReply = message; + } + + finalReply.Should().BeOfType>(); + } + + private static bool IsValidJsonMessage(string input) + { + try + { + JsonDocument.Parse(input); + return true; + } + catch (JsonException) + { + return false; + } + catch (Exception ex) + { + Console.WriteLine("An unexpected exception occurred: " + ex.Message); + return false; + } + } + + private static OllamaAgent BuildOllamaAgent(string host, string modelName) + { + var httpClient = new HttpClient + { + BaseAddress = new Uri(host) + }; + return new OllamaAgent(httpClient, "TestAgent", modelName); + } +} diff --git a/dotnet/test/Autogen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs b/dotnet/test/Autogen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs new file mode 100644 index 000000000000..7f2d94e147db --- /dev/null +++ b/dotnet/test/Autogen.Ollama.Tests/OllamaTextEmbeddingServiceTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OllamaTextEmbeddingServiceTests.cs + +using AutoGen.Tests; +using FluentAssertions; + +namespace Autogen.Ollama.Tests; + +public class OllamaTextEmbeddingServiceTests +{ + [ApiKeyFact("OLLAMA_HOST", "OLLAMA_EMBEDDING_MODEL_NAME")] + public async Task GenerateAsync_ReturnsEmbeddings_WhenApiResponseIsSuccessful() + { + string host = Environment.GetEnvironmentVariable("OLLAMA_HOST") + ?? throw new InvalidOperationException("OLLAMA_HOST is not set."); + string embeddingModelName = Environment.GetEnvironmentVariable("OLLAMA_EMBEDDING_MODEL_NAME") + ?? throw new InvalidOperationException("OLLAMA_EMBEDDING_MODEL_NAME is not set."); + var httpClient = new HttpClient + { + BaseAddress = new Uri(host) + }; + var request = new TextEmbeddingsRequest { Model = embeddingModelName, Prompt = "Llamas are members of the camelid family", }; + var service = new OllamaTextEmbeddingService(httpClient); + TextEmbeddingsResponse response = await service.GenerateAsync(request); + response.Should().NotBeNull(); + } +}