Skip to content

Commit 51dac37

Browse files
committed
feat: implement token counting from byte slices and add tokens module
1 parent e7bc4e1 commit 51dac37

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

src/language/language_type.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ impl LanguageType {
108108
// first character in the column, so removing starting whitespace
109109
// could cause a miscount.
110110
let line = if is_fortran { line } else { line.trim() };
111-
let tokens = Self::count_tokens(&String::from_utf8_lossy(line));
111+
let tokens = crate::tokens::count_tokens_from_bytes(line);
112112
if line.trim().is_empty() {
113113
(1, 0, 0, tokens)
114114
} else if is_literate
@@ -134,11 +134,6 @@ impl LanguageType {
134134
}
135135
}
136136

137-
fn count_tokens(text: &str) -> usize {
138-
let bpe = tiktoken_rs::p50k_base().unwrap();
139-
bpe.encode_with_special_tokens(text).len()
140-
}
141-
142137
#[inline]
143138
fn parse_lines(
144139
self,
@@ -218,7 +213,7 @@ impl LanguageType {
218213
}
219214
}
220215

221-
let tokens = Self::count_tokens(&String::from_utf8_lossy(lines));
216+
let tokens = crate::tokens::count_tokens_from_bytes(lines);
222217
stats.tokens += tokens;
223218

224219
stats

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ mod consts;
5454
mod language;
5555
mod sort;
5656
mod stats;
57+
mod tokens;
5758

5859
pub use self::{
5960
config::Config,

src/tokens.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
use once_cell::sync::Lazy;
2+
use tiktoken_rs::CoreBPE;
3+
4+
static TOKENIZER: Lazy<CoreBPE> = Lazy::new(|| tiktoken_rs::p50k_base().unwrap());
5+
6+
pub fn count_tokens(text: &str) -> usize {
7+
TOKENIZER.encode_with_special_tokens(text).len()
8+
}
9+
10+
pub fn count_tokens_from_bytes(bytes: &[u8]) -> usize {
11+
match std::str::from_utf8(bytes) {
12+
Ok(text) => count_tokens(text),
13+
Err(_) => {
14+
let text = String::from_utf8_lossy(bytes);
15+
count_tokens(&text)
16+
}
17+
}
18+
}

0 commit comments

Comments
 (0)