Skip to content

Commit 9485498

Browse files
authored
Partial sync of codebase (#451)
1 parent 00ff187 commit 9485498

File tree

5 files changed

+38
-28
lines changed

5 files changed

+38
-28
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ python = [
1414
]
1515

1616
[dependencies]
17-
pyo3 = { version = "0.26", default-features = false, features = [
17+
pyo3 = { version = "0.26.0", default-features = false, features = [
1818
"extension-module",
1919
"macros",
2020
], optional = true }
2121

2222
# tiktoken dependencies
23-
fancy-regex = "0.16"
23+
fancy-regex = "0.13.0"
2424
regex = "1.10.3"
2525
rustc-hash = "2"
2626
bstr = "1.5.0"

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ skip = [
3131
"*-manylinux_i686",
3232
"*-musllinux_i686",
3333
"*-win32",
34-
"*-musllinux_aarch64",
3534
]
3635
macos.archs = ["x86_64", "arm64"]
3736
# When cross-compiling on Intel, it is not possible to test arm64 wheels.

src/py.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl CoreBPE {
2828

2929
#[pyo3(name = "encode_ordinary")]
3030
fn py_encode_ordinary(&self, py: Python, text: &str) -> Vec<Rank> {
31-
py.allow_threads(|| self.encode_ordinary(text))
31+
py.detach(|| self.encode_ordinary(text))
3232
}
3333

3434
#[pyo3(name = "encode")]
@@ -38,7 +38,7 @@ impl CoreBPE {
3838
text: &str,
3939
allowed_special: HashSet<PyBackedStr>,
4040
) -> PyResult<Vec<Rank>> {
41-
py.allow_threads(|| {
41+
py.detach(|| {
4242
let allowed_special: HashSet<&str> =
4343
allowed_special.iter().map(|s| s.as_ref()).collect();
4444
match self.encode(text, &allowed_special) {
@@ -54,7 +54,7 @@ impl CoreBPE {
5454
text: &str,
5555
allowed_special: HashSet<PyBackedStr>,
5656
) -> PyResult<Py<PyAny>> {
57-
let tokens_res = py.allow_threads(|| {
57+
let tokens_res = py.detach(|| {
5858
let allowed_special: HashSet<&str> =
5959
allowed_special.iter().map(|s| s.as_ref()).collect();
6060
self.encode(text, &allowed_special)
@@ -70,7 +70,7 @@ impl CoreBPE {
7070
}
7171

7272
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<Rank> {
73-
py.allow_threads(|| {
73+
py.detach(|| {
7474
match std::str::from_utf8(bytes) {
7575
// Straightforward case
7676
Ok(text) => self.encode_ordinary(text),
@@ -121,7 +121,7 @@ impl CoreBPE {
121121
text: &str,
122122
allowed_special: HashSet<PyBackedStr>,
123123
) -> PyResult<(Vec<Rank>, Py<PyList>)> {
124-
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.allow_threads(|| {
124+
let (tokens, completions): (Vec<Rank>, HashSet<Vec<Rank>>) = py.detach(|| {
125125
let allowed_special: HashSet<&str> =
126126
allowed_special.iter().map(|s| s.as_ref()).collect();
127127
self._encode_unstable_native(text, &allowed_special)
@@ -155,7 +155,7 @@ impl CoreBPE {
155155

156156
#[pyo3(name = "decode_bytes")]
157157
fn py_decode_bytes(&self, py: Python, tokens: Vec<Rank>) -> Result<Py<PyBytes>, PyErr> {
158-
match py.allow_threads(|| self.decode_bytes(&tokens)) {
158+
match py.detach(|| self.decode_bytes(&tokens)) {
159159
Ok(bytes) => Ok(PyBytes::new(py, &bytes).into()),
160160
Err(e) => Err(pyo3::exceptions::PyKeyError::new_err(format!("{}", e))),
161161
}

tiktoken/core.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from concurrent.futures import ThreadPoolExecutor
55
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence
66

7-
import regex
8-
97
from tiktoken import _tiktoken
108

119
if TYPE_CHECKING:
10+
import re
11+
1212
import numpy as np
1313
import numpy.typing as npt
1414

@@ -391,6 +391,9 @@ def _encode_single_piece(self, text_or_bytes: str | bytes) -> list[int]:
391391

392392
def _encode_only_native_bpe(self, text: str) -> list[int]:
393393
"""Encodes a string into tokens, but do regex splitting in Python."""
394+
# We need specifically `regex` in order to compile pat_str due to e.g. \p
395+
import regex
396+
394397
_unused_pat = regex.compile(self._pat_str)
395398
ret = []
396399
for piece in regex.findall(_unused_pat, text):
@@ -423,9 +426,13 @@ def __setstate__(self, value: object) -> None:
423426

424427

425428
@functools.lru_cache(maxsize=128)
426-
def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]":
427-
inner = "|".join(regex.escape(token) for token in tokens)
428-
return regex.compile(f"({inner})")
429+
def _special_token_regex(tokens: frozenset[str]) -> re.Pattern[str]:
430+
try:
431+
import regex as re
432+
except ImportError:
433+
import re
434+
inner = "|".join(re.escape(token) for token in tokens)
435+
return re.compile(f"({inner})")
429436

430437

431438
def raise_disallowed_special_token(token: str) -> NoReturn:

tiktoken/load.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@
66

77

88
def read_file(blobpath: str) -> bytes:
9-
if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
10-
try:
11-
import blobfile
12-
except ImportError as e:
13-
raise ImportError(
14-
"blobfile is not installed. Please install it by running `pip install blobfile`."
15-
) from e
16-
with blobfile.BlobFile(blobpath, "rb") as f:
9+
if "://" not in blobpath:
10+
with open(blobpath, "rb", buffering=0) as f:
1711
return f.read()
1812

19-
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
20-
import requests
13+
if blobpath.startswith(("http://", "https://")):
14+
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
15+
import requests
16+
17+
resp = requests.get(blobpath)
18+
resp.raise_for_status()
19+
return resp.content
2120

22-
resp = requests.get(blobpath)
23-
resp.raise_for_status()
24-
return resp.content
21+
try:
22+
import blobfile
23+
except ImportError as e:
24+
raise ImportError(
25+
"blobfile is not installed. Please install it by running `pip install blobfile`."
26+
) from e
27+
with blobfile.BlobFile(blobpath, "rb") as f:
28+
return f.read()
2529

2630

2731
def check_hash(data: bytes, expected_hash: str) -> bool:
@@ -49,7 +53,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:
4953

5054
cache_path = os.path.join(cache_dir, cache_key)
5155
if os.path.exists(cache_path):
52-
with open(cache_path, "rb") as f:
56+
with open(cache_path, "rb", buffering=0) as f:
5357
data = f.read()
5458
if expected_hash is None or check_hash(data, expected_hash):
5559
return data

0 commit comments

Comments
 (0)