forked from microsoft/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[.Net] feature: Ollama integration (microsoft#2693)
* [.Net] feature: Ollama integration with * [.Net] ollama agent improvements and reorganization * added ollama fact logic * [.Net] added ollama embeddings service * [.Net] Ollama embeddings integration * cleaned the agent and connector code * [.Net] cleaned ollama agent tests * [.Net] standardize api key fact ollama host variable * [.Net] fixed solution issue --------- Co-authored-by: Xiaoyun Zhang <[email protected]>
- Loading branch information
1 parent
aa16968
commit 23e8d27
Showing
17 changed files
with
953 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// OllamaAgent.cs | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using System.Net.Http; | ||
using System.Runtime.CompilerServices; | ||
using System.Text; | ||
using System.Text.Json; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using AutoGen.Core; | ||
|
||
namespace Autogen.Ollama; | ||
|
||
/// <summary> | ||
/// An agent that can interact with ollama models. | ||
/// </summary> | ||
public class OllamaAgent : IStreamingAgent | ||
{ | ||
private readonly HttpClient _httpClient; | ||
public string Name { get; } | ||
private readonly string _modelName; | ||
private readonly string _systemMessage; | ||
private readonly OllamaReplyOptions? _replyOptions; | ||
|
||
public OllamaAgent(HttpClient httpClient, string name, string modelName, | ||
string systemMessage = "You are a helpful AI assistant", | ||
OllamaReplyOptions? replyOptions = null) | ||
{ | ||
Name = name; | ||
_httpClient = httpClient; | ||
_modelName = modelName; | ||
_systemMessage = systemMessage; | ||
_replyOptions = replyOptions; | ||
} | ||
public async Task<IMessage> GenerateReplyAsync( | ||
IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellation = default) | ||
{ | ||
ChatRequest request = await BuildChatRequest(messages, options); | ||
request.Stream = false; | ||
using (HttpResponseMessage? response = await _httpClient | ||
.SendAsync(BuildRequestMessage(request), HttpCompletionOption.ResponseContentRead, cancellation)) | ||
{ | ||
response.EnsureSuccessStatusCode(); | ||
Stream? streamResponse = await response.Content.ReadAsStreamAsync(); | ||
ChatResponse chatResponse = await JsonSerializer.DeserializeAsync<ChatResponse>(streamResponse, cancellationToken: cancellation) | ||
?? throw new Exception("Failed to deserialize response"); | ||
var output = new MessageEnvelope<ChatResponse>(chatResponse, from: Name); | ||
return output; | ||
} | ||
} | ||
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync( | ||
IEnumerable<IMessage> messages, | ||
GenerateReplyOptions? options = null, | ||
[EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
{ | ||
ChatRequest request = await BuildChatRequest(messages, options); | ||
request.Stream = true; | ||
HttpRequestMessage message = BuildRequestMessage(request); | ||
using (HttpResponseMessage? response = await _httpClient.SendAsync(message, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) | ||
{ | ||
response.EnsureSuccessStatusCode(); | ||
using Stream? stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); | ||
using var reader = new StreamReader(stream); | ||
|
||
while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested) | ||
{ | ||
string? line = await reader.ReadLineAsync(); | ||
if (string.IsNullOrWhiteSpace(line)) continue; | ||
|
||
ChatResponseUpdate? update = JsonSerializer.Deserialize<ChatResponseUpdate>(line); | ||
if (update != null) | ||
{ | ||
yield return new MessageEnvelope<ChatResponseUpdate>(update, from: Name); | ||
} | ||
|
||
if (update is { Done: false }) continue; | ||
|
||
ChatResponse? chatMessage = JsonSerializer.Deserialize<ChatResponse>(line); | ||
if (chatMessage == null) continue; | ||
yield return new MessageEnvelope<ChatResponse>(chatMessage, from: Name); | ||
} | ||
} | ||
} | ||
private async Task<ChatRequest> BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options) | ||
{ | ||
var request = new ChatRequest | ||
{ | ||
Model = _modelName, | ||
Messages = await BuildChatHistory(messages) | ||
}; | ||
|
||
if (options is OllamaReplyOptions replyOptions) | ||
{ | ||
BuildChatRequestOptions(replyOptions, request); | ||
return request; | ||
} | ||
|
||
if (_replyOptions != null) | ||
{ | ||
BuildChatRequestOptions(_replyOptions, request); | ||
return request; | ||
} | ||
return request; | ||
} | ||
private void BuildChatRequestOptions(OllamaReplyOptions replyOptions, ChatRequest request) | ||
{ | ||
request.Format = replyOptions.Format == FormatType.Json ? OllamaConsts.JsonFormatType : null; | ||
request.Template = replyOptions.Template; | ||
request.KeepAlive = replyOptions.KeepAlive; | ||
|
||
if (replyOptions.Temperature != null | ||
|| replyOptions.MaxToken != null | ||
|| replyOptions.StopSequence != null | ||
|| replyOptions.Seed != null | ||
|| replyOptions.MiroStat != null | ||
|| replyOptions.MiroStatEta != null | ||
|| replyOptions.MiroStatTau != null | ||
|| replyOptions.NumCtx != null | ||
|| replyOptions.NumGqa != null | ||
|| replyOptions.NumGpu != null | ||
|| replyOptions.NumThread != null | ||
|| replyOptions.RepeatLastN != null | ||
|| replyOptions.RepeatPenalty != null | ||
|| replyOptions.TopK != null | ||
|| replyOptions.TopP != null | ||
|| replyOptions.TfsZ != null) | ||
{ | ||
request.Options = new ModelReplyOptions | ||
{ | ||
Temperature = replyOptions.Temperature, | ||
NumPredict = replyOptions.MaxToken, | ||
Stop = replyOptions.StopSequence?[0], | ||
Seed = replyOptions.Seed, | ||
MiroStat = replyOptions.MiroStat, | ||
MiroStatEta = replyOptions.MiroStatEta, | ||
MiroStatTau = replyOptions.MiroStatTau, | ||
NumCtx = replyOptions.NumCtx, | ||
NumGqa = replyOptions.NumGqa, | ||
NumGpu = replyOptions.NumGpu, | ||
NumThread = replyOptions.NumThread, | ||
RepeatLastN = replyOptions.RepeatLastN, | ||
RepeatPenalty = replyOptions.RepeatPenalty, | ||
TopK = replyOptions.TopK, | ||
TopP = replyOptions.TopP, | ||
TfsZ = replyOptions.TfsZ | ||
}; | ||
} | ||
} | ||
private async Task<List<Message>> BuildChatHistory(IEnumerable<IMessage> messages) | ||
{ | ||
if (!messages.Any(m => m.IsSystemMessage())) | ||
{ | ||
var systemMessage = new TextMessage(Role.System, _systemMessage, from: Name); | ||
messages = new[] { systemMessage }.Concat(messages); | ||
} | ||
|
||
var collection = new List<Message>(); | ||
foreach (IMessage? message in messages) | ||
{ | ||
Message item; | ||
switch (message) | ||
{ | ||
case TextMessage tm: | ||
item = new Message { Role = tm.Role.ToString(), Value = tm.Content }; | ||
break; | ||
case ImageMessage im: | ||
string base64Image = await ImageUrlToBase64(im.Url!); | ||
item = new Message { Role = im.Role.ToString(), Images = [base64Image] }; | ||
break; | ||
case MultiModalMessage mm: | ||
var textsGroupedByRole = mm.Content.OfType<TextMessage>().GroupBy(tm => tm.Role) | ||
.ToDictionary(g => g.Key, g => string.Join(Environment.NewLine, g.Select(tm => tm.Content))); | ||
|
||
string content = string.Join($"{Environment.NewLine}", textsGroupedByRole | ||
.Select(g => $"{g.Key}{Environment.NewLine}:{g.Value}")); | ||
|
||
IEnumerable<Task<string>> imagesConversionTasks = mm.Content | ||
.OfType<ImageMessage>() | ||
.Select(async im => await ImageUrlToBase64(im.Url!)); | ||
|
||
string[]? imagesBase64 = await Task.WhenAll(imagesConversionTasks); | ||
item = new Message { Role = mm.Role.ToString(), Value = content, Images = imagesBase64 }; | ||
break; | ||
default: | ||
throw new NotSupportedException(); | ||
} | ||
|
||
collection.Add(item); | ||
} | ||
|
||
return collection; | ||
} | ||
private static HttpRequestMessage BuildRequestMessage(ChatRequest request) | ||
{ | ||
string serialized = JsonSerializer.Serialize(request); | ||
return new HttpRequestMessage(HttpMethod.Post, OllamaConsts.ChatCompletionEndpoint) | ||
{ | ||
Content = new StringContent(serialized, Encoding.UTF8, OllamaConsts.JsonMediaType) | ||
}; | ||
} | ||
private async Task<string> ImageUrlToBase64(string imageUrl) | ||
{ | ||
if (string.IsNullOrWhiteSpace(imageUrl)) | ||
{ | ||
throw new ArgumentException("required parameter", nameof(imageUrl)); | ||
} | ||
byte[] imageBytes = await _httpClient.GetByteArrayAsync(imageUrl); | ||
return imageBytes != null | ||
? Convert.ToBase64String(imageBytes) | ||
: throw new InvalidOperationException("no image byte array"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>netstandard2.0</TargetFramework> | ||
<GenerateDocumentationFile>True</GenerateDocumentationFile> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// ChatRequest.cs | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Text.Json.Serialization; | ||
|
||
namespace Autogen.Ollama; | ||
|
||
public class ChatRequest | ||
{ | ||
/// <summary> | ||
/// (required) the model name | ||
/// </summary> | ||
[JsonPropertyName("model")] | ||
public string Model { get; set; } = string.Empty; | ||
|
||
/// <summary> | ||
/// the messages of the chat, this can be used to keep a chat memory | ||
/// </summary> | ||
[JsonPropertyName("messages")] | ||
public IList<Message> Messages { get; set; } = Array.Empty<Message>(); | ||
|
||
/// <summary> | ||
/// the format to return a response in. Currently, the only accepted value is json | ||
/// </summary> | ||
[JsonPropertyName("format")] | ||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] | ||
public string? Format { get; set; } | ||
|
||
/// <summary> | ||
/// additional model parameters listed in the documentation for the Modelfile such as temperature | ||
/// </summary> | ||
[JsonPropertyName("options")] | ||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] | ||
public ModelReplyOptions? Options { get; set; } | ||
/// <summary> | ||
/// the prompt template to use (overrides what is defined in the Modelfile) | ||
/// </summary> | ||
[JsonPropertyName("template")] | ||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] | ||
public string? Template { get; set; } | ||
/// <summary> | ||
/// if false the response will be returned as a single response object, rather than a stream of objects | ||
/// </summary> | ||
[JsonPropertyName("stream")] | ||
public bool Stream { get; set; } | ||
/// <summary> | ||
/// controls how long the model will stay loaded into memory following the request (default: 5m) | ||
/// </summary> | ||
[JsonPropertyName("keep_alive")] | ||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] | ||
public string? KeepAlive { get; set; } | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// ChatResponse.cs | ||
|
||
using System.Text.Json.Serialization; | ||
|
||
namespace Autogen.Ollama; | ||
|
||
public class ChatResponse : ChatResponseUpdate | ||
{ | ||
/// <summary> | ||
/// time spent generating the response | ||
/// </summary> | ||
[JsonPropertyName("total_duration")] | ||
public long TotalDuration { get; set; } | ||
|
||
/// <summary> | ||
/// time spent in nanoseconds loading the model | ||
/// </summary> | ||
[JsonPropertyName("load_duration")] | ||
public long LoadDuration { get; set; } | ||
|
||
/// <summary> | ||
/// number of tokens in the prompt | ||
/// </summary> | ||
[JsonPropertyName("prompt_eval_count")] | ||
public int PromptEvalCount { get; set; } | ||
|
||
/// <summary> | ||
/// time spent in nanoseconds evaluating the prompt | ||
/// </summary> | ||
[JsonPropertyName("prompt_eval_duration")] | ||
public long PromptEvalDuration { get; set; } | ||
|
||
/// <summary> | ||
/// number of tokens the response | ||
/// </summary> | ||
[JsonPropertyName("eval_count")] | ||
public int EvalCount { get; set; } | ||
|
||
/// <summary> | ||
/// time in nanoseconds spent generating the response | ||
/// </summary> | ||
[JsonPropertyName("eval_duration")] | ||
public long EvalDuration { get; set; } | ||
} |
Oops, something went wrong.