@@ -15,6 +15,8 @@ namespace Microsoft.SemanticKernel.Text;
1515/// </summary>
1616public static class TextChunker
1717{
18+ public delegate int TokenCounter ( string input ) ;
19+
1820 private static readonly char [ ] s_spaceChar = new [ ] { ' ' } ;
1921 private static readonly string ? [ ] s_plaintextSplitOptions = new [ ] { "\n \r " , "." , "?!" , ";" , ":" , "," , ")]}" , " " , "-" , null } ;
2022 private static readonly string ? [ ] s_markdownSplitOptions = new [ ] { "." , "?!" , ";" , ":" , "," , ")]}" , " " , "-" , "\n \r " , null } ;
@@ -24,21 +26,27 @@ public static class TextChunker
2426 /// </summary>
2527 /// <param name="text">Text to split</param>
2628 /// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
29+ /// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
2730 /// <returns>List of lines.</returns>
28- public static List < string > SplitPlainTextLines ( string text , int maxTokensPerLine )
31+ public static List < string > SplitPlainTextLines ( string text , int maxTokensPerLine , TokenCounter ? tokenCounter = null )
2932 {
30- return InternalSplitLines ( text , maxTokensPerLine , trim : true , s_plaintextSplitOptions ) ;
33+ tokenCounter ??= DefaultTokenCounter ;
34+
35+ return InternalSplitLines ( text , maxTokensPerLine , trim : true , s_plaintextSplitOptions , tokenCounter ) ;
3136 }
3237
3338 /// <summary>
3439 /// Split markdown text into lines.
3540 /// </summary>
3641 /// <param name="text">Text to split</param>
3742 /// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
43+ /// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
3844 /// <returns>List of lines.</returns>
39- public static List < string > SplitMarkDownLines ( string text , int maxTokensPerLine )
45+ public static List < string > SplitMarkDownLines ( string text , int maxTokensPerLine , TokenCounter ? tokenCounter = null )
4046 {
41- return InternalSplitLines ( text , maxTokensPerLine , trim : true , s_markdownSplitOptions ) ;
47+ tokenCounter ??= DefaultTokenCounter ;
48+
49+ return InternalSplitLines ( text , maxTokensPerLine , trim : true , s_markdownSplitOptions , tokenCounter ) ;
4250 }
4351
4452 /// <summary>
@@ -47,10 +55,13 @@ public static List<string> SplitMarkDownLines(string text, int maxTokensPerLine)
4755 /// <param name="lines">Lines of text.</param>
4856 /// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
4957 /// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
58+ /// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
5059 /// <returns>List of paragraphs.</returns>
51- public static List < string > SplitPlainTextParagraphs ( List < string > lines , int maxTokensPerParagraph , int overlapTokens = 0 )
60+ public static List < string > SplitPlainTextParagraphs ( List < string > lines , int maxTokensPerParagraph , int overlapTokens = 0 , TokenCounter ? tokenCounter = null )
5261 {
53- return InternalSplitTextParagraphs ( lines , maxTokensPerParagraph , overlapTokens , ( text , maxTokens ) => InternalSplitLines ( text , maxTokens , trim : false , s_plaintextSplitOptions ) ) ;
62+ tokenCounter ??= DefaultTokenCounter ;
63+
64+ return InternalSplitTextParagraphs ( lines , maxTokensPerParagraph , overlapTokens , ( text , maxTokens ) => InternalSplitLines ( text , maxTokens , trim : false , s_plaintextSplitOptions , tokenCounter ) , tokenCounter ) ;
5465 }
5566
5667 /// <summary>
@@ -59,13 +70,16 @@ public static List<string> SplitPlainTextParagraphs(List<string> lines, int maxT
5970 /// <param name="lines">Lines of text.</param>
6071 /// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
6172 /// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
73+ /// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
6274 /// <returns>List of paragraphs.</returns>
63- public static List < string > SplitMarkdownParagraphs ( List < string > lines , int maxTokensPerParagraph , int overlapTokens = 0 )
75+ public static List < string > SplitMarkdownParagraphs ( List < string > lines , int maxTokensPerParagraph , int overlapTokens = 0 , TokenCounter ? tokenCounter = null )
6476 {
65- return InternalSplitTextParagraphs ( lines , maxTokensPerParagraph , overlapTokens , ( text , maxTokens ) => InternalSplitLines ( text , maxTokens , trim : false , s_markdownSplitOptions ) ) ;
77+ tokenCounter ??= DefaultTokenCounter ;
78+
79+ return InternalSplitTextParagraphs ( lines , maxTokensPerParagraph , overlapTokens , ( text , maxTokens ) => InternalSplitLines ( text , maxTokens , trim : false , s_markdownSplitOptions , tokenCounter ) , tokenCounter ) ;
6680 }
6781
68- private static List < string > InternalSplitTextParagraphs ( List < string > lines , int maxTokensPerParagraph , int overlapTokens , Func < string , int , List < string > > longLinesSplitter )
82+ private static List < string > InternalSplitTextParagraphs ( List < string > lines , int maxTokensPerParagraph , int overlapTokens , Func < string , int , List < string > > longLinesSplitter , TokenCounter tokenCounter )
6983 {
7084 if ( maxTokensPerParagraph <= 0 )
7185 {
@@ -87,14 +101,14 @@ private static List<string> InternalSplitTextParagraphs(List<string> lines, int
87101 // Split long lines first
88102 IEnumerable < string > truncatedLines = lines . SelectMany ( line => longLinesSplitter ( line , adjustedMaxTokensPerParagraph ) ) ;
89103
90- var paragraphs = BuildParagraph ( truncatedLines , adjustedMaxTokensPerParagraph , longLinesSplitter ) ;
104+ var paragraphs = BuildParagraph ( truncatedLines , adjustedMaxTokensPerParagraph , longLinesSplitter , tokenCounter ) ;
91105 // distribute text more evenly in the last paragraphs when the last paragraph is too short.
92106 if ( paragraphs . Count > 1 )
93107 {
94108 var lastParagraph = paragraphs [ paragraphs . Count - 1 ] ;
95109 var secondLastParagraph = paragraphs [ paragraphs . Count - 2 ] ;
96110
97- if ( TokenCount ( lastParagraph . Length ) < adjustedMaxTokensPerParagraph / 4 )
111+ if ( tokenCounter ( lastParagraph ) < adjustedMaxTokensPerParagraph / 4 )
98112 {
99113 var lastParagraphTokens = lastParagraph . Split ( s_spaceChar , StringSplitOptions . RemoveEmptyEntries ) ;
100114 var secondLastParagraphTokens = secondLastParagraph . Split ( s_spaceChar , StringSplitOptions . RemoveEmptyEntries ) ;
@@ -129,14 +143,14 @@ private static List<string> InternalSplitTextParagraphs(List<string> lines, int
129143 return paragraphs ;
130144 }
131145
132- private static List < string > BuildParagraph ( IEnumerable < string > truncatedLines , int maxTokensPerParagraph , Func < string , int , List < string > > longLinesSplitter )
146+ private static List < string > BuildParagraph ( IEnumerable < string > truncatedLines , int maxTokensPerParagraph , Func < string , int , List < string > > longLinesSplitter , TokenCounter tokenCounter )
133147 {
134148 StringBuilder paragraphBuilder = new ( ) ;
135149 List < string > paragraphs = new ( ) ;
136150
137151 foreach ( string line in truncatedLines )
138152 {
139- if ( paragraphBuilder . Length > 0 && TokenCount ( paragraphBuilder . Length ) + TokenCount ( line . Length ) + 1 >= maxTokensPerParagraph )
153+ if ( paragraphBuilder . Length > 0 && tokenCounter ( paragraphBuilder . ToString ( ) ) + tokenCounter ( line ) + 1 >= maxTokensPerParagraph )
140154 {
141155 // Complete the paragraph and prepare for the next
142156 paragraphs . Add ( paragraphBuilder . ToString ( ) . Trim ( ) ) ;
@@ -155,7 +169,7 @@ private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, i
155169 return paragraphs ;
156170 }
157171
158- private static List < string > InternalSplitLines ( string text , int maxTokensPerLine , bool trim , string ? [ ] splitOptions )
172+ private static List < string > InternalSplitLines ( string text , int maxTokensPerLine , bool trim , string ? [ ] splitOptions , TokenCounter tokenCounter )
159173 {
160174 var result = new List < string > ( ) ;
161175
@@ -164,7 +178,7 @@ private static List<string> InternalSplitLines(string text, int maxTokensPerLine
164178 for ( int i = 0 ; i < splitOptions . Length ; i ++ )
165179 {
166180 int count = result . Count ; // track where the original input left off
167- var ( splits2 , inputWasSplit2 ) = Split ( result , maxTokensPerLine , splitOptions [ i ] . AsSpan ( ) , trim ) ;
181+ var ( splits2 , inputWasSplit2 ) = Split ( result , maxTokensPerLine , splitOptions [ i ] . AsSpan ( ) , trim , tokenCounter ) ;
168182 result . AddRange ( splits2 ) ;
169183 result . RemoveRange ( 0 , count ) ; // remove the original input
170184 if ( ! inputWasSplit2 )
@@ -175,26 +189,26 @@ private static List<string> InternalSplitLines(string text, int maxTokensPerLine
175189 return result ;
176190 }
177191
178- private static ( List < string > , bool ) Split ( List < string > input , int maxTokens , ReadOnlySpan < char > separators , bool trim )
192+ private static ( List < string > , bool ) Split ( List < string > input , int maxTokens , ReadOnlySpan < char > separators , bool trim , TokenCounter tokenCounter )
179193 {
180194 bool inputWasSplit = false ;
181195 List < string > result = new ( ) ;
182196 int count = input . Count ;
183197 for ( int i = 0 ; i < count ; i ++ )
184198 {
185- var ( splits , split ) = Split ( input [ i ] . AsSpan ( ) , input [ i ] , maxTokens , separators , trim ) ;
199+ var ( splits , split ) = Split ( input [ i ] . AsSpan ( ) , input [ i ] , maxTokens , separators , trim , tokenCounter ) ;
186200 result . AddRange ( splits ) ;
187201 inputWasSplit |= split ;
188202 }
189203 return ( result , inputWasSplit ) ;
190204 }
191205
192- private static ( List < string > , bool ) Split ( ReadOnlySpan < char > input , string ? inputString , int maxTokens , ReadOnlySpan < char > separators , bool trim )
206+ private static ( List < string > , bool ) Split ( ReadOnlySpan < char > input , string ? inputString , int maxTokens , ReadOnlySpan < char > separators , bool trim , TokenCounter tokenCounter )
193207 {
194208 Debug . Assert ( inputString is null || input . SequenceEqual ( inputString . AsSpan ( ) ) ) ;
195209 List < string > result = new ( ) ;
196210 var inputWasSplit = false ;
197- if ( TokenCount ( input . Length ) > maxTokens )
211+ if ( tokenCounter ( input . ToString ( ) ) > maxTokens )
198212 {
199213 inputWasSplit = true ;
200214
@@ -238,9 +252,9 @@ private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inpu
238252 }
239253
240254 // Recursion
241- var ( splits1 , split1 ) = Split ( firstHalf , null , maxTokens , separators , trim ) ;
255+ var ( splits1 , split1 ) = Split ( firstHalf , null , maxTokens , separators , trim , tokenCounter ) ;
242256 result . AddRange ( splits1 ) ;
243- var ( splits2 , split2 ) = Split ( secondHalf , null , maxTokens , separators , trim ) ;
257+ var ( splits2 , split2 ) = Split ( secondHalf , null , maxTokens , separators , trim , tokenCounter ) ;
244258 result . AddRange ( splits2 ) ;
245259
246260 inputWasSplit = split1 || split2 ;
@@ -259,10 +273,8 @@ private static (List<string>, bool) Split(ReadOnlySpan<char> input, string? inpu
259273 return ( result , inputWasSplit ) ;
260274 }
261275
262- private static int TokenCount ( int inputLength )
276+ private static int DefaultTokenCounter ( string input )
263277 {
264- // TODO: partitioning methods should be configurable to allow for different tokenization strategies
265- // depending on the model to be called. For now, we use an extremely rough estimate.
266- return inputLength / 4 ;
278+ return input . Length / 4 ;
267279 }
268280}
0 commit comments