Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ class GPT2CausalLM(Task):

"""

def __init__(self, backbone, preprocessor=None, **kwargs):
def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
inputs = backbone.input
x = backbone(inputs)
# Use token embedding weights to project from the token representation
Expand Down
28 changes: 25 additions & 3 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class GPT2CausalLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
vocab = {
self.vocab = {
"<|endoftext|>": 0,
"!": 1,
"air": 2,
Expand All @@ -47,11 +47,12 @@ def setUp(self):
merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
merges += ["Ġai r", "Ġa i", "pla ne"]
self.merges = merges

self.preprocessor = GPT2CausalLMPreprocessor(
tokenizer=GPT2Tokenizer(
vocabulary=vocab,
merges=merges,
vocabulary=self.vocab,
merges=self.merges,
),
sequence_length=8,
)
Expand All @@ -77,6 +78,27 @@ def test_tokenize_list_of_strings(self):
self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0]] * 4)

def test_pad_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = GPT2CausalLMPreprocessor(
tokenizer=GPT2Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
),
sequence_length=8,
add_start_token=True,
add_end_token=True,
)
x, y, sw = preprocessor(input_data)
self.assertAllEqual(
x["token_ids"],
[[0, 2, 4, 5, 3, 6, 0]] * 4,
)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1]] * 4)
self.assertAllEqual(y, [[2, 4, 5, 3, 6, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0]] * 4)

def test_tokenize_labeled_batch(self):
x = tf.constant(["airplane at airport"] * 4)
y = tf.constant([1] * 4)
Expand Down
39 changes: 33 additions & 6 deletions keras_nlp/models/gpt2/gpt2_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,27 @@ class GPT2Preprocessor(Preprocessor):
Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
sequence_length: The length of the packed inputs.
add_start_token: If true, the preprocessor will append the tokenizer
start token to each input sequence.
add_end_token: If true, the preprocessor will append the tokenizer
end token to each input sequence.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset("gpt2_base_en")

# Tokenize and pack a single sentence.
sentence = tf.constant("league of legends")
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("league of legends")
preprocessor("League of legends")

# Tokenize a batch of sentences.
sentences = tf.constant(["taco tuesday", "gi gi gi gi"])
sentences = tf.constant(["Taco tuesday", "Fish taco!"])
preprocessor(sentences)
# Same output.
preprocessor(["taco tuesday", "gi gi gi gi"])
preprocessor(["Taco tuesday", "Fish taco!"])

# Map a dataset to preprocess a single sentence.
features = tf.constant(
Expand Down Expand Up @@ -119,13 +123,17 @@ def __init__(
self,
tokenizer,
sequence_length,
add_start_token=False,
add_end_token=False,
**kwargs,
):

super().__init__(**kwargs)

self._tokenizer = tokenizer
self.sequence_length = sequence_length
self.add_start_token = add_start_token
self.add_end_token = add_end_token

def get_config(self):
config = super().get_config()
Expand All @@ -149,9 +157,28 @@ def call(self, x, y=None, sample_weight=None):
input_is_1d = len(token_ids.shape) == 1
if input_is_1d:
token_ids = tf.RaggedTensor.from_tensor([token_ids])
if self.add_start_token:
start_tokens = tf.fill(
[tf.shape(token_ids)[0], 1],
self.tokenizer.start_token_id,
)
token_ids = tf.concat([start_tokens, token_ids], axis=1)
if self.add_end_token:
end_tokens = tf.fill(
[tf.shape(token_ids)[0], 1],
self.tokenizer.end_token_id,
)
token_ids = tf.concat([token_ids, end_tokens], axis=1)
mask = tf.ones_like(token_ids, dtype=tf.bool)
mask = mask.to_tensor(shape=(None, self.sequence_length))
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
shape_after_padding = tf.stack(
[tf.constant(-1), self.sequence_length],
axis=0,
)
mask = mask.to_tensor(shape=shape_after_padding)
token_ids = token_ids.to_tensor(
shape=shape_after_padding,
default_value=self.tokenizer.pad_token_id,
)
if input_is_1d:
# If the input is a single string, we let the output be a 1D tensor.
token_ids = tf.squeeze(token_ids, axis=0)
Expand Down
29 changes: 26 additions & 3 deletions keras_nlp/models/gpt2/gpt2_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class GPT2PreprocessorTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
vocab = {
self.vocab = {
"<|endoftext|>": 0,
"!": 1,
"air": 2,
Expand All @@ -45,11 +45,12 @@ def setUp(self):
merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
merges += ["Ġai r", "Ġa i", "pla ne"]
self.merges = merges

self.preprocessor = GPT2Preprocessor(
tokenizer=GPT2Tokenizer(
vocabulary=vocab,
merges=merges,
vocabulary=self.vocab,
merges=self.merges,
),
sequence_length=8,
)
Expand All @@ -74,6 +75,28 @@ def test_tokenize_list_of_strings(self):
output["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4
)

def test_pad_start_end_token(self):
input_data = ["airplane at airport"] * 4

preprocessor = GPT2Preprocessor(
tokenizer=GPT2Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
),
sequence_length=8,
add_start_token=True,
add_end_token=True,
)
output = preprocessor(input_data)
self.assertAllEqual(
output["token_ids"],
[[0, 2, 4, 5, 3, 6, 0, 0]] * 4,
)

self.assertAllEqual(
output["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4
)

def test_tokenize_labeled_batch(self):
x = tf.constant(["airplane at airport"] * 4)
y = tf.constant([1] * 4)
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/models/gpt2/gpt2_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def __init__(
)

self.end_token_id = self.token_to_id(end_token)
# GPT2 uses the same start and pad token as end token, i.e.,
# "<|endoftext|>".
self.start_token_id = self.end_token_id
self.pad_token_id = self.end_token_id

@classproperty
def presets(cls):
Expand Down