Skip to content

Commit

Permalink
Merge pull request #53 from RogerBarreto/features/chat-completion-met…
Browse files Browse the repository at this point in the history
…adata

Add Chat With Metadata Endpoint to OllamaSharp client
  • Loading branch information
awaescher authored Jul 12, 2024
2 parents 55d3626 + d17fe58 commit 0785218
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 10 deletions.
7 changes: 6 additions & 1 deletion OllamaSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OllamaSharp", "src\OllamaSh
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tests", "test\Tests.csproj", "{1527F300-40C7-49EB-A6FD-D21B20BA5BC1}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OllamaApiConsole", "OllamaApiConsole\OllamaApiConsole.csproj", "{81DCD129-2666-4846-8D3F-8682F6276289}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OllamaApiConsole", "OllamaApiConsole\OllamaApiConsole.csproj", "{81DCD129-2666-4846-8D3F-8682F6276289}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{7B88E967-DA41-4A3F-80B9-AAAE0A7B471D}"
ProjectSection(SolutionItems) = preProject
.editorconfig = .editorconfig
EndProjectSection
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down
10 changes: 10 additions & 0 deletions src/IOllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ public interface IOllamaApiClient
/// </summary>
string SelectedModel { get; set; }

/// <summary>
/// Sends a request to the /api/chat endpoint
/// </summary>
/// <param name="chatRequest">The request to send to Ollama</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>The returned message with metadata</returns>
Task<ChatResponse> Chat(
ChatRequest chatRequest,
CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/chat endpoint
/// </summary>
Expand Down
72 changes: 72 additions & 0 deletions src/Models/Chat/ChatResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using System.Text.Json.Serialization;

namespace OllamaSharp.Models.Chat;

public class ChatResponse
{
/// <summary>
/// The model that generated the response
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = null!;

/// <summary>
/// The time the response was generated
/// </summary>
[JsonPropertyName("created_at")]
public string CreatedAt { get; set; } = null!;

/// <summary>
/// The message returned by the model
/// </summary>
[JsonPropertyName("message")]
public Message Message { get; set; } = null!;

/// <summary>
/// Whether the response is complete
/// </summary>
[JsonPropertyName("done")]
public bool Done { get; set; }

/// <summary>
/// The reason for the completion of the chat
/// </summary>
[JsonPropertyName("done_reason")]
public string? DoneReason { get; set; }

/// <summary>
/// The time spent generating the response
/// </summary>
[JsonPropertyName("total_duration")]
public long TotalDuration { get; set; }

/// <summary>
/// The time spent in nanoseconds loading the model
/// </summary>
[JsonPropertyName("load_duration")]
public long LoadDuration { get; set; }

/// <summary>
/// The number of tokens in the prompt
/// </summary>
[JsonPropertyName("prompt_eval_count")]
public int PromptEvalCount { get; set; }

/// <summary>
/// The time spent in nanoseconds evaluating the prompt
/// </summary>
[JsonPropertyName("prompt_eval_duration")]
public long PromptEvalDuration { get; set; }

/// <summary>
/// The number of tokens in the response
/// </summary>
[JsonPropertyName("eval_count")]
public int EvalCount { get; set; }

/// <summary>
/// The time in nanoseconds spent generating the response
/// </summary>
[JsonPropertyName("eval_duration")]
public long EvalDuration { get; set; }
}
44 changes: 35 additions & 9 deletions src/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,29 @@ public async Task<IEnumerable<Message>> SendChat(
chatRequest, response, streamer, cancellationToken);
}

public async Task<ChatResponse> Chat(
ChatRequest chatRequest,
CancellationToken cancellationToken = default)
{
chatRequest.Stream = false;
var request = new HttpRequestMessage(HttpMethod.Post, "api/chat")
{
Content = new StringContent(
JsonSerializer.Serialize(chatRequest),
Encoding.UTF8,
"application/json")
};

var completion = chatRequest.Stream
? HttpCompletionOption.ResponseHeadersRead
: HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(request, completion, cancellationToken);
response.EnsureSuccessStatusCode();

return await ProcessChatResponseAsync(response);
}

public async IAsyncEnumerable<ChatResponseStream?> StreamChat(
ChatRequest chatRequest,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -372,8 +395,7 @@ private async Task StreamPostAsync<TRequest, TResponse>(
await ProcessStreamedResponseAsync(response, streamer, cancellationToken);
}

private async IAsyncEnumerable<TResponse?>
StreamPostAsync<TRequest, TResponse>(
private async IAsyncEnumerable<TResponse?> StreamPostAsync<TRequest, TResponse>(
string endpoint,
TRequest requestModel,
[EnumeratorCancellation] CancellationToken cancellationToken)
Expand All @@ -399,7 +421,6 @@ private async IAsyncEnumerable<TResponse?>
yield return result;
}


private static async Task ProcessStreamedResponseAsync<TLine>(
HttpResponseMessage response,
IResponseStreamer<TLine> streamer,
Expand Down Expand Up @@ -454,8 +475,7 @@ private static async Task<ConversationContext> ProcessStreamedCompletionResponse
return new ConversationContext(Array.Empty<long>());
}

private static async IAsyncEnumerable<GenerateCompletionResponseStream?>
ProcessStreamedCompletionResponseAsync(
private static async IAsyncEnumerable<GenerateCompletionResponseStream?> ProcessStreamedCompletionResponseAsync(
HttpResponseMessage response,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
Expand All @@ -473,8 +493,15 @@ private static async IAsyncEnumerable<GenerateCompletionResponseStream?>
}
}

private static async Task<IEnumerable<Message>>
ProcessStreamedChatResponseAsync(
private static async Task<ChatResponse> ProcessChatResponseAsync(HttpResponseMessage response)
{
var responseBody = await response.Content.ReadAsStringAsync();
var chatResponse = JsonSerializer.Deserialize<ChatResponse>(responseBody);

return chatResponse!;
}

private static async Task<IEnumerable<Message>> ProcessStreamedChatResponseAsync(
ChatRequest chatRequest,
HttpResponseMessage response,
IResponseStreamer<ChatResponseStream?> streamer,
Expand Down Expand Up @@ -510,8 +537,7 @@ private static async Task<IEnumerable<Message>>
return Array.Empty<Message>();
}

private static async IAsyncEnumerable<ChatResponseStream?>
ProcessStreamedChatResponseAsync(
private static async IAsyncEnumerable<ChatResponseStream?> ProcessStreamedChatResponseAsync(
HttpResponseMessage response,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
Expand Down
60 changes: 60 additions & 0 deletions test/OllamaApiClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,66 @@ public async Task Streams_Response_Message_Chunks()
}
}

public class ChatMethod : OllamaApiClientTests
{
[Test]
public async Task Receives_Response_MessageWithMetadata()
{
await using var stream = new MemoryStream();

_response = new HttpResponseMessage
{
StatusCode = HttpStatusCode.OK,
Content = new StreamContent(stream)
};

await using var writer = new StreamWriter(stream, leaveOpen: true);
writer.AutoFlush = true;
await writer.WriteAsync(
"""
{
"model": "llama2",
"created_at": "2024-07-12T12:34:39.63897616Z",
"message": {
"role": "assistant",
"content": "Test content."
},
"done_reason": "stop",
"done": true,
"total_duration": 137729492272,
"load_duration": 133071702768,
"prompt_eval_count": 26,
"prompt_eval_duration": 35137000,
"eval_count": 323,
"eval_duration": 4575154000
}
""");
stream.Seek(0, SeekOrigin.Begin);

var chat = new ChatRequest
{
Model = "model",
Messages = [
new(ChatRole.User, "Why?"),
new(ChatRole.Assistant, "Because!"),
new(ChatRole.User, "And where?")]
};

var result = await _client.Chat(chat, CancellationToken.None);

result.Message.Role.Should().Be(ChatRole.Assistant);
result.Message.Content.Should().Be("Test content.");
result.Done.Should().BeTrue();
result.DoneReason.Should().Be("stop");
result.TotalDuration.Should().Be(137729492272);
result.LoadDuration.Should().Be(133071702768);
result.PromptEvalCount.Should().Be(26);
result.PromptEvalDuration.Should().Be(35137000);
result.EvalCount.Should().Be(323);
result.EvalDuration.Should().Be(4575154000);
}
}

public class StreamChatMethod : OllamaApiClientTests
{
[Test]
Expand Down
5 changes: 5 additions & 0 deletions test/TestOllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,9 @@ public Task<Version> GetVersion(CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

public Task<ChatResponse> Chat(ChatRequest chatRequest, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}

0 comments on commit 0785218

Please sign in to comment.