Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support llama3 tokenizer #67

Merged
merged 9 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
pip install -r benchmarks/requirements.in
- name: Typecheck the code with pytype
run: |
pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/
pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/
- name: Analysing the code with pylint
run: |
pylint jetstream/ benchmarks/
Expand Down
148 changes: 130 additions & 18 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from jetstream.engine import mock_utils
from jetstream.engine import tokenizer_api
from jetstream.engine import tokenizer_pb2
from jetstream.third_party.llama3 import llama3_tokenizer

# ResultToken class to store tokens ids.
ResultTokens = Any
Expand All @@ -40,9 +41,10 @@ def take_nearest_length(lengths: list[int], length: int) -> int:
return lengths[pos]


def tokenize_and_pad(
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
s: str,
vocab: Vocabulary,
def pad_tokens(
tokens: np.ndarray,
bos_id: int,
pad_id: int,
is_bos: bool = True,
prefill_lengths: Optional[List[int]] = None,
max_prefill_length: Optional[int] = None,
Expand Down Expand Up @@ -84,14 +86,13 @@ def tokenize_and_pad(
] + [
max_prefill_length,
]
tokens = np.array(vocab.encode_tf(s)) # [Length]
# Add a beginning of sequence token if this is the beginning.
if is_bos:
tokens = np.concatenate(
[
np.array(
[
vocab.bos_id,
bos_id,
]
),
tokens,
Expand All @@ -101,13 +102,12 @@ def tokenize_and_pad(
true_length = tokens.shape[-1]
padded_length = take_nearest_length(prefill_lengths, true_length)
padding = padded_length - true_length
assert vocab.pad_id == 0, "Further logic required if pad_id not 0."
if padding < 0:
logging.warning("Provided sequence longer than available.")
# Take the last N tokens if we have too many.
padded_tokens = tokens[-padded_length:]
else:
padded_tokens = np.pad(tokens, (0, padding))
padded_tokens = np.pad(tokens, (0, padding), constant_values=(pad_id,))
if jax_padding:
padded_tokens = jnp.array(padded_tokens)
return padded_tokens, true_length
Expand All @@ -117,7 +117,8 @@ def process_result_tokens(
slot: int,
slot_max_length: int,
result_tokens: ResultTokens,
vocab: Vocabulary,
eos_id: int,
pad_id: int,
complete: np.ndarray,
debug: bool = False,
) -> Tuple[List[List[int]], np.ndarray]:
Expand All @@ -128,7 +129,8 @@ def process_result_tokens(
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
vocab: For the detokenizer.
eos_id: Id for EOS token.
pad_id: Id for pad token.
complete: Array representing the completion status of each sample in the
slot.
debug: Whether to log step by step detokenisation.
Expand All @@ -143,7 +145,7 @@ def process_result_tokens(
slot_valid = slot_data.valid
slot_lengths = slot_data.lengths
samples, speculations = slot_tokens.shape
stop_tokens = [vocab.eos_id, vocab.pad_id]
stop_tokens = [eos_id, pad_id]
# Stop anything which has reached it's max length.
complete = complete | (slot_lengths > slot_max_length)
if debug:
Expand Down Expand Up @@ -212,11 +214,9 @@ def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.

Args:
s: String to tokenize.
**kwargs: Additional keyword arguments

Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
Expand All @@ -225,13 +225,18 @@ def encode(
is_bos = kwargs.pop("is_bos", True)
prefill_lengths = kwargs.pop("prefill_lengths", None)
max_prefill_length = kwargs.pop("max_prefill_length", None)
jax_padding = kwargs.pop("jax_padding", True)

tokens = np.array(self.vocab.encode_tf(s))

tokens, true_length = tokenize_and_pad(
s,
self.vocab,
tokens, true_length = pad_tokens(
tokens,
self.bos_id,
self.pad_id,
is_bos=is_bos,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
jax_padding=jax_padding,
)
return tokens, true_length

Expand All @@ -245,15 +250,13 @@ def decode(
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.

Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
kwargs: Additional keyword arguments.

Returns:
sample_return: List of strings, one per sample.
complete: Updated complete.
Expand All @@ -263,12 +266,22 @@ def decode(
slot=slot,
slot_max_length=slot_max_length,
result_tokens=result_tokens,
vocab=self.vocab,
eos_id=self.eos_id,
pad_id=self.pad_id,
complete=complete,
debug=debug,
)
return results, complete

def decode_str(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
str: String generated from the token ids.
"""
return self.vocab.tokenizer.decode(token_ids)

@property
def pad_id(self) -> int:
"""ID of the pad token."""
Expand All @@ -278,3 +291,102 @@ def pad_id(self) -> int:
def eos_id(self) -> int:
"""ID of EOS token."""
return self.vocab.eos_id

@property
def bos_id(self) -> int:
"""ID of the BOS token."""
return self.vocab.bos_id


class TikToken(tokenizer_api.Tokenizer):
"""Tokenizer to convert strings to token ids and vice-versa."""

def __init__(self, metadata: tokenizer_pb2.TokenizerParameters):
self.tokenizer = llama3_tokenizer.Tokenizer(metadata.path)

def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
if padding is used.
"""
is_bos = kwargs.pop("is_bos", True)
prefill_lengths = kwargs.pop("prefill_lengths", None)
max_prefill_length = kwargs.pop("max_prefill_length", None)
jax_padding = kwargs.pop("jax_padding", True)

tokens = np.array(self.tokenizer.encode(s, bos=False, eos=False))

tokens, true_length = pad_tokens(
tokens,
self.bos_id,
self.pad_id,
is_bos=is_bos,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
jax_padding=jax_padding,
)
return tokens, true_length

def decode(
self,
slot: int,
slot_max_length: int,
result_tokens: ResultTokens,
complete: np.ndarray,
**kwargs,
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.
Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
kwargs: Additional keyword arguments.
Returns:
sample_return: List of strings, one per sample.
complete: Updated complete.
"""
debug = kwargs.pop("debug", False)
results, complete = process_result_tokens(
slot=slot,
slot_max_length=slot_max_length,
result_tokens=result_tokens,
eos_id=self.eos_id,
pad_id=self.pad_id,
complete=complete,
debug=debug,
)
return results, complete

def decode_str(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
str: String generated from the token ids.
"""
return self.tokenizer.decode(token_ids)

@property
def pad_id(self) -> int:
"""ID of the pad token."""
return self.tokenizer.pad_id

@property
def eos_id(self) -> int:
"""ID of EOS token."""
return self.tokenizer.eos_id

@property
def bos_id(self) -> int:
"""ID of the BOS token."""
return self.tokenizer.bos_id
21 changes: 15 additions & 6 deletions jetstream/engine/tokenizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.

Args:
s: String to tokenize.
**kwargs: Additional keyword arguments

Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
Expand All @@ -54,20 +52,26 @@ def decode(
) -> Tuple[list[list[int]], np.ndarray]:
"""Processes a result tokens into a list of token ids, handling multiple
samples.

Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
**kwards: Additional keyword arguments.

Returns:
sample_return: List of strings, one per sample.
sample_return: List of token_ids, one per sample.
complete: Updated complete.
"""
# TODO(bbahl): Add an option to return str from decode.

@abc.abstractmethod
def decode_str(self, token_ids: list[int]) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
Returns:
str: String generated from the token ids.
"""

@property
@abc.abstractmethod
Expand All @@ -78,3 +82,8 @@ def pad_id(self) -> int:
@abc.abstractmethod
def eos_id(self) -> int:
"""ID of EOS token."""

@property
@abc.abstractmethod
def bos_id(self) -> int:
"""ID of BOS token."""
10 changes: 4 additions & 6 deletions jetstream/tests/engine/test_mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def _prefill(self):
# A 2 will be pre-pended as 'bos' token from the vocab.
text = "AB"
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True)
tokenizer = engine.build_tokenizer(metadata)
tokens, true_length = tokenizer.encode(text, is_bos=True)
prefill_result = engine.prefill(
params=params, padded_tokens=tokens, true_length=3
)
Expand All @@ -65,10 +65,8 @@ def _prefill_np(self):
# A 2 will be pre-pended as 'bos' token from the vocab.
text = "AB"
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, jax_padding=False
)
tokenizer = engine.build_tokenizer(metadata)
tokens, true_length = tokenizer.encode(text, is_bos=True, jax_padding=False)
prefill_result = engine.prefill(
params=params, padded_tokens=tokens, true_length=3
)
Expand Down
Loading
Loading