Skip to content

Commit

Permalink
Support llama3 tokenizer (#67)
Browse files Browse the repository at this point in the history
* Support llama3 tokenizer

* Add tiktoken to requirements

* Add blobfile to requirements

* Fix unit tests

* Fix linting issues

* Fix pytype errors

* Move llama3 tokenizer to third_party directory

* Fix pytype error

* Update pytype command
  • Loading branch information
bhavya01 authored and jwyang-google committed May 6, 2024
1 parent 9093565 commit 22a8a24
Show file tree
Hide file tree
Showing 11 changed files with 128,676 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
pip install -r benchmarks/requirements.in
- name: Typecheck the code with pytype
run: |
pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/
pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/
- name: Analysing the code with pylint
run: |
pylint jetstream/ benchmarks/
Expand Down
148 changes: 130 additions & 18 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from jetstream.engine import mock_utils
from jetstream.engine import tokenizer_api
from jetstream.engine import tokenizer_pb2
from jetstream.third_party.llama3 import llama3_tokenizer

# ResultToken class to store tokens ids.
ResultTokens = Any
Expand All @@ -40,9 +41,10 @@ def take_nearest_length(lengths: list[int], length: int) -> int:
return lengths[pos]


def tokenize_and_pad(
s: str,
vocab: Vocabulary,
def pad_tokens(
tokens: np.ndarray,
bos_id: int,
pad_id: int,
is_bos: bool = True,
prefill_lengths: Optional[List[int]] = None,
max_prefill_length: Optional[int] = None,
Expand Down Expand Up @@ -84,14 +86,13 @@ def tokenize_and_pad(
] + [
max_prefill_length,
]
tokens = np.array(vocab.encode_tf(s)) # [Length]
# Add a beginning of sequence token if this is the beginning.
if is_bos:
tokens = np.concatenate(
[
np.array(
[
vocab.bos_id,
bos_id,
]
),
tokens,
Expand All @@ -101,13 +102,12 @@ def tokenize_and_pad(
true_length = tokens.shape[-1]
padded_length = take_nearest_length(prefill_lengths, true_length)
padding = padded_length - true_length
assert vocab.pad_id == 0, "Further logic required if pad_id not 0."
if padding < 0:
logging.warning("Provided sequence longer than available.")
# Take the last N tokens if we have too many.
padded_tokens = tokens[-padded_length:]
else:
padded_tokens = np.pad(tokens, (0, padding))
padded_tokens = np.pad(tokens, (0, padding), constant_values=(pad_id,))
if jax_padding:
padded_tokens = jnp.array(padded_tokens)
return padded_tokens, true_length
Expand All @@ -117,7 +117,8 @@ def process_result_tokens(
slot: int,
slot_max_length: int,
result_tokens: ResultTokens,
vocab: Vocabulary,
eos_id: int,
pad_id: int,
complete: np.ndarray,
debug: bool = False,
) -> Tuple[List[List[int]], np.ndarray]:
Expand All @@ -128,7 +129,8 @@ def process_result_tokens(
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
vocab: For the detokenizer.
eos_id: Id for EOS token.
pad_id: Id for pad token.
complete: Array representing the completion status of each sample in the
slot.
debug: Whether to log step by step detokenisation.
Expand All @@ -143,7 +145,7 @@ def process_result_tokens(
slot_valid = slot_data.valid
slot_lengths = slot_data.lengths
samples, speculations = slot_tokens.shape
stop_tokens = [vocab.eos_id, vocab.pad_id]
stop_tokens = [eos_id, pad_id]
# Stop anything which has reached it's max length.
complete = complete | (slot_lengths > slot_max_length)
if debug:
Expand Down Expand Up @@ -212,11 +214,9 @@ def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
Expand All @@ -225,13 +225,18 @@ def encode(
is_bos = kwargs.pop("is_bos", True)
prefill_lengths = kwargs.pop("prefill_lengths", None)
max_prefill_length = kwargs.pop("max_prefill_length", None)
jax_padding = kwargs.pop("jax_padding", True)

tokens = np.array(self.vocab.encode_tf(s))

tokens, true_length = tokenize_and_pad(
s,
self.vocab,
tokens, true_length = pad_tokens(
tokens,
self.bos_id,
self.pad_id,
is_bos=is_bos,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
jax_padding=jax_padding,
)
return tokens, true_length

Expand All @@ -245,15 +250,13 @@ def decode(
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.
Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
kwargs: Additional keyword arguments.
Returns:
sample_return: List of strings, one per sample.
complete: Updated complete.
Expand All @@ -263,12 +266,22 @@ def decode(
slot=slot,
slot_max_length=slot_max_length,
result_tokens=result_tokens,
vocab=self.vocab,
eos_id=self.eos_id,
pad_id=self.pad_id,
complete=complete,
debug=debug,
)
return results, complete

def decode_str(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
str: String generated from the token ids.
"""
return self.vocab.tokenizer.decode(token_ids)

@property
def pad_id(self) -> int:
"""ID of the pad token."""
Expand All @@ -278,3 +291,102 @@ def pad_id(self) -> int:
def eos_id(self) -> int:
"""ID of EOS token."""
return self.vocab.eos_id

@property
def bos_id(self) -> int:
"""ID of the BOS token."""
return self.vocab.bos_id


class TikToken(tokenizer_api.Tokenizer):
"""Tokenizer to convert strings to token ids and vice-versa."""

def __init__(self, metadata: tokenizer_pb2.TokenizerParameters):
self.tokenizer = llama3_tokenizer.Tokenizer(metadata.path)

def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
if padding is used.
"""
is_bos = kwargs.pop("is_bos", True)
prefill_lengths = kwargs.pop("prefill_lengths", None)
max_prefill_length = kwargs.pop("max_prefill_length", None)
jax_padding = kwargs.pop("jax_padding", True)

tokens = np.array(self.tokenizer.encode(s, bos=False, eos=False))

tokens, true_length = pad_tokens(
tokens,
self.bos_id,
self.pad_id,
is_bos=is_bos,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
jax_padding=jax_padding,
)
return tokens, true_length

def decode(
self,
slot: int,
slot_max_length: int,
result_tokens: ResultTokens,
complete: np.ndarray,
**kwargs,
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.
Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
kwargs: Additional keyword arguments.
Returns:
sample_return: List of strings, one per sample.
complete: Updated complete.
"""
debug = kwargs.pop("debug", False)
results, complete = process_result_tokens(
slot=slot,
slot_max_length=slot_max_length,
result_tokens=result_tokens,
eos_id=self.eos_id,
pad_id=self.pad_id,
complete=complete,
debug=debug,
)
return results, complete

def decode_str(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
str: String generated from the token ids.
"""
return self.tokenizer.decode(token_ids)

@property
def pad_id(self) -> int:
"""ID of the pad token."""
return self.tokenizer.pad_id

@property
def eos_id(self) -> int:
"""ID of EOS token."""
return self.tokenizer.eos_id

@property
def bos_id(self) -> int:
"""ID of the BOS token."""
return self.tokenizer.bos_id
21 changes: 15 additions & 6 deletions jetstream/engine/tokenizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
Expand All @@ -54,20 +52,26 @@ def decode(
) -> Tuple[list[list[int]], np.ndarray]:
"""Processes a result tokens into a list of token ids, handling multiple
samples.
Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
**kwards: Additional keyword arguments.
Returns:
sample_return: List of strings, one per sample.
sample_return: List of token_ids, one per sample.
complete: Updated complete.
"""
# TODO(bbahl): Add an option to return str from decode.

@abc.abstractmethod
def decode_str(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
str: String generated from the token ids.
"""

@property
@abc.abstractmethod
Expand All @@ -78,3 +82,8 @@ def pad_id(self) -> int:
@abc.abstractmethod
def eos_id(self) -> int:
"""ID of EOS token."""

@property
@abc.abstractmethod
def bos_id(self) -> int:
"""ID of BOS token."""
10 changes: 4 additions & 6 deletions jetstream/tests/engine/test_mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def _prefill(self):
# A 2 will be pre-pended as 'bos' token from the vocab.
text = "AB"
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True)
tokenizer = engine.build_tokenizer(metadata)
tokens, true_length = tokenizer.encode(text, is_bos=True)
prefill_result = engine.prefill(
params=params, padded_tokens=tokens, true_length=3
)
Expand All @@ -65,10 +65,8 @@ def _prefill_np(self):
# A 2 will be pre-pended as 'bos' token from the vocab.
text = "AB"
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, jax_padding=False
)
tokenizer = engine.build_tokenizer(metadata)
tokens, true_length = tokenizer.encode(text, is_bos=True, jax_padding=False)
prefill_result = engine.prefill(
params=params, padded_tokens=tokens, true_length=3
)
Expand Down
Loading

0 comments on commit 22a8a24

Please sign in to comment.