Skip to content

Commit eedc856

Browse files
committed
Partial sync of codebase
1 parent 5818d56 commit eedc856

File tree

4 files changed

+116
-28
lines changed

4 files changed

+116
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ This is the changelog for the open source version of tiktoken.
77
- Update version of `pyo3`
88
- Use new Rust edition
99
- Fix special token handling in `encode_to_numpy`
10+
- Better error handling
1011
- Improvements to private APIs
1112

1213
## [v0.10.0]

scripts/wheel_download.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import argparse
2+
import zipfile
3+
from pathlib import Path
4+
5+
import requests
6+
7+
8+
def download_artifacts(token, owner, repo, run_id, output_dir):
9+
headers = {"Authorization": f"token {token}", "Accept": "application/vnd.github.v3+json"}
10+
11+
# Get list of artifacts
12+
artifacts_url = f"https://api.github.com/repos/{owner}/{repo}/actions/runs/{run_id}/artifacts"
13+
response = requests.get(artifacts_url, headers=headers)
14+
response.raise_for_status()
15+
artifacts = response.json()["artifacts"]
16+
17+
if not artifacts:
18+
print(f"No artifacts found for run ID: {run_id}")
19+
return
20+
21+
output_dir = Path(output_dir)
22+
output_dir.mkdir(parents=True, exist_ok=True)
23+
24+
print(f"Found {len(artifacts)} artifacts")
25+
for artifact in artifacts:
26+
name = artifact["name"]
27+
download_url = artifact["archive_download_url"]
28+
29+
print(f"Downloading {name}...")
30+
31+
response = requests.get(download_url, headers=headers, stream=True)
32+
response.raise_for_status()
33+
34+
temp_zip = output_dir / f"{name}.zip"
35+
with open(temp_zip, "wb") as f:
36+
for chunk in response.iter_content(chunk_size=8192):
37+
f.write(chunk)
38+
with zipfile.ZipFile(temp_zip, "r") as zip_ref:
39+
zip_ref.extractall(output_dir)
40+
temp_zip.unlink()
41+
print(f"Downloaded and extracted {name}")
42+
43+
44+
if __name__ == "__main__":
45+
parser = argparse.ArgumentParser(description="Download artifacts from a GitHub Actions run")
46+
parser.add_argument("--token", required=True, help="GitHub Personal Access Token")
47+
parser.add_argument("--owner", required=True, help="Repository owner")
48+
parser.add_argument("--repo", required=True, help="Repository name")
49+
parser.add_argument("--run-id", required=True, help="Workflow run ID")
50+
parser.add_argument(
51+
"--output-dir", default="artifacts", help="Output directory for downloaded artifacts"
52+
)
53+
54+
args = parser.parse_args()
55+
56+
download_artifacts(args.token, args.owner, args.repo, args.run_id, args.output_dir)

src/lib.rs

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ impl std::fmt::Display for DecodeError {
172172

173173
impl std::error::Error for DecodeError {}
174174

175+
#[derive(Debug, Clone)]
176+
pub struct EncodeError {
177+
pub message: String,
178+
}
179+
180+
impl std::fmt::Display for EncodeError {
181+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
182+
write!(f, "Could not encode string: {}", self.message)
183+
}
184+
}
185+
186+
impl std::error::Error for EncodeError {}
187+
175188
const MAX_NUM_THREADS: usize = 128;
176189

177190
#[cfg_attr(feature = "python", pyclass)]
@@ -231,7 +244,11 @@ impl CoreBPE {
231244
ret
232245
}
233246

234-
pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
247+
pub fn encode(
248+
&self,
249+
text: &str,
250+
allowed_special: &HashSet<&str>,
251+
) -> Result<(Vec<Rank>, usize), EncodeError> {
235252
let special_regex = self._get_tl_special_regex();
236253
let regex = self._get_tl_regex();
237254
let mut ret = vec![];
@@ -257,8 +274,17 @@ impl CoreBPE {
257274
let end = next_special.map_or(text.len(), |m| m.start());
258275

259276
// Okay, here we go, compare this logic to encode_ordinary
260-
for mat in regex.find_iter(&text[start..end]) {
261-
let piece = mat.unwrap().as_str().as_bytes();
277+
for mat_res in regex.find_iter(&text[start..end]) {
278+
let mat = match mat_res {
279+
Ok(m) => m,
280+
Err(e) => {
281+
return Err(EncodeError {
282+
message: format!("Regex error while tokenizing: {e}"),
283+
});
284+
}
285+
};
286+
287+
let piece = mat.as_str().as_bytes();
262288
if let Some(token) = self.encoder.get(piece) {
263289
last_piece_token_len = 1;
264290
ret.push(*token);
@@ -284,7 +310,7 @@ impl CoreBPE {
284310

285311
// last_piece_token_len is how many tokens came from the last regex split. This is used
286312
// for determining unstable tokens, since you can't merge across (stable) regex splits
287-
(ret, last_piece_token_len)
313+
Ok((ret, last_piece_token_len))
288314
}
289315

290316
fn _increase_last_piece_token_len(
@@ -331,7 +357,7 @@ impl CoreBPE {
331357
text: &str,
332358
allowed_special: &HashSet<&str>,
333359
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
334-
let (tokens, last_piece_token_len) = self.encode(text, allowed_special);
360+
let (tokens, last_piece_token_len) = self.encode(text, allowed_special).unwrap();
335361
if last_piece_token_len == 0 {
336362
// If last_piece_token_len is zero, the last token was a special token and we have
337363
// no unstable bytes
@@ -427,7 +453,7 @@ impl CoreBPE {
427453
if unstable_bytes.len() > 1 {
428454
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
429455
if unstable_bytes.len() - last_decoded.1 > 0
430-
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
456+
&& last_decoded.0.is_some_and(|c| c.is_whitespace())
431457
{
432458
let mut reencoded = byte_pair_encode(
433459
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
@@ -517,7 +543,7 @@ impl CoreBPE {
517543

518544
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
519545
let allowed_special = self.special_tokens();
520-
self.encode(text, &allowed_special).0
546+
self.encode(text, &allowed_special).unwrap().0
521547
}
522548
}
523549

src/py.rs

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use std::collections::HashSet;
22

33
use pyo3::{
4-
PyResult, exceptions,
4+
IntoPyObjectExt, PyResult, exceptions,
55
prelude::*,
66
pybacked::PyBackedStr,
7-
types::{PyBytes, PyList, PyTuple},
7+
types::{PyBytes, PyList},
88
};
99
use rustc_hash::FxHashMap as HashMap;
1010

@@ -37,11 +37,14 @@ impl CoreBPE {
3737
py: Python,
3838
text: &str,
3939
allowed_special: HashSet<PyBackedStr>,
40-
) -> Vec<Rank> {
40+
) -> PyResult<Vec<Rank>> {
4141
py.allow_threads(|| {
4242
let allowed_special: HashSet<&str> =
4343
allowed_special.iter().map(|s| s.as_ref()).collect();
44-
self.encode(text, &allowed_special).0
44+
match self.encode(text, &allowed_special) {
45+
Ok((tokens, _)) => Ok(tokens),
46+
Err(e) => Err(PyErr::new::<exceptions::PyValueError, _>(e.message)),
47+
}
4548
})
4649
}
4750

@@ -50,14 +53,20 @@ impl CoreBPE {
5053
py: Python,
5154
text: &str,
5255
allowed_special: HashSet<PyBackedStr>,
53-
) -> Py<PyAny> {
54-
let tokens = py.allow_threads(|| {
56+
) -> PyResult<Py<PyAny>> {
57+
let tokens_res = py.allow_threads(|| {
5558
let allowed_special: HashSet<&str> =
5659
allowed_special.iter().map(|s| s.as_ref()).collect();
57-
self.encode(text, &allowed_special).0
60+
self.encode(text, &allowed_special)
5861
});
62+
63+
let tokens = match tokens_res {
64+
Ok((tokens, _)) => tokens,
65+
Err(e) => return Err(PyErr::new::<exceptions::PyValueError, _>(e.message)),
66+
};
67+
5968
let buffer = TiktokenBuffer { tokens };
60-
buffer.into_py(py)
69+
buffer.into_py_any(py)
6170
}
6271

6372
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
@@ -69,7 +78,8 @@ impl CoreBPE {
6978
// Unicode space, so we make our best guess at where we would have splits
7079
Err(e) => {
7180
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
72-
let (tokens, last_piece_token_len) = self.encode(text, &HashSet::new());
81+
let (tokens, last_piece_token_len) =
82+
self.encode(text, &HashSet::new()).unwrap();
7383
let (mut tokens, last_piece_token_len) =
7484
self._increase_last_piece_token_len(tokens, last_piece_token_len);
7585

@@ -110,19 +120,14 @@ impl CoreBPE {
110120
py: Python,
111121
text: &str,
112122
allowed_special: HashSet<PyBackedStr>,
113-
) -> Py<PyTuple> {
114-
let (tokens, completions) = py.allow_threads(|| {
123+
) -> PyResult<(Vec<Rank>, Py<PyList>)> {
124+
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.allow_threads(|| {
115125
let allowed_special: HashSet<&str> =
116126
allowed_special.iter().map(|s| s.as_ref()).collect();
117127
self._encode_unstable_native(text, &allowed_special)
118128
});
119-
let py_completions = PyList::new_bound(
120-
py,
121-
completions
122-
.iter()
123-
.map(|seq| PyList::new_bound(py, &seq[..])),
124-
);
125-
(tokens, py_completions).into_py(py)
129+
let py_completions = PyList::new(py, completions.into_iter())?;
130+
Ok((tokens, py_completions.into()))
126131
}
127132

128133
fn encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
@@ -151,17 +156,17 @@ impl CoreBPE {
151156
#[pyo3(name = "decode_bytes")]
152157
fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
153158
match py.allow_threads(|| self.decode_bytes(&tokens)) {
154-
Ok(bytes) => Ok(PyBytes::new_bound(py, &bytes).into()),
159+
Ok(bytes) => Ok(PyBytes::new(py, &bytes).into()),
155160
Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
156161
}
157162
}
158163

159164
fn decode_single_token_bytes(&self, py: Python, token: Rank) -> PyResult<Py<PyBytes>> {
160165
if let Some(bytes) = self.decoder.get(&token) {
161-
return Ok(PyBytes::new_bound(py, bytes).into());
166+
return Ok(PyBytes::new(py, bytes).into());
162167
}
163168
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
164-
return Ok(PyBytes::new_bound(py, bytes).into());
169+
return Ok(PyBytes::new(py, bytes).into());
165170
}
166171
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
167172
}

0 commit comments

Comments
 (0)