Skip to content

Commit 430065c

Browse files
Rework cache key handling in caching client / generator (#5641)
* Rework cache key handling in caching client / generator - Expose the default cache key helper so that customization doesn't require re-implementing the whole thing. - Make it easy to incorporate additional state into the cache key. - Avoid serializing all of the values for the key into a new byte[], at least on .NET 8+. There, we can serialize directly into a stream that targets an IncrementalHash. - Include Chat/EmbeddingGenerationOptions in the cache key by default. * Update test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs Co-authored-by: Shyam N <[email protected]> --------- Co-authored-by: Shyam N <[email protected]>
1 parent 73962c6 commit 430065c

File tree

6 files changed

+173
-55
lines changed

6 files changed

+173
-55
lines changed

src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs

Lines changed: 101 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,60 +2,129 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Diagnostics;
6+
using System.IO;
57
using System.Security.Cryptography;
68
using System.Text.Json;
7-
using Microsoft.Shared.Diagnostics;
9+
#if NET
10+
using System.Threading;
11+
using System.Threading.Tasks;
12+
#endif
13+
14+
#pragma warning disable S109 // Magic numbers should not be used
15+
#pragma warning disable SA1202 // Elements should be ordered by access
16+
#pragma warning disable SA1502 // Element should not be on a single line
817

918
namespace Microsoft.Extensions.AI;
1019

1120
/// <summary>Provides internal helpers for implementing caching services.</summary>
1221
internal static class CachingHelpers
1322
{
1423
/// <summary>Computes a default cache key for the specified parameters.</summary>
15-
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
16-
/// <param name="value">The data with which to compute the key.</param>
17-
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
18-
/// <returns>A string that will be used as a cache key.</returns>
19-
public static string GetCacheKey<TValue>(TValue value, JsonSerializerOptions serializerOptions)
20-
=> GetCacheKey(value, false, serializerOptions);
21-
22-
/// <summary>Computes a default cache key for the specified parameters.</summary>
23-
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
24-
/// <param name="value">The data with which to compute the key.</param>
25-
/// <param name="flag">Another data item that causes the key to vary.</param>
24+
/// <param name="values">The data with which to compute the key.</param>
2625
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
2726
/// <returns>A string that will be used as a cache key.</returns>
28-
public static string GetCacheKey<TValue>(TValue value, bool flag, JsonSerializerOptions serializerOptions)
27+
public static string GetCacheKey(ReadOnlySpan<object?> values, JsonSerializerOptions serializerOptions)
2928
{
30-
_ = Throw.IfNull(value);
31-
_ = Throw.IfNull(serializerOptions);
32-
serializerOptions.MakeReadOnly();
33-
34-
var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue)));
35-
36-
if (flag && jsonKeyBytes.Length > 0)
37-
{
38-
// Make an arbitrary change to the hash input based on the flag
39-
// The alternative would be including the flag in "value" in the
40-
// first place, but that's likely to require an extra allocation
41-
// or the inclusion of another type in the JsonSerializerContext.
42-
// This is a micro-optimization we can change at any time.
43-
jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]);
44-
}
29+
Debug.Assert(serializerOptions is not null, "Expected serializer options to be non-null");
30+
Debug.Assert(serializerOptions!.IsReadOnly, "Expected serializer options to already be read-only.");
4531

4632
// The complete JSON representation is excessively long for a cache key, duplicating much of the content
4733
// from the value. So we use a hash of it as the default key, and we rely on collision resistance for security purposes.
4834
// If a collision occurs, we'd serve the cached LLM response for a potentially unrelated prompt, leading to information
4935
// disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit
5036
// invalidating any existing cache entries that may exist in whatever IDistributedCache was in use.
51-
#if NET8_0_OR_GREATER
37+
38+
#if NET
39+
IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance ?? new();
40+
IncrementalHashStream.ThreadStaticInstance = null;
41+
42+
foreach (object? value in values)
43+
{
44+
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
45+
}
46+
5247
Span<byte> hashData = stackalloc byte[SHA256.HashSizeInBytes];
53-
SHA256.HashData(jsonKeyBytes, hashData);
48+
stream.GetHashAndReset(hashData);
49+
IncrementalHashStream.ThreadStaticInstance = stream;
50+
5451
return Convert.ToHexString(hashData);
5552
#else
53+
MemoryStream stream = new();
54+
foreach (object? value in values)
55+
{
56+
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
57+
}
58+
5659
using var sha256 = SHA256.Create();
57-
var hashData = sha256.ComputeHash(jsonKeyBytes);
58-
return BitConverter.ToString(hashData).Replace("-", string.Empty);
60+
stream.Position = 0;
61+
var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length);
62+
63+
var chars = new char[hashData.Length * 2];
64+
int destPos = 0;
65+
foreach (byte b in hashData)
66+
{
67+
int div = Math.DivRem(b, 16, out int rem);
68+
chars[destPos++] = ToHexChar(div);
69+
chars[destPos++] = ToHexChar(rem);
70+
71+
static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A');
72+
}
73+
74+
Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array.");
75+
76+
return new string(chars);
5977
#endif
6078
}
79+
80+
#if NET
81+
/// <summary>Provides a stream that writes to an <see cref="IncrementalHash"/>.</summary>
82+
private sealed class IncrementalHashStream : Stream
83+
{
84+
/// <summary>A per-thread instance of <see cref="IncrementalHashStream"/>.</summary>
85+
/// <remarks>An instance stored must be in a reset state ready to be used by another consumer.</remarks>
86+
[ThreadStatic]
87+
public static IncrementalHashStream? ThreadStaticInstance;
88+
89+
/// <summary>Gets the current hash and resets.</summary>
90+
public void GetHashAndReset(Span<byte> bytes) => _hash.GetHashAndReset(bytes);
91+
92+
/// <summary>The <see cref="IncrementalHash"/> used by this instance.</summary>
93+
private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);
94+
95+
protected override void Dispose(bool disposing)
96+
{
97+
_hash.Dispose();
98+
base.Dispose(disposing);
99+
}
100+
101+
public override void WriteByte(byte value) => Write(new ReadOnlySpan<byte>(in value));
102+
public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count);
103+
public override void Write(ReadOnlySpan<byte> buffer) => _hash.AppendData(buffer);
104+
105+
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
106+
{
107+
Write(buffer, offset, count);
108+
return Task.CompletedTask;
109+
}
110+
111+
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
112+
{
113+
Write(buffer.Span);
114+
return ValueTask.CompletedTask;
115+
}
116+
117+
public override void Flush() { }
118+
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
119+
120+
public override bool CanWrite => true;
121+
public override bool CanRead => false;
122+
public override bool CanSeek => false;
123+
public override long Length => throw new NotSupportedException();
124+
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
125+
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
126+
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
127+
public override void SetLength(long value) => throw new NotSupportedException();
128+
}
129+
#endif
61130
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
56
using System.Text.Json;
67
using System.Threading;
@@ -19,8 +20,17 @@ namespace Microsoft.Extensions.AI;
1920
/// </remarks>
2021
public class DistributedCachingChatClient : CachingChatClient
2122
{
23+
/// <summary>A boxed <see langword="true"/> value.</summary>
24+
private static readonly object _boxedTrue = true;
25+
26+
/// <summary>A boxed <see langword="false"/> value.</summary>
27+
private static readonly object _boxedFalse = false;
28+
29+
/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
2230
private readonly IDistributedCache _storage;
23-
private JsonSerializerOptions _jsonSerializerOptions;
31+
32+
/// <summary>The <see cref="JsonSerializerOptions"/> to use when serializing cache data.</summary>
33+
private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
2434

2535
/// <summary>Initializes a new instance of the <see cref="DistributedCachingChatClient"/> class.</summary>
2636
/// <param name="innerClient">The underlying <see cref="IChatClient"/>.</param>
@@ -29,7 +39,6 @@ public DistributedCachingChatClient(IChatClient innerClient, IDistributedCache s
2939
: base(innerClient)
3040
{
3141
_storage = Throw.IfNull(storage);
32-
_jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
3342
}
3443

3544
/// <summary>Gets or sets JSON serialization options to use when serializing cache data.</summary>
@@ -90,13 +99,16 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
9099
}
91100

92101
/// <inheritdoc />
93-
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options)
102+
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options) =>
103+
GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);
104+
105+
/// <summary>Gets a cache key based on the supplied values.</summary>
106+
/// <param name="values">The values to inform the key.</param>
107+
/// <returns>The computed key.</returns>
108+
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(bool, IList{ChatMessage}, ChatOptions?)"/>.</remarks>
109+
protected string GetCacheKey(ReadOnlySpan<object?> values)
94110
{
95-
// While it might be desirable to include ChatOptions in the cache key, it's not always possible,
96-
// since ChatOptions can contain types that are not guaranteed to be serializable or have a stable
97-
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
98-
// the chat contents. Developers may subclass and override this to provide custom rules.
99111
_jsonSerializerOptions.MakeReadOnly();
100-
return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions);
112+
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
101113
}
102114
}

src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Text.Json;
56
using System.Text.Json.Serialization.Metadata;
67
using System.Threading;
@@ -74,12 +75,16 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
7475
}
7576

7677
/// <inheritdoc />
77-
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options)
78+
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
79+
GetCacheKey([value, options]);
80+
81+
/// <summary>Gets a cache key based on the supplied values.</summary>
82+
/// <param name="values">The values to inform the key.</param>
83+
/// <returns>The computed key.</returns>
84+
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(TInput, EmbeddingGenerationOptions?)"/>.</remarks>
85+
protected string GetCacheKey(ReadOnlySpan<object?> values)
7886
{
79-
// While it might be desirable to include options in the cache key, it's not always possible,
80-
// since options can contain types that are not guaranteed to be serializable or have a stable
81-
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
82-
// the value. Developers may subclass and override this to provide custom rules.
83-
return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions);
87+
_jsonSerializerOptions.MakeReadOnly();
88+
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
8489
}
8590
}

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ public async Task StreamingDoesNotCacheCanceledResultsAsync()
527527
}
528528

529529
[Fact]
530-
public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
530+
public async Task CacheKeyVariesByChatOptionsAsync()
531531
{
532532
// Arrange
533533
var innerCallCount = 0;
@@ -546,20 +546,35 @@ public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
546546
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
547547
};
548548

549-
// Act: Call with two different ChatOptions
549+
// Act: Call with two different ChatOptions that have the same values
550550
var result1 = await outer.CompleteAsync([], new ChatOptions
551551
{
552552
AdditionalProperties = new() { { "someKey", "value 1" } }
553553
});
554554
var result2 = await outer.CompleteAsync([], new ChatOptions
555555
{
556-
AdditionalProperties = new() { { "someKey", "value 2" } }
556+
AdditionalProperties = new() { { "someKey", "value 1" } }
557557
});
558558

559559
// Assert: Same result
560560
Assert.Equal(1, innerCallCount);
561561
Assert.Equal("value 1", result1.Message.Text);
562562
Assert.Equal("value 1", result2.Message.Text);
563+
564+
// Act: Call with two different ChatOptions that have different values
565+
var result3 = await outer.CompleteAsync([], new ChatOptions
566+
{
567+
AdditionalProperties = new() { { "someKey", "value 1" } }
568+
});
569+
var result4 = await outer.CompleteAsync([], new ChatOptions
570+
{
571+
AdditionalProperties = new() { { "someKey", "value 2" } }
572+
});
573+
574+
// Assert: Different results
575+
Assert.Equal(2, innerCallCount);
576+
Assert.Equal("value 1", result3.Message.Text);
577+
Assert.Equal("value 2", result4.Message.Text);
563578
}
564579

565580
[Fact]

test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ public async Task DoesNotCacheCanceledResultsAsync()
221221
}
222222

223223
[Fact]
224-
public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
224+
public async Task CacheKeyVariesByEmbeddingOptionsAsync()
225225
{
226226
// Arrange
227227
var innerCallCount = 0;
@@ -232,28 +232,43 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
232232
{
233233
innerCallCount++;
234234
await Task.Yield();
235-
return [_expectedEmbedding];
235+
return [new(((string)options!.AdditionalProperties!["someKey"]!).Select(c => (float)c).ToArray())];
236236
}
237237
};
238238
using var outer = new DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, _storage)
239239
{
240240
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
241241
};
242242

243-
// Act: Call with two different options
243+
// Act: Call with two different EmbeddingGenerationOptions that have the same values
244244
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
245245
{
246246
AdditionalProperties = new() { ["someKey"] = "value 1" }
247247
});
248248
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
249249
{
250-
AdditionalProperties = new() { ["someKey"] = "value 2" }
250+
AdditionalProperties = new() { ["someKey"] = "value 1" }
251251
});
252252

253253
// Assert: Same result
254254
Assert.Equal(1, innerCallCount);
255-
AssertEmbeddingsEqual(_expectedEmbedding, result1);
256-
AssertEmbeddingsEqual(_expectedEmbedding, result2);
255+
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result1);
256+
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2);
257+
258+
// Act: Call with two different EmbeddingGenerationOptions that have different values
259+
var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
260+
{
261+
AdditionalProperties = new() { ["someKey"] = "value 1" }
262+
});
263+
var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
264+
{
265+
AdditionalProperties = new() { ["someKey"] = "value 2" }
266+
});
267+
268+
// Assert: Different result
269+
Assert.Equal(2, innerCallCount);
270+
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result3);
271+
AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
257272
}
258273

259274
[Fact]

test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@ namespace Microsoft.Extensions.AI;
2525
[JsonSerializable(typeof(Dictionary<string, string>))]
2626
[JsonSerializable(typeof(DayOfWeek[]))]
2727
[JsonSerializable(typeof(Guid))]
28+
[JsonSerializable(typeof(ChatOptions))]
29+
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
2830
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;

0 commit comments

Comments
 (0)