diff --git a/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs
index 13637dc5226..102fc86b138 100644
--- a/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs
@@ -2,9 +2,18 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Diagnostics;
+using System.IO;
using System.Security.Cryptography;
using System.Text.Json;
-using Microsoft.Shared.Diagnostics;
+#if NET
+using System.Threading;
+using System.Threading.Tasks;
+#endif
+
+#pragma warning disable S109 // Magic numbers should not be used
+#pragma warning disable SA1202 // Elements should be ordered by access
+#pragma warning disable SA1502 // Element should not be on a single line
namespace Microsoft.Extensions.AI;
@@ -12,50 +21,110 @@ namespace Microsoft.Extensions.AI;
internal static class CachingHelpers
{
/// Computes a default cache key for the specified parameters.
- /// Specifies the type of the data being used to compute the key.
- /// The data with which to compute the key.
- /// The .
- /// A string that will be used as a cache key.
- public static string GetCacheKey(TValue value, JsonSerializerOptions serializerOptions)
- => GetCacheKey(value, false, serializerOptions);
-
- /// Computes a default cache key for the specified parameters.
- /// Specifies the type of the data being used to compute the key.
- /// The data with which to compute the key.
- /// Another data item that causes the key to vary.
+ /// The data with which to compute the key.
/// The .
/// A string that will be used as a cache key.
- public static string GetCacheKey(TValue value, bool flag, JsonSerializerOptions serializerOptions)
+ public static string GetCacheKey(ReadOnlySpan values, JsonSerializerOptions serializerOptions)
{
- _ = Throw.IfNull(value);
- _ = Throw.IfNull(serializerOptions);
- serializerOptions.MakeReadOnly();
-
- var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue)));
-
- if (flag && jsonKeyBytes.Length > 0)
- {
- // Make an arbitrary change to the hash input based on the flag
- // The alternative would be including the flag in "value" in the
- // first place, but that's likely to require an extra allocation
- // or the inclusion of another type in the JsonSerializerContext.
- // This is a micro-optimization we can change at any time.
- jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]);
- }
+ Debug.Assert(serializerOptions is not null, "Expected serializer options to be non-null");
+ Debug.Assert(serializerOptions!.IsReadOnly, "Expected serializer options to already be read-only.");
// The complete JSON representation is excessively long for a cache key, duplicating much of the content
// from the value. So we use a hash of it as the default key, and we rely on collision resistance for security purposes.
// If a collision occurs, we'd serve the cached LLM response for a potentially unrelated prompt, leading to information
// disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit
// invalidating any existing cache entries that may exist in whatever IDistributedCache was in use.
-#if NET8_0_OR_GREATER
+
+#if NET
+ IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance ?? new();
+ IncrementalHashStream.ThreadStaticInstance = null;
+
+ foreach (object? value in values)
+ {
+ JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
+ }
+
Span hashData = stackalloc byte[SHA256.HashSizeInBytes];
- SHA256.HashData(jsonKeyBytes, hashData);
+ stream.GetHashAndReset(hashData);
+ IncrementalHashStream.ThreadStaticInstance = stream;
+
return Convert.ToHexString(hashData);
#else
+ MemoryStream stream = new();
+ foreach (object? value in values)
+ {
+ JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
+ }
+
using var sha256 = SHA256.Create();
- var hashData = sha256.ComputeHash(jsonKeyBytes);
- return BitConverter.ToString(hashData).Replace("-", string.Empty);
+ stream.Position = 0;
+ var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length);
+
+ var chars = new char[hashData.Length * 2];
+ int destPos = 0;
+ foreach (byte b in hashData)
+ {
+ int div = Math.DivRem(b, 16, out int rem);
+ chars[destPos++] = ToHexChar(div);
+ chars[destPos++] = ToHexChar(rem);
+
+ static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A');
+ }
+
+ Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array.");
+
+ return new string(chars);
#endif
}
+
+#if NET
+ /// Provides a stream that writes to an .
+ private sealed class IncrementalHashStream : Stream
+ {
+ /// A per-thread instance of .
+ /// An instance stored must be in a reset state ready to be used by another consumer.
+ [ThreadStatic]
+ public static IncrementalHashStream? ThreadStaticInstance;
+
+ /// Gets the current hash and resets.
+ public void GetHashAndReset(Span bytes) => _hash.GetHashAndReset(bytes);
+
+ /// The used by this instance.
+ private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);
+
+ protected override void Dispose(bool disposing)
+ {
+ _hash.Dispose();
+ base.Dispose(disposing);
+ }
+
+ public override void WriteByte(byte value) => Write(new ReadOnlySpan(in value));
+ public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count);
+ public override void Write(ReadOnlySpan buffer) => _hash.AppendData(buffer);
+
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ Write(buffer, offset, count);
+ return Task.CompletedTask;
+ }
+
+ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default)
+ {
+ Write(buffer.Span);
+ return ValueTask.CompletedTask;
+ }
+
+ public override void Flush() { }
+ public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+
+ public override bool CanWrite => true;
+ public override bool CanRead => false;
+ public override bool CanSeek => false;
+ public override long Length => throw new NotSupportedException();
+ public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
+ public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+ public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
+ public override void SetLength(long value) => throw new NotSupportedException();
+ }
+#endif
}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs
index 6ea79f9f738..678e9bd6523 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
@@ -19,8 +20,17 @@ namespace Microsoft.Extensions.AI;
///
public class DistributedCachingChatClient : CachingChatClient
{
+ /// A boxed value.
+ private static readonly object _boxedTrue = true;
+
+ /// A boxed value.
+ private static readonly object _boxedFalse = false;
+
+ /// The instance that will be used as the backing store for the cache.
private readonly IDistributedCache _storage;
- private JsonSerializerOptions _jsonSerializerOptions;
+
+ /// The to use when serializing cache data.
+ private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
/// Initializes a new instance of the class.
/// The underlying .
@@ -29,7 +39,6 @@ public DistributedCachingChatClient(IChatClient innerClient, IDistributedCache s
: base(innerClient)
{
_storage = Throw.IfNull(storage);
- _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
}
/// Gets or sets JSON serialization options to use when serializing cache data.
@@ -90,13 +99,16 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
}
///
- protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options)
+ protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) =>
+ GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);
+
+ /// Gets a cache key based on the supplied values.
+ /// The values to inform the key.
+ /// The computed key.
+ /// This provides the default implementation for .
+ protected string GetCacheKey(ReadOnlySpan values)
{
- // While it might be desirable to include ChatOptions in the cache key, it's not always possible,
- // since ChatOptions can contain types that are not guaranteed to be serializable or have a stable
- // hashcode across multiple calls. So the default cache key is simply the JSON representation of
- // the chat contents. Developers may subclass and override this to provide custom rules.
_jsonSerializerOptions.MakeReadOnly();
- return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions);
+ return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
}
}
diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs
index ecec409a1b3..6482ed8ed2b 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
@@ -74,12 +75,16 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
}
///
- protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options)
+ protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
+ GetCacheKey([value, options]);
+
+ /// Gets a cache key based on the supplied values.
+ /// The values to inform the key.
+ /// The computed key.
+ /// This provides the default implementation for .
+ protected string GetCacheKey(ReadOnlySpan values)
{
- // While it might be desirable to include options in the cache key, it's not always possible,
- // since options can contain types that are not guaranteed to be serializable or have a stable
- // hashcode across multiple calls. So the default cache key is simply the JSON representation of
- // the value. Developers may subclass and override this to provide custom rules.
- return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions);
+ _jsonSerializerOptions.MakeReadOnly();
+ return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
}
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs
index 67e23ec495c..772bb9cf7d6 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs
@@ -527,7 +527,7 @@ public async Task StreamingDoesNotCacheCanceledResultsAsync()
}
[Fact]
- public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
+ public async Task CacheKeyVariesByChatOptionsAsync()
{
// Arrange
var innerCallCount = 0;
@@ -546,20 +546,35 @@ public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};
- // Act: Call with two different ChatOptions
+ // Act: Call with two different ChatOptions that have the same values
var result1 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 1" } }
});
var result2 = await outer.CompleteAsync([], new ChatOptions
{
- AdditionalProperties = new() { { "someKey", "value 2" } }
+ AdditionalProperties = new() { { "someKey", "value 1" } }
});
// Assert: Same result
Assert.Equal(1, innerCallCount);
Assert.Equal("value 1", result1.Message.Text);
Assert.Equal("value 1", result2.Message.Text);
+
+ // Act: Call with two different ChatOptions that have different values
+ var result3 = await outer.CompleteAsync([], new ChatOptions
+ {
+ AdditionalProperties = new() { { "someKey", "value 1" } }
+ });
+ var result4 = await outer.CompleteAsync([], new ChatOptions
+ {
+ AdditionalProperties = new() { { "someKey", "value 2" } }
+ });
+
+ // Assert: Different results
+ Assert.Equal(2, innerCallCount);
+ Assert.Equal("value 1", result3.Message.Text);
+ Assert.Equal("value 2", result4.Message.Text);
}
[Fact]
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs
index a2818c7c3ed..f9356ef45c9 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs
@@ -221,7 +221,7 @@ public async Task DoesNotCacheCanceledResultsAsync()
}
[Fact]
- public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
+ public async Task CacheKeyVariesByEmbeddingOptionsAsync()
{
// Arrange
var innerCallCount = 0;
@@ -232,7 +232,7 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
{
innerCallCount++;
await Task.Yield();
- return [_expectedEmbedding];
+ return [new(((string)options!.AdditionalProperties!["someKey"]!).Select(c => (float)c).ToArray())];
}
};
using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage)
@@ -240,20 +240,35 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};
- // Act: Call with two different options
+ // Act: Call with two different EmbeddingGenerationOptions that have the same values
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
- AdditionalProperties = new() { ["someKey"] = "value 2" }
+ AdditionalProperties = new() { ["someKey"] = "value 1" }
});
// Assert: Same result
Assert.Equal(1, innerCallCount);
- AssertEmbeddingsEqual(_expectedEmbedding, result1);
- AssertEmbeddingsEqual(_expectedEmbedding, result2);
+ AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result1);
+ AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2);
+
+ // Act: Call with two different EmbeddingGenerationOptions that have different values
+ var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
+ {
+ AdditionalProperties = new() { ["someKey"] = "value 1" }
+ });
+ var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
+ {
+ AdditionalProperties = new() { ["someKey"] = "value 2" }
+ });
+
+ // Assert: Different result
+ Assert.Equal(2, innerCallCount);
+ AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result3);
+ AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
}
[Fact]
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs
index e376da86dad..b077542c17c 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs
@@ -25,4 +25,6 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(Dictionary))]
[JsonSerializable(typeof(DayOfWeek[]))]
[JsonSerializable(typeof(Guid))]
+[JsonSerializable(typeof(ChatOptions))]
+[JsonSerializable(typeof(EmbeddingGenerationOptions))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;