Skip to content

Commit 99d25a7

Browse files
committed
Refactor token usage and prompt generation
Enhanced token usage logging in AzureOpenAITextGenerator and OpenAITextGenerator with detailed information. Introduced new private fields `_deployment` and `_textModel` in respective classes and updated constructors accordingly. Modified `MemoryAnswer` to handle multiple `TokenUsage` instances. Expanded `TokenUsage` class with additional properties and renamed existing ones for clarity. In `SearchClient`, added `CreatePrompt` method and updated `GenerateAnswer` to use it, ensuring token usage is always populated.
1 parent c4d4cf2 commit 99d25a7

File tree

5 files changed

+100
-24
lines changed

5 files changed

+100
-24
lines changed

extensions/AzureOpenAI/AzureOpenAI/AzureOpenAITextGenerator.cs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Diagnostics.CodeAnalysis;
56
using System.Net.Http;
@@ -30,6 +31,8 @@ public sealed class AzureOpenAITextGenerator : ITextGenerator
3031
private readonly ITextTokenizer _textTokenizer;
3132
private readonly ILogger<AzureOpenAITextGenerator> _log;
3233

34+
private readonly string _deployment;
35+
3336
/// <inheritdoc/>
3437
public int MaxTokenTotal { get; }
3538

@@ -89,6 +92,7 @@ public AzureOpenAITextGenerator(
8992
{
9093
this._client = skClient;
9194
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<AzureOpenAITextGenerator>();
95+
this._deployment = config.Deployment;
9296
this.MaxTokenTotal = config.MaxTokenTotal;
9397

9498
if (textTokenizer == null)
@@ -145,10 +149,28 @@ public IReadOnlyList<string> GetTokens(string text)
145149
IAsyncEnumerable<StreamingTextContent> result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken);
146150
await foreach (StreamingTextContent x in result)
147151
{
148-
var tokenUsage = x.Metadata?["Usage"] is ChatTokenUsage { } usage
149-
? new TokenUsage { InputTokenCount = usage.InputTokenCount, OutputTokenCount = usage.OutputTokenCount, TotalTokenCount = usage.TotalTokenCount }
150-
: null;
151-
152+
TokenUsage? tokenUsage = null;
153+
154+
if (x.Metadata?["Usage"] is ChatTokenUsage { } usage)
155+
{
156+
this._log.LogTrace("Usage report: input tokens {0}, output tokens {1}, output reasoning tokens {2}",
157+
usage?.InputTokenCount, usage?.OutputTokenCount, usage?.OutputTokenDetails.ReasoningTokenCount);
158+
159+
tokenUsage = new TokenUsage
160+
{
161+
Timestamp = DateTime.UtcNow,
162+
ServiceType = "Azure OpenAI",
163+
ModelType = "TextGeneration",
164+
ModelName = this._deployment,
165+
ServiceTokensIn = usage!.InputTokenCount,
166+
ServiceTokensOut = usage.OutputTokenCount,
167+
ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount
168+
};
169+
}
170+
171+
// NOTE: as stated at https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices,
172+
// The Choice can also be empty for the last chunk if we set stream_options: { "include_usage": true} to get token counts, so we can continue
173+
// only if both x.Text and tokenUsage are null.
152174
if (x.Text is null && tokenUsage is null) { continue; }
153175

154176
yield return (x.Text, tokenUsage);

extensions/OpenAI/OpenAI/OpenAITextGenerator.cs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Diagnostics.CodeAnalysis;
56
using System.Net.Http;
@@ -30,6 +31,8 @@ public sealed class OpenAITextGenerator : ITextGenerator
3031
private readonly ITextTokenizer _textTokenizer;
3132
private readonly ILogger<OpenAITextGenerator> _log;
3233

34+
private readonly string _textModel;
35+
3336
/// <inheritdoc/>
3437
public int MaxTokenTotal { get; }
3538

@@ -88,6 +91,7 @@ public OpenAITextGenerator(
8891
{
8992
this._client = skClient;
9093
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<OpenAITextGenerator>();
94+
this._textModel = config.TextModel;
9195
this.MaxTokenTotal = config.TextModelMaxTokenTotal;
9296

9397
if (textTokenizer == null)
@@ -144,10 +148,28 @@ public IReadOnlyList<string> GetTokens(string text)
144148
IAsyncEnumerable<StreamingTextContent> result = this._client.GetStreamingTextContentsAsync(prompt, skOptions, cancellationToken: cancellationToken);
145149
await foreach (StreamingTextContent x in result)
146150
{
147-
var tokenUsage = x.Metadata?["Usage"] is ChatTokenUsage { } usage
148-
? new TokenUsage { InputTokenCount = usage.InputTokenCount, OutputTokenCount = usage.OutputTokenCount, TotalTokenCount = usage.TotalTokenCount }
149-
: null;
150-
151+
TokenUsage? tokenUsage = null;
152+
153+
if (x.Metadata?["Usage"] is ChatTokenUsage { } usage)
154+
{
155+
this._log.LogTrace("Usage report: input tokens {0}, output tokens {1}, output reasoning tokens {2}",
156+
usage?.InputTokenCount, usage?.OutputTokenCount, usage?.OutputTokenDetails.ReasoningTokenCount);
157+
158+
tokenUsage = new TokenUsage
159+
{
160+
Timestamp = DateTime.UtcNow,
161+
ServiceType = "OpenAI",
162+
ModelType = "TextGeneration",
163+
ModelName = this._textModel,
164+
ServiceTokensIn = usage!.InputTokenCount,
165+
ServiceTokensOut = usage.OutputTokenCount,
166+
ServiceReasoningTokens = usage.OutputTokenDetails?.ReasoningTokenCount
167+
};
168+
}
169+
170+
// NOTE: as stated at https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices,
171+
// The Choice can also be empty for the last chunk if we set stream_options: { "include_usage": true} to get token counts, so we can continue
172+
// only if both x.Text and tokenUsage are null.
151173
if (x.Text is null && tokenUsage is null) { continue; }
152174

153175
yield return (x.Text, tokenUsage);

service/Abstractions/Models/MemoryAnswer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public class MemoryAnswer
4848
/// <remarks>Not all the models and text generators return token usage information.</remarks>
4949
[JsonPropertyName("tokenUsage")]
5050
[JsonPropertyOrder(11)]
51-
public TokenUsage? TokenUsage { get; set; }
51+
public IList<TokenUsage> TokenUsages { get; set; } = [];
5252

5353
/// <summary>
5454
/// List of the relevant sources used to produce the answer.
Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Text.Json.Serialization;
45

56
namespace Microsoft.KernelMemory.Models;
@@ -9,21 +10,38 @@ namespace Microsoft.KernelMemory.Models;
910
/// </summary>
1011
public class TokenUsage
1112
{
13+
public DateTime Timestamp { get; set; }
14+
15+
public string? ServiceType { get; set; }
16+
17+
public string? ModelType { get; set; }
18+
19+
public string? ModelName { get; set; }
20+
1221
/// <summary>
13-
/// The number of tokens in the request message input, spanning all message content items.
22+
/// The number of tokens in the request message input, spanning all message content items, measured by the tokenizer.
1423
/// </summary>
15-
[JsonPropertyOrder(0)]
16-
public int InputTokenCount { get; set; }
24+
[JsonPropertyName("tokenizer_tokens_in")]
25+
public int TokeninzerTokensIn { get; set; }
1726

1827
/// <summary>
19-
/// The combined number of output tokens in the generated completion, as consumed by the model.
28+
/// The combined number of output tokens in the generated completion, measured by the tokenizer.
2029
/// </summary>
21-
[JsonPropertyOrder(1)]
22-
public int OutputTokenCount { get; set; }
30+
[JsonPropertyName("tokenizer_tokens_out")]
31+
public int TokeninzerTokensOut { get; set; }
2332

2433
/// <summary>
25-
/// The total number of combined input (prompt) and output (completion) tokens used.
34+
/// The number of tokens in the request message input, spanning all message content items, measured by the service.
2635
/// </summary>
27-
[JsonPropertyOrder(2)]
28-
public int TotalTokenCount { get; set; }
36+
[JsonPropertyName("service_tokens_in")]
37+
public int? ServiceTokensIn { get; set; }
38+
39+
/// <summary>
40+
/// The combined number of output tokens in the generated completion, as consumed by the model.
41+
/// </summary>
42+
[JsonPropertyName("service_tokens_out")]
43+
public int? ServiceTokensOut { get; set; }
44+
45+
[JsonPropertyName("service_reasoning_tokens")]
46+
public int? ServiceReasoningTokens { get; set; }
2947
}

service/Core/Search/SearchClient.cs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,16 @@ public async Task<MemoryAnswer> AskAsync(
337337
return noAnswerFound;
338338
}
339339

340+
var prompt = this.CreatePrompt(question, facts.ToString(), context);
341+
340342
var text = new StringBuilder();
341343
TokenUsage? tokenUsage = null;
342344

343345
var charsGenerated = 0;
344346
var watch = new Stopwatch();
345347
watch.Restart();
346348

347-
await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false))
349+
await foreach (var x in this.GenerateAnswer(prompt, context, cancellationToken).ConfigureAwait(false))
348350
{
349351
if (x.Text is not null)
350352
{
@@ -363,7 +365,13 @@ public async Task<MemoryAnswer> AskAsync(
363365
watch.Stop();
364366

365367
answer.Result = text.ToString();
366-
answer.TokenUsage = tokenUsage;
368+
369+
// If the service does not provide Token usage information, we explicitly create it.
370+
tokenUsage ??= new TokenUsage { Timestamp = DateTime.UtcNow, ModelType = "TextGeneration" };
371+
tokenUsage.TokeninzerTokensIn = this._textGenerator.CountTokens(prompt);
372+
tokenUsage.TokeninzerTokensOut = this._textGenerator.CountTokens(answer.Result);
373+
374+
answer.TokenUsages.Add(tokenUsage);
367375

368376
this._log.LogSensitive("Answer: {0}", answer.Result);
369377
answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer);
@@ -392,12 +400,9 @@ public async Task<MemoryAnswer> AskAsync(
392400
return answer;
393401
}
394402

395-
private IAsyncEnumerable<(string? Text, TokenUsage? TokenUsage)> GenerateAnswer(string question, string facts, IContext? context, CancellationToken token)
403+
private string CreatePrompt(string question, string facts, IContext? context)
396404
{
397405
string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt);
398-
int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
399-
double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
400-
double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);
401406

402407
prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
403408

@@ -406,6 +411,15 @@ public async Task<MemoryAnswer> AskAsync(
406411
prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase);
407412
prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase);
408413

414+
return prompt;
415+
}
416+
417+
private IAsyncEnumerable<(string? Text, TokenUsage? TokenUsage)> GenerateAnswer(string prompt, IContext? context, CancellationToken token)
418+
{
419+
int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
420+
double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
421+
double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);
422+
409423
var options = new TextGenerationOptions
410424
{
411425
MaxTokens = maxTokens,

0 commit comments

Comments
 (0)