Skip to content

Commit fe978d8

Browse files
anthonypuppolemillermicrosoft
authored andcommitted
.Net: Expose the token counter in the TextChunker (microsoft#2147)
### Motivation and Context The TextChunker currently uses a rough estimate of the total number of tokens in a string (input divided by 4). This PR aims to expose the counter logic to be caller configurable, enabling accurate token counts based on which model will be consuming the chunks. Fixes microsoft#2146 ### Description Adds an optional token counter function parameter to the TextChunker methods. If not supplied, it falls back to the base token counter logic. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: Lee Miller <[email protected]>
1 parent 853ad24 commit fe978d8

File tree

4 files changed

+232
-25
lines changed

4 files changed

+232
-25
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Threading.Tasks;
6+
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.Tokenizers;
7+
using Microsoft.SemanticKernel.Text;
8+
9+
// ReSharper disable once InconsistentNaming
10+
public static class Example55_TextChunker
11+
{
12+
private const string text = @"The city of Venice, located in the northeastern part of Italy,
13+
is renowned for its unique geographical features. Built on more than 100 small islands in a lagoon in the
14+
Adriatic Sea, it has no roads, just canals including the Grand Canal thoroughfare lined with Renaissance and
15+
Gothic palaces. The central square, Piazza San Marco, contains St. Mark's Basilica, which is tiled with Byzantine
16+
mosaics, and the Campanile bell tower offering views of the city's red roofs.
17+
18+
The Amazon Rainforest, also known as Amazonia, is a moist broadleaf tropical rainforest in the Amazon biome that
19+
covers most of the Amazon basin of South America. This basin encompasses 7 million square kilometers, of which
20+
5.5 million square kilometers are covered by the rainforest. This region includes territory belonging to nine nations
21+
and 3.4 million square kilometers of uncontacted tribes. The Amazon represents over half of the planet's remaining
22+
rainforests and comprises the largest and most biodiverse tract of tropical rainforest in the world.
23+
24+
The Great Barrier Reef is the world's largest coral reef system composed of over 2,900 individual reefs and 900 islands
25+
stretching for over 2,300 kilometers over an area of approximately 344,400 square kilometers. The reef is located in the
26+
Coral Sea, off the coast of Queensland, Australia. The Great Barrier Reef can be seen from outer space and is the world's
27+
biggest single structure made by living organisms. This reef structure is composed of and built by billions of tiny organisms,
28+
known as coral polyps.";
29+
30+
public static Task RunAsync()
31+
{
32+
RunExample();
33+
RunExampleWithCustomTokenCounter();
34+
35+
return Task.CompletedTask;
36+
}
37+
38+
private static void RunExample()
39+
{
40+
Console.WriteLine("=== Text chunking ===");
41+
42+
var lines = TextChunker.SplitPlainTextLines(text, 40);
43+
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, 120);
44+
45+
WriteParagraphsToConsole(paragraphs);
46+
}
47+
48+
private static void RunExampleWithCustomTokenCounter()
49+
{
50+
Console.WriteLine("=== Text chunking with a custom token counter ===");
51+
52+
var lines = TextChunker.SplitPlainTextLines(text, 40, TokenCounter);
53+
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, 120, tokenCounter: TokenCounter);
54+
55+
WriteParagraphsToConsole(paragraphs);
56+
}
57+
58+
private static void WriteParagraphsToConsole(List<string> paragraphs)
59+
{
60+
for (var i = 0; i < paragraphs.Count; i++)
61+
{
62+
Console.WriteLine(paragraphs[i]);
63+
64+
if (i < paragraphs.Count - 1)
65+
{
66+
Console.WriteLine("------------------------");
67+
}
68+
}
69+
}
70+
71+
private static int TokenCounter(string input)
72+
{
73+
var tokens = GPT3Tokenizer.Encode(input);
74+
75+
return tokens.Count;
76+
}
77+
}

dotnet/samples/KernelSyntaxExamples/Program.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ public static async Task Main()
7474
await Example52_ApimAuth.RunAsync().SafeWaitAsync(cancelToken);
7575
await Example53_Kusto.RunAsync().SafeWaitAsync(cancelToken);
7676
await Example54_AzureChatCompletionWithData.RunAsync().SafeWaitAsync(cancelToken);
77+
await Example55_TextChunker.RunAsync().SafeWaitAsync(cancelToken);
7778
}
7879

7980
private static void LoadUserSecrets()

dotnet/src/SemanticKernel.UnitTests/Text/TextChunkerTests.cs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,4 +478,121 @@ public void CanSplitVeryLargeDocumentsWithoutStackOverflowing()
478478
Assert.NotEmpty(paragraphs);
479479
#pragma warning restore CA5394
480480
}
481+
482+
[Fact]
483+
public void CanSplitPlainTextLinesWithCustomTokenCounter()
484+
{
485+
const string Input = "This is a test of the emergency broadcast system. This is only a test.";
486+
var expected = new[]
487+
{
488+
"This is a test of the emergency broadcast system.",
489+
"This is only a test."
490+
};
491+
492+
var result = TextChunker.SplitPlainTextLines(Input, 60, (input) => input.Length);
493+
494+
Assert.Equal(expected, result);
495+
}
496+
497+
[Fact]
498+
public void CanSplitMarkdownParagraphsWithCustomTokenCounter()
499+
{
500+
List<string> input = new()
501+
{
502+
"This is a test of the emergency broadcast system. This is only a test.",
503+
"We repeat, this is only a test. A unit test."
504+
};
505+
var expected = new[]
506+
{
507+
"This is a test of the emergency broadcast system.",
508+
"This is only a test.",
509+
"We repeat, this is only a test. A unit test."
510+
};
511+
512+
var result = TextChunker.SplitMarkdownParagraphs(input, 52, tokenCounter: (input) => input.Length);
513+
514+
Assert.Equal(expected, result);
515+
}
516+
517+
[Fact]
518+
public void CanSplitMarkdownParagraphsWithOverlapAndCustomTokenCounter()
519+
{
520+
List<string> input = new()
521+
{
522+
"This is a test of the emergency broadcast system. This is only a test.",
523+
"We repeat, this is only a test. A unit test."
524+
};
525+
526+
var expected = new[]
527+
{
528+
"This is a test of the emergency broadcast system.",
529+
"emergency broadcast system. This is only a test.",
530+
"This is only a test. We repeat, this is only a test.",
531+
"We repeat, this is only a test. A unit test.",
532+
"A unit test."
533+
};
534+
535+
var result = TextChunker.SplitMarkdownParagraphs(input, 75, 40, (input) => input.Length);
536+
537+
Assert.Equal(expected, result);
538+
}
539+
540+
[Fact]
541+
public void CanSplitTextParagraphsWithCustomTokenCounter()
542+
{
543+
List<string> input = new()
544+
{
545+
"This is a test of the emergency broadcast system. This is only a test.",
546+
"We repeat, this is only a test. A unit test."
547+
};
548+
549+
var expected = new[]
550+
{
551+
"This is a test of the emergency broadcast system.",
552+
"This is only a test.",
553+
"We repeat, this is only a test. A unit test."
554+
};
555+
556+
var result = TextChunker.SplitPlainTextParagraphs(input, 52, tokenCounter: (input) => input.Length);
557+
558+
Assert.Equal(expected, result);
559+
}
560+
561+
[Fact]
562+
public void CanSplitTextParagraphsWithOverlapAndCustomTokenCounter()
563+
{
564+
List<string> input = new()
565+
{
566+
"This is a test of the emergency broadcast system. This is only a test.",
567+
"We repeat, this is only a test. A unit test."
568+
};
569+
570+
var expected = new[]
571+
{
572+
"This is a test of the emergency broadcast system.",
573+
"emergency broadcast system. This is only a test.",
574+
"This is only a test. We repeat, this is only a test.",
575+
"We repeat, this is only a test. A unit test.",
576+
"A unit test."
577+
};
578+
579+
var result = TextChunker.SplitPlainTextParagraphs(input, 75, 40, (input) => input.Length);
580+
581+
Assert.Equal(expected, result);
582+
}
583+
584+
[Fact]
585+
public void CanSplitMarkDownLinesWithCustomTokenCounter()
586+
{
587+
const string Input = "This is a test of the emergency broadcast system. This is only a test.";
588+
var expected = new[]
589+
{
590+
"This is a test of the emergency broadcast system.",
591+
"This is only a test."
592+
};
593+
594+
var result = TextChunker.SplitMarkDownLines(Input, 60, (input) => input.Length);
595+
596+
Assert.Equal(expected, result);
597+
}
481598
}

dotnet/src/SemanticKernel/Text/TextChunker.cs

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace Microsoft.SemanticKernel.Text;
1515
/// </summary>
1616
public static class TextChunker
1717
{
18+
public delegate int TokenCounter(string input);
19+
1820
private static readonly char[] s_spaceChar = new[] { ' ' };
1921
private static readonly string?[] s_plaintextSplitOptions = new[] { "\n\r", ".", "?!", ";", ":", ",", ")]}", " ", "-", null };
2022
private static readonly string?[] s_markdownSplitOptions = new[] { ".", "?!", ";", ":", ",", ")]}", " ", "-", "\n\r", null };
@@ -24,21 +26,27 @@ public static class TextChunker
2426
/// </summary>
2527
/// <param name="text">Text to split</param>
2628
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
29+
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
2730
/// <returns>List of lines.</returns>
28-
public static List<string> SplitPlainTextLines(string text, int maxTokensPerLine)
31+
public static List<string> SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
2932
{
30-
return InternalSplitLines(text, maxTokensPerLine, trim: true, s_plaintextSplitOptions);
33+
tokenCounter ??= DefaultTokenCounter;
34+
35+
return InternalSplitLines(text, maxTokensPerLine, trim: true, s_plaintextSplitOptions, tokenCounter);
3136
}
3237

3338
/// <summary>
3439
/// Split markdown text into lines.
3540
/// </summary>
3641
/// <param name="text">Text to split</param>
3742
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
43+
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
3844
/// <returns>List of lines.</returns>
39-
public static List<string> SplitMarkDownLines(string text, int maxTokensPerLine)
45+
public static List<string> SplitMarkDownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
4046
{
41-
return InternalSplitLines(text, maxTokensPerLine, trim: true, s_markdownSplitOptions);
47+
tokenCounter ??= DefaultTokenCounter;
48+
49+
return InternalSplitLines(text, maxTokensPerLine, trim: true, s_markdownSplitOptions, tokenCounter);
4250
}
4351

4452
/// <summary>
@@ -47,10 +55,13 @@ public static List<string> SplitMarkDownLines(string text, int maxTokensPerLine)
4755
/// <param name="lines">Lines of text.</param>
4856
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
4957
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
58+
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
5059
/// <returns>List of paragraphs.</returns>
51-
public static List<string> SplitPlainTextParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens = 0)
60+
public static List<string> SplitPlainTextParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, TokenCounter? tokenCounter = null)
5261
{
53-
return InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, (text, maxTokens) => InternalSplitLines(text, maxTokens, trim: false, s_plaintextSplitOptions));
62+
tokenCounter ??= DefaultTokenCounter;
63+
64+
return InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, (text, maxTokens) => InternalSplitLines(text, maxTokens, trim: false, s_plaintextSplitOptions, tokenCounter), tokenCounter);
5465
}
5566

5667
/// <summary>
@@ -59,13 +70,16 @@ public static List<string> SplitPlainTextParagraphs(List<string> lines, int maxT
5970
/// <param name="lines">Lines of text.</param>
6071
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
6172
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
73+
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
6274
/// <returns>List of paragraphs.</returns>
63-
public static List<string> SplitMarkdownParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens = 0)
75+
public static List<string> SplitMarkdownParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, TokenCounter? tokenCounter = null)
6476
{
65-
return InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, (text, maxTokens) => InternalSplitLines(text, maxTokens, trim: false, s_markdownSplitOptions));
77+
tokenCounter ??= DefaultTokenCounter;
78+
79+
return InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, (text, maxTokens) => InternalSplitLines(text, maxTokens, trim: false, s_markdownSplitOptions, tokenCounter), tokenCounter);
6680
}
6781

68-
private static List<string> InternalSplitTextParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens, Func<string, int, List<string>> longLinesSplitter)
82+
private static List<string> InternalSplitTextParagraphs(List<string> lines, int maxTokensPerParagraph, int overlapTokens, Func<string, int, List<string>> longLinesSplitter, TokenCounter tokenCounter)
6983
{
7084
if (maxTokensPerParagraph <= 0)
7185
{
@@ -87,14 +101,14 @@ private static List<string> InternalSplitTextParagraphs(List<string> lines, int
87101
// Split long lines first
88102
IEnumerable<string> truncatedLines = lines.SelectMany(line => longLinesSplitter(line, adjustedMaxTokensPerParagraph));
89103

90-
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, longLinesSplitter);
104+
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, longLinesSplitter, tokenCounter);
91105
// distribute text more evenly in the last paragraphs when the last paragraph is too short.
92106
if (paragraphs.Count > 1)
93107
{
94108
var lastParagraph = paragraphs[paragraphs.Count - 1];
95109
var secondLastParagraph = paragraphs[paragraphs.Count - 2];
96110

97-
if (TokenCount(lastParagraph.Length) < adjustedMaxTokensPerParagraph / 4)
111+
if (tokenCounter(lastParagraph) < adjustedMaxTokensPerParagraph / 4)
98112
{
99113
var lastParagraphTokens = lastParagraph.Split(s_spaceChar, StringSplitOptions.RemoveEmptyEntries);
100114
var secondLastParagraphTokens = secondLastParagraph.Split(s_spaceChar, StringSplitOptions.RemoveEmptyEntries);
@@ -129,14 +143,14 @@ private static List<string> InternalSplitTextParagraphs(List<string> lines, int
129143
return paragraphs;
130144
}
131145

132-
private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, int maxTokensPerParagraph, Func<string, int, List<string>> longLinesSplitter)
146+
private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, int maxTokensPerParagraph, Func<string, int, List<string>> longLinesSplitter, TokenCounter tokenCounter)
133147
{
134148
StringBuilder paragraphBuilder = new();
135149
List<string> paragraphs = new();
136150

137151
foreach (string line in truncatedLines)
138152
{
139-
if (paragraphBuilder.Length > 0 && TokenCount(paragraphBuilder.Length) + TokenCount(line.Length) + 1 >= maxTokensPerParagraph)
153+
if (paragraphBuilder.Length > 0 && tokenCounter(paragraphBuilder.ToString()) + tokenCounter(line) + 1 >= maxTokensPerParagraph)
140154
{
141155
// Complete the paragraph and prepare for the next
142156
paragraphs.Add(paragraphBuilder.ToString().Trim());
@@ -155,7 +169,7 @@ private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, i
155169
return paragraphs;
156170
}
157171

158-
private static List<string> InternalSplitLines(string text, int maxTokensPerLine, bool trim, string?[] splitOptions)
172+
private static List<string> InternalSplitLines(string text, int maxTokensPerLine, bool trim, string?[] splitOptions, TokenCounter tokenCounter)
159173
{
160174
var result = new List<string>();
161175

@@ -164,7 +178,7 @@ private static List<string> InternalSplitLines(string text, int maxTokensPerLine
164178
for (int i = 0; i < splitOptions.Length; i++)
165179
{
166180
int count = result.Count; // track where the original input left off
167-
var (splits2, inputWasSplit2) = Split(result, maxTokensPerLine, splitOptions[i].AsSpan(), trim);
181+
var (splits2, inputWasSplit2) = Split(result, maxTokensPerLine, splitOptions[i].AsSpan(), trim, tokenCounter);
168182
result.AddRange(splits2);
169183
result.RemoveRange(0, count); // remove the original input
170184
if (!inputWasSplit2)
@@ -175,26 +189,26 @@ private static List<string> InternalSplitLines(string text, int maxTokensPerLine
175189
return result;
176190
}
177191

178-
private static (List<string>, bool) Split(List<string> input, int maxTokens, ReadOnlySpan<char> separators, bool trim)
192+
private static (List<string>, bool) Split(List<string> input, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter tokenCounter)
179193
{
180194
bool inputWasSplit = false;
181195
List<string> result = new();
182196
int count = input.Count;
183197
for (int i = 0; i < count; i++)
184198
{
185-
var (splits, split) = Split(input[i].AsSpan(), input[i], maxTokens, separators, trim);
199+
var (splits, split) = Split(input[i].AsSpan(), input[i], maxTokens, separators, trim, tokenCounter);
186200
result.AddRange(splits);
187201
inputWasSplit |= split;
188202
}
189203
return (result, inputWasSplit);
190204
}
191205

192-
private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inputString, int maxTokens, ReadOnlySpan<char> separators, bool trim)
206+
private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inputString, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter tokenCounter)
193207
{
194208
Debug.Assert(inputString is null || input.SequenceEqual(inputString.AsSpan()));
195209
List<string> result = new();
196210
var inputWasSplit = false;
197-
if (TokenCount(input.Length) > maxTokens)
211+
if (tokenCounter(input.ToString()) > maxTokens)
198212
{
199213
inputWasSplit = true;
200214

@@ -238,9 +252,9 @@ private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inpu
238252
}
239253

240254
// Recursion
241-
var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim);
255+
var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter);
242256
result.AddRange(splits1);
243-
var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim);
257+
var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter);
244258
result.AddRange(splits2);
245259

246260
inputWasSplit = split1 || split2;
@@ -259,10 +273,8 @@ private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inpu
259273
return (result, inputWasSplit);
260274
}
261275

262-
private static int TokenCount(int inputLength)
276+
private static int DefaultTokenCounter(string input)
263277
{
264-
// TODO: partitioning methods should be configurable to allow for different tokenization strategies
265-
// depending on the model to be called. For now, we use an extremely rough estimate.
266-
return inputLength / 4;
278+
return input.Length / 4;
267279
}
268280
}

0 commit comments

Comments
 (0)