|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// OllamaAgent.cs |
| 3 | + |
| 4 | +using System; |
| 5 | +using System.Collections.Generic; |
| 6 | +using System.IO; |
| 7 | +using System.Linq; |
| 8 | +using System.Net.Http; |
| 9 | +using System.Runtime.CompilerServices; |
| 10 | +using System.Text; |
| 11 | +using System.Text.Json; |
| 12 | +using System.Threading; |
| 13 | +using System.Threading.Tasks; |
| 14 | +using AutoGen.Core; |
| 15 | + |
| 16 | +namespace Autogen.Ollama; |
| 17 | + |
| 18 | +/// <summary> |
| 19 | +/// An agent that can interact with ollama models. |
| 20 | +/// </summary> |
| 21 | +public class OllamaAgent : IStreamingAgent |
| 22 | +{ |
| 23 | + private readonly HttpClient _httpClient; |
| 24 | + public string Name { get; } |
| 25 | + private readonly string _modelName; |
| 26 | + private readonly string _systemMessage; |
| 27 | + private readonly OllamaReplyOptions? _replyOptions; |
| 28 | + |
| 29 | + public OllamaAgent(HttpClient httpClient, string name, string modelName, |
| 30 | + string systemMessage = "You are a helpful AI assistant", |
| 31 | + OllamaReplyOptions? replyOptions = null) |
| 32 | + { |
| 33 | + Name = name; |
| 34 | + _httpClient = httpClient; |
| 35 | + _modelName = modelName; |
| 36 | + _systemMessage = systemMessage; |
| 37 | + _replyOptions = replyOptions; |
| 38 | + } |
| 39 | + public async Task<IMessage> GenerateReplyAsync( |
| 40 | + IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellation = default) |
| 41 | + { |
| 42 | + ChatRequest request = await BuildChatRequest(messages, options); |
| 43 | + request.Stream = false; |
| 44 | + using (HttpResponseMessage? response = await _httpClient |
| 45 | + .SendAsync(BuildRequestMessage(request), HttpCompletionOption.ResponseContentRead, cancellation)) |
| 46 | + { |
| 47 | + response.EnsureSuccessStatusCode(); |
| 48 | + Stream? streamResponse = await response.Content.ReadAsStreamAsync(); |
| 49 | + ChatResponse chatResponse = await JsonSerializer.DeserializeAsync<ChatResponse>(streamResponse, cancellationToken: cancellation) |
| 50 | + ?? throw new Exception("Failed to deserialize response"); |
| 51 | + var output = new MessageEnvelope<ChatResponse>(chatResponse, from: Name); |
| 52 | + return output; |
| 53 | + } |
| 54 | + } |
| 55 | + public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync( |
| 56 | + IEnumerable<IMessage> messages, |
| 57 | + GenerateReplyOptions? options = null, |
| 58 | + [EnumeratorCancellation] CancellationToken cancellationToken = default) |
| 59 | + { |
| 60 | + ChatRequest request = await BuildChatRequest(messages, options); |
| 61 | + request.Stream = true; |
| 62 | + HttpRequestMessage message = BuildRequestMessage(request); |
| 63 | + using (HttpResponseMessage? response = await _httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) |
| 64 | + { |
| 65 | + response.EnsureSuccessStatusCode(); |
| 66 | + using Stream? stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); |
| 67 | + using var reader = new StreamReader(stream); |
| 68 | + |
| 69 | + while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested) |
| 70 | + { |
| 71 | + string? line = await reader.ReadLineAsync(); |
| 72 | + if (string.IsNullOrWhiteSpace(line)) continue; |
| 73 | + |
| 74 | + ChatResponseUpdate? update = JsonSerializer.Deserialize<ChatResponseUpdate>(line); |
| 75 | + if (update != null) |
| 76 | + { |
| 77 | + yield return new MessageEnvelope<ChatResponseUpdate>(update, from: Name); |
| 78 | + } |
| 79 | + |
| 80 | + if (update is { Done: false }) continue; |
| 81 | + |
| 82 | + ChatResponse? chatMessage = JsonSerializer.Deserialize<ChatResponse>(line); |
| 83 | + if (chatMessage == null) continue; |
| 84 | + yield return new MessageEnvelope<ChatResponse>(chatMessage, from: Name); |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | + private async Task<ChatRequest> BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options) |
| 89 | + { |
| 90 | + var request = new ChatRequest |
| 91 | + { |
| 92 | + Model = _modelName, |
| 93 | + Messages = await BuildChatHistory(messages) |
| 94 | + }; |
| 95 | + |
| 96 | + if (options is OllamaReplyOptions replyOptions) |
| 97 | + { |
| 98 | + BuildChatRequestOptions(replyOptions, request); |
| 99 | + return request; |
| 100 | + } |
| 101 | + |
| 102 | + if (_replyOptions != null) |
| 103 | + { |
| 104 | + BuildChatRequestOptions(_replyOptions, request); |
| 105 | + return request; |
| 106 | + } |
| 107 | + return request; |
| 108 | + } |
| 109 | + private void BuildChatRequestOptions(OllamaReplyOptions replyOptions, ChatRequest request) |
| 110 | + { |
| 111 | + request.Format = replyOptions.Format == FormatType.Json ? OllamaConsts.JsonFormatType : null; |
| 112 | + request.Template = replyOptions.Template; |
| 113 | + request.KeepAlive = replyOptions.KeepAlive; |
| 114 | + |
| 115 | + if (replyOptions.Temperature != null |
| 116 | + || replyOptions.MaxToken != null |
| 117 | + || replyOptions.StopSequence != null |
| 118 | + || replyOptions.Seed != null |
| 119 | + || replyOptions.MiroStat != null |
| 120 | + || replyOptions.MiroStatEta != null |
| 121 | + || replyOptions.MiroStatTau != null |
| 122 | + || replyOptions.NumCtx != null |
| 123 | + || replyOptions.NumGqa != null |
| 124 | + || replyOptions.NumGpu != null |
| 125 | + || replyOptions.NumThread != null |
| 126 | + || replyOptions.RepeatLastN != null |
| 127 | + || replyOptions.RepeatPenalty != null |
| 128 | + || replyOptions.TopK != null |
| 129 | + || replyOptions.TopP != null |
| 130 | + || replyOptions.TfsZ != null) |
| 131 | + { |
| 132 | + request.Options = new ModelReplyOptions |
| 133 | + { |
| 134 | + Temperature = replyOptions.Temperature, |
| 135 | + NumPredict = replyOptions.MaxToken, |
| 136 | + Stop = replyOptions.StopSequence?[0], |
| 137 | + Seed = replyOptions.Seed, |
| 138 | + MiroStat = replyOptions.MiroStat, |
| 139 | + MiroStatEta = replyOptions.MiroStatEta, |
| 140 | + MiroStatTau = replyOptions.MiroStatTau, |
| 141 | + NumCtx = replyOptions.NumCtx, |
| 142 | + NumGqa = replyOptions.NumGqa, |
| 143 | + NumGpu = replyOptions.NumGpu, |
| 144 | + NumThread = replyOptions.NumThread, |
| 145 | + RepeatLastN = replyOptions.RepeatLastN, |
| 146 | + RepeatPenalty = replyOptions.RepeatPenalty, |
| 147 | + TopK = replyOptions.TopK, |
| 148 | + TopP = replyOptions.TopP, |
| 149 | + TfsZ = replyOptions.TfsZ |
| 150 | + }; |
| 151 | + } |
| 152 | + } |
| 153 | + private async Task<List<Message>> BuildChatHistory(IEnumerable<IMessage> messages) |
| 154 | + { |
| 155 | + if (!messages.Any(m => m.IsSystemMessage())) |
| 156 | + { |
| 157 | + var systemMessage = new TextMessage(Role.System, _systemMessage, from: Name); |
| 158 | + messages = new[] { systemMessage }.Concat(messages); |
| 159 | + } |
| 160 | + |
| 161 | + var collection = new List<Message>(); |
| 162 | + foreach (IMessage? message in messages) |
| 163 | + { |
| 164 | + Message item; |
| 165 | + switch (message) |
| 166 | + { |
| 167 | + case TextMessage tm: |
| 168 | + item = new Message { Role = tm.Role.ToString(), Value = tm.Content }; |
| 169 | + break; |
| 170 | + case ImageMessage im: |
| 171 | + string base64Image = await ImageUrlToBase64(im.Url!); |
| 172 | + item = new Message { Role = im.Role.ToString(), Images = [base64Image] }; |
| 173 | + break; |
| 174 | + case MultiModalMessage mm: |
| 175 | + var textsGroupedByRole = mm.Content.OfType<TextMessage>().GroupBy(tm => tm.Role) |
| 176 | + .ToDictionary(g => g.Key, g => string.Join(Environment.NewLine, g.Select(tm => tm.Content))); |
| 177 | + |
| 178 | + string content = string.Join($"{Environment.NewLine}", textsGroupedByRole |
| 179 | + .Select(g => $"{g.Key}{Environment.NewLine}:{g.Value}")); |
| 180 | + |
| 181 | + IEnumerable<Task<string>> imagesConversionTasks = mm.Content |
| 182 | + .OfType<ImageMessage>() |
| 183 | + .Select(async im => await ImageUrlToBase64(im.Url!)); |
| 184 | + |
| 185 | + string[]? imagesBase64 = await Task.WhenAll(imagesConversionTasks); |
| 186 | + item = new Message { Role = mm.Role.ToString(), Value = content, Images = imagesBase64 }; |
| 187 | + break; |
| 188 | + default: |
| 189 | + throw new NotSupportedException(); |
| 190 | + } |
| 191 | + |
| 192 | + collection.Add(item); |
| 193 | + } |
| 194 | + |
| 195 | + return collection; |
| 196 | + } |
| 197 | + private static HttpRequestMessage BuildRequestMessage(ChatRequest request) |
| 198 | + { |
| 199 | + string serialized = JsonSerializer.Serialize(request); |
| 200 | + return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.ChatCompletionEndpoint) |
| 201 | + { |
| 202 | + Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType) |
| 203 | + }; |
| 204 | + } |
| 205 | + private async Task<string> ImageUrlToBase64(string imageUrl) |
| 206 | + { |
| 207 | + if (string.IsNullOrWhiteSpace(imageUrl)) |
| 208 | + { |
| 209 | + throw new ArgumentException("required parameter", nameof(imageUrl)); |
| 210 | + } |
| 211 | + byte[] imageBytes = await _httpClient.GetByteArrayAsync(imageUrl); |
| 212 | + return imageBytes != null |
| 213 | + ? Convert.ToBase64String(imageBytes) |
| 214 | + : throw new InvalidOperationException("no image byte array"); |
| 215 | + } |
| 216 | +} |
0 commit comments