Skip to content

Commit

Permalink
Merge pull request #60 from milbk/main
Browse files Browse the repository at this point in the history
 Update Embedding for /api/embed
  • Loading branch information
awaescher authored Aug 5, 2024
2 parents 724782d + 146e555 commit e5c93c3
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 16 deletions.
2 changes: 1 addition & 1 deletion OllamaApiConsole/Demos/ModelManagerConsole.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private async Task GenerateEmbedding()
var embedContent = ReadInput("Enter a string to to embed:");
Ollama.SelectedModel = embedModel;
var embedResponse = await Ollama.GenerateEmbeddings(embedContent);
AnsiConsole.MarkupLineInterpolated($"[cyan]{string.Join(", ", embedResponse.Embedding)}[/]");
AnsiConsole.MarkupLineInterpolated($"[cyan]{string.Join(", ", embedResponse.Embeddings.First())}[/]");
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/IOllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public interface IOllamaApiClient
Task DeleteModel(string model, CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/embeddings endpoint to generate embeddings
/// Sends a request to the /api/embed endpoint to generate embeddings
/// </summary>
/// <param name="request">The parameters to generate embeddings for</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
Expand Down
37 changes: 32 additions & 5 deletions src/Models/GenerateEmbedding.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace OllamaSharp.Models;
Expand All @@ -16,8 +17,8 @@ public class GenerateEmbeddingRequest
/// <summary>
/// The text to generate embeddings for
/// </summary>
[JsonPropertyName("prompt")]
public string Prompt { get; set; } = null!;
[JsonPropertyName("input")]
public List<string> Input { get; set; } = null!;

/// <summary>
/// Additional model parameters listed in the documentation for the Modelfile
Expand All @@ -33,16 +34,42 @@ public class GenerateEmbeddingRequest
[JsonPropertyName("keep_alive")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? KeepAlive { get; set; }

/// <summary>
/// Truncates the end of each input to fit within context length.
/// Returns error if false and context length is exceeded. Defaults to true
/// </summary>
[JsonPropertyName("truncate")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? Truncate { get; set; }
}

/// <summary>
/// The response from the /api/embeddings endpoint
/// The response from the /api/embed endpoint
/// </summary>
public class GenerateEmbeddingResponse
{
/// <summary>
/// An array of embeddings for the input text
/// </summary>
[JsonPropertyName("embedding")]
public double[] Embedding { get; set; } = null!;
[JsonPropertyName("embeddings")]
public List<double[]> Embeddings { get; set; } = null!;

/// <summary>
/// The time spent generating the response
/// </summary>
[JsonPropertyName("total_duration")]
public long? TotalDuration { get; set; }

/// <summary>
/// The time spent in nanoseconds loading the model
/// </summary>
[JsonPropertyName("load_duration")]
public long? LoadDuration { get; set; }

/// <summary>
/// The number of tokens in the input text
/// </summary>
[JsonPropertyName("prompt_eval_count")]
public int? PromptEvalCount { get; set; }
}
2 changes: 1 addition & 1 deletion src/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public Task PushModel(PushModelRequest modelRequest, IResponseStreamer<PushModel

/// <inheritdoc />
public Task<GenerateEmbeddingResponse> GenerateEmbeddings(GenerateEmbeddingRequest request, CancellationToken cancellationToken = default)
=> PostAsync<GenerateEmbeddingRequest, GenerateEmbeddingResponse>("api/embeddings", request, cancellationToken);
=> PostAsync<GenerateEmbeddingRequest, GenerateEmbeddingResponse>("api/embed", request, cancellationToken);

/// <inheritdoc />
public Task<ConversationContext> StreamCompletion(GenerateCompletionRequest request, IResponseStreamer<GenerateCompletionResponseStream?> streamer, CancellationToken cancellationToken = default)
Expand Down
8 changes: 4 additions & 4 deletions src/OllamaApiClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,17 @@ public static Task PushModel(this IOllamaApiClient client, string name, IRespons
=> client.PushModel(new PushModelRequest { Model = name, Stream = true }, streamer, cancellationToken);

/// <summary>
/// Sends a request to the /api/embeddings endpoint to generate embeddings for the currently selected model
/// Sends a request to the /api/embed endpoint to generate embeddings for the currently selected model
/// </summary>
/// <param name="client">The client used to execute the command</param>
/// <param name="prompt">The prompt to generate embeddings for</param>
/// <param name="input">The input text to generate embeddings for</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
public static Task<GenerateEmbeddingResponse> GenerateEmbeddings(this IOllamaApiClient client, string prompt, CancellationToken cancellationToken = default)
public static Task<GenerateEmbeddingResponse> GenerateEmbeddings(this IOllamaApiClient client, string input, CancellationToken cancellationToken = default)
{
var request = new GenerateEmbeddingRequest
{
Model = client.SelectedModel,
Prompt = prompt
Input = new List<string> { input }
};
return client.GenerateEmbeddings(request, cancellationToken);
}
Expand Down
8 changes: 4 additions & 4 deletions test/OllamaApiClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ public async Task Returns_Deserialized_Models()
_response = new HttpResponseMessage
{
StatusCode = HttpStatusCode.OK,
Content = new StringContent("{\r\n \"embedding\": [\r\n 0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313 ]\r\n}")
Content = new StringContent("{\r\n \"embeddings\": [[\r\n 0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313 ]]\r\n}")
};

var info = await _client.GenerateEmbeddings(new GenerateEmbeddingRequest { Model = "", Prompt = "" }, CancellationToken.None);
var info = await _client.GenerateEmbeddings(new GenerateEmbeddingRequest { Model = "", Input = [""]}, CancellationToken.None);

info.Embedding.Should().HaveCount(5);
info.Embedding.First().Should().BeApproximately(0.567, precision: 0.01);
info.Embeddings.First().Should().HaveCount(5);
info.Embeddings.First().First().Should().BeApproximately(0.567, precision: 0.01);
}
}
}
Expand Down

0 comments on commit e5c93c3

Please sign in to comment.