Skip to content

Commit b964b52

Browse files
authored
fix: Match OpenAI Token Counting Strategy 🪙 (danny-avila#945)
* wip token fix * fix: complete token count refactor to match OpenAI example * chore: add back sendPayload method (accidentally deleted) * chore: revise JSDoc for getTokenCountForMessage
1 parent 98d6481 commit b964b52

File tree

10 files changed

+109
-75
lines changed

10 files changed

+109
-75
lines changed

app/clients/AnthropicClient.js

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// const { Agent, ProxyAgent } = require('undici');
22
const BaseClient = require('./BaseClient');
3-
const {
4-
encoding_for_model: encodingForModel,
5-
get_encoding: getEncoding,
6-
} = require('@dqbd/tiktoken');
3+
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
74
const Anthropic = require('@anthropic-ai/sdk');
85

96
const HUMAN_PROMPT = '\n\nHuman:';

app/clients/BaseClient.js

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ class BaseClient {
272272
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
273273
*/
274274
async getMessagesWithinTokenLimit(messages) {
275-
let currentTokenCount = 0;
275+
// Every reply is primed with <|start|>assistant<|message|>, so we
276+
// start with 3 tokens for the label after all messages have been counted.
277+
let currentTokenCount = 3;
276278
let context = [];
277279
let messagesToRefine = [];
278280
let refineIndex = -1;
@@ -562,44 +564,29 @@ class BaseClient {
562564
* Algorithm adapted from "6. Counting tokens for chat API calls" of
563565
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
564566
*
565-
* An additional 2 tokens need to be added for metadata after all messages have been counted.
567+
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
566568
*
567-
* @param {*} message
569+
* @param {Object} message
568570
*/
569571
getTokenCountForMessage(message) {
570-
let tokensPerMessage;
571-
let nameAdjustment;
572-
if (this.modelOptions.model.startsWith('gpt-4')) {
573-
tokensPerMessage = 3;
574-
nameAdjustment = 1;
575-
} else {
576-
tokensPerMessage = 4;
577-
nameAdjustment = -1;
578-
}
572+
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
573+
let tokensPerMessage = 3;
574+
let tokensPerName = 1;
579575

580-
if (this.options.debug) {
581-
console.debug('getTokenCountForMessage', message);
576+
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
577+
tokensPerMessage = 4;
578+
tokensPerName = -1;
582579
}
583580

584-
// Map each property of the message to the number of tokens it contains
585-
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
586-
if (key === 'tokenCount' || typeof value !== 'string') {
587-
return 0;
581+
let numTokens = tokensPerMessage;
582+
for (let [key, value] of Object.entries(message)) {
583+
numTokens += this.getTokenCount(value);
584+
if (key === 'name') {
585+
numTokens += tokensPerName;
588586
}
589-
// Count the number of tokens in the property value
590-
const numTokens = this.getTokenCount(value);
591-
592-
// Adjust by `nameAdjustment` tokens if the property key is 'name'
593-
const adjustment = key === 'name' ? nameAdjustment : 0;
594-
return numTokens + adjustment;
595-
});
596-
597-
if (this.options.debug) {
598-
console.debug('propertyTokenCounts', propertyTokenCounts);
599587
}
600588

601-
// Sum the number of tokens in all properties and add `tokensPerMessage` for metadata
602-
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage);
589+
return numTokens;
603590
}
604591

605592
async sendPayload(payload, opts = {}) {

app/clients/ChatGPTClient.js

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
const crypto = require('crypto');
22
const Keyv = require('keyv');
3-
const {
4-
encoding_for_model: encodingForModel,
5-
get_encoding: getEncoding,
6-
} = require('@dqbd/tiktoken');
3+
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
74
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
85
const { Agent, ProxyAgent } = require('undici');
96
const BaseClient = require('./BaseClient');
@@ -526,8 +523,8 @@ ${botMessage.message}
526523
const prompt = `${promptBody}${promptSuffix}`;
527524
if (isChatGptModel) {
528525
messagePayload.content = prompt;
529-
// Add 2 tokens for metadata after all messages have been counted.
530-
currentTokenCount += 2;
526+
// Add 3 tokens for Assistant Label priming after all messages have been counted.
527+
currentTokenCount += 3;
531528
}
532529

533530
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
@@ -554,33 +551,29 @@ ${botMessage.message}
554551
* Algorithm adapted from "6. Counting tokens for chat API calls" of
555552
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
556553
*
557-
* An additional 2 tokens need to be added for metadata after all messages have been counted.
554+
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
558555
*
559-
* @param {*} message
556+
* @param {Object} message
560557
*/
561558
getTokenCountForMessage(message) {
562-
let tokensPerMessage;
563-
let nameAdjustment;
564-
if (this.modelOptions.model.startsWith('gpt-4')) {
565-
tokensPerMessage = 3;
566-
nameAdjustment = 1;
567-
} else {
559+
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
560+
let tokensPerMessage = 3;
561+
let tokensPerName = 1;
562+
563+
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
568564
tokensPerMessage = 4;
569-
nameAdjustment = -1;
565+
tokensPerName = -1;
570566
}
571567

572-
// Map each property of the message to the number of tokens it contains
573-
const propertyTokenCounts = Object.entries(message).map(([key, value]) => {
574-
// Count the number of tokens in the property value
575-
const numTokens = this.getTokenCount(value);
576-
577-
// Adjust by `nameAdjustment` tokens if the property key is 'name'
578-
const adjustment = key === 'name' ? nameAdjustment : 0;
579-
return numTokens + adjustment;
580-
});
568+
let numTokens = tokensPerMessage;
569+
for (let [key, value] of Object.entries(message)) {
570+
numTokens += this.getTokenCount(value);
571+
if (key === 'name') {
572+
numTokens += tokensPerName;
573+
}
574+
}
581575

582-
// Sum the number of tokens in all properties and add `tokensPerMessage` for metadata
583-
return propertyTokenCounts.reduce((a, b) => a + b, tokensPerMessage);
576+
return numTokens;
584577
}
585578
}
586579

app/clients/GoogleClient.js

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
const BaseClient = require('./BaseClient');
22
const { google } = require('googleapis');
33
const { Agent, ProxyAgent } = require('undici');
4-
const {
5-
encoding_for_model: encodingForModel,
6-
get_encoding: getEncoding,
7-
} = require('@dqbd/tiktoken');
4+
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
85

96
const tokenizersCache = {};
107

app/clients/OpenAIClient.js

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
const BaseClient = require('./BaseClient');
22
const ChatGPTClient = require('./ChatGPTClient');
3-
const {
4-
encoding_for_model: encodingForModel,
5-
get_encoding: getEncoding,
6-
} = require('@dqbd/tiktoken');
3+
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
74
const { maxTokensMap, genAzureChatCompletion } = require('../../utils');
85
const { runTitleChain } = require('./chains');
96
const { createLLM } = require('./llm');

app/clients/specs/BaseClient.test.js

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ describe('BaseClient', () => {
138138
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
139139
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
140140
];
141-
const expectedRemainingContextTokens = 58; // 100 - 5 - 19 - 18
141+
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
142+
const expectedRemainingContextTokens = 58 - 3; // (100 - 5 - 19 - 18) - 3
142143
const expectedMessagesToRefine = [];
143144

144145
const result = await TestClient.getMessagesWithinTokenLimit(messages);
@@ -168,7 +169,9 @@ describe('BaseClient', () => {
168169
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
169170
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
170171
];
171-
const expectedRemainingContextTokens = 8; // 50 - 18 - 19 - 5
172+
173+
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
174+
const expectedRemainingContextTokens = 8 - 3; // (50 - 18 - 19 - 5) - 3
172175
const expectedMessagesToRefine = [
173176
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 30 },
174177
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 30 },

app/clients/specs/OpenAIClient.test.js

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,63 @@ describe('OpenAIClient', () => {
213213
expect(result.prompt).toEqual([]);
214214
});
215215
});
216+
217+
describe('getTokenCountForMessage', () => {
218+
const example_messages = [
219+
{
220+
role: 'system',
221+
content:
222+
'You are a helpful, pattern-following assistant that translates corporate jargon into plain English.',
223+
},
224+
{
225+
role: 'system',
226+
name: 'example_user',
227+
content: 'New synergies will help drive top-line growth.',
228+
},
229+
{
230+
role: 'system',
231+
name: 'example_assistant',
232+
content: 'Things working well together will increase revenue.',
233+
},
234+
{
235+
role: 'system',
236+
name: 'example_user',
237+
content:
238+
'Let\'s circle back when we have more bandwidth to touch base on opportunities for increased leverage.',
239+
},
240+
{
241+
role: 'system',
242+
name: 'example_assistant',
243+
content: 'Let\'s talk later when we\'re less busy about how to do better.',
244+
},
245+
{
246+
role: 'user',
247+
content:
248+
'This late pivot means we don\'t have time to boil the ocean for the client deliverable.',
249+
},
250+
];
251+
252+
const testCases = [
253+
{ model: 'gpt-3.5-turbo-0301', expected: 127 },
254+
{ model: 'gpt-3.5-turbo-0613', expected: 129 },
255+
{ model: 'gpt-3.5-turbo', expected: 129 },
256+
{ model: 'gpt-4-0314', expected: 129 },
257+
{ model: 'gpt-4-0613', expected: 129 },
258+
{ model: 'gpt-4', expected: 129 },
259+
{ model: 'unknown', expected: 129 },
260+
];
261+
262+
testCases.forEach((testCase) => {
263+
it(`should return ${testCase.expected} tokens for model ${testCase.model}`, () => {
264+
client.modelOptions.model = testCase.model;
265+
client.selectTokenizer();
266+
// 3 tokens for assistant label
267+
let totalTokens = 3;
268+
for (let message of example_messages) {
269+
totalTokens += client.getTokenCountForMessage(message);
270+
}
271+
expect(totalTokens).toBe(testCase.expected);
272+
});
273+
});
274+
});
216275
});

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
"dependencies": {
2424
"@anthropic-ai/sdk": "^0.5.4",
2525
"@azure/search-documents": "^11.3.2",
26-
"@dqbd/tiktoken": "^1.0.7",
2726
"@keyv/mongo": "^2.1.8",
2827
"@waylaidwanderer/chatgpt-api": "^1.37.2",
2928
"axios": "^1.3.4",
@@ -60,6 +59,7 @@
6059
"passport-local": "^1.0.0",
6160
"pino": "^8.12.1",
6261
"sharp": "^0.32.5",
62+
"tiktoken": "^1.0.10",
6363
"ua-parser-js": "^1.0.36",
6464
"zod": "^3.22.2"
6565
},

server/routes/tokenizer.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
const express = require('express');
22
const router = express.Router();
3-
const { Tiktoken } = require('@dqbd/tiktoken/lite');
4-
const { load } = require('@dqbd/tiktoken/load');
5-
const registry = require('@dqbd/tiktoken/registry.json');
6-
const models = require('@dqbd/tiktoken/model_to_encoding.json');
3+
const { Tiktoken } = require('tiktoken/lite');
4+
const { load } = require('tiktoken/load');
5+
const registry = require('tiktoken/registry.json');
6+
const models = require('tiktoken/model_to_encoding.json');
77
const requireJwtAuth = require('../middleware/requireJwtAuth');
88

99
router.post('/', requireJwtAuth, async (req, res) => {

utils/tokens.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ const maxTokensMap = {
4141
'gpt-4': 8191,
4242
'gpt-4-0613': 8191,
4343
'gpt-4-32k': 32767,
44+
'gpt-4-32k-0314': 32767,
4445
'gpt-4-32k-0613': 32767,
4546
'gpt-3.5-turbo': 4095,
4647
'gpt-3.5-turbo-0613': 4095,

0 commit comments

Comments
 (0)