diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index b1078503e0..a7e33a11be 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -161,7 +161,12 @@ class GPT2CausalLM(Task): """ - def __init__(self, backbone, preprocessor=None, **kwargs): + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): inputs = backbone.input x = backbone(inputs) # Use token embedding weights to project from the token representation diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py index 7ad2cce988..43b576fb0d 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -28,7 +28,7 @@ class GPT2CausalLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): - vocab = { + self.vocab = { "<|endoftext|>": 0, "!": 1, "air": 2, @@ -47,11 +47,12 @@ def setUp(self): merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] merges += ["Ġai r", "Ġa i", "pla ne"] + self.merges = merges self.preprocessor = GPT2CausalLMPreprocessor( tokenizer=GPT2Tokenizer( - vocabulary=vocab, - merges=merges, + vocabulary=self.vocab, + merges=self.merges, ), sequence_length=8, ) @@ -77,6 +78,27 @@ def test_tokenize_list_of_strings(self): self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0]] * 4) self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0]] * 4) + def test_pad_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = GPT2CausalLMPreprocessor( + tokenizer=GPT2Tokenizer( + vocabulary=self.vocab, + merges=self.merges, + ), + sequence_length=8, + add_start_token=True, + add_end_token=True, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], + [[0, 2, 4, 5, 3, 6, 0]] * 4, + ) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1]] * 4) + self.assertAllEqual(y, [[2, 4, 5, 3, 6, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0]] * 4) + def test_tokenize_labeled_batch(self): x = tf.constant(["airplane at airport"] * 4) y = tf.constant([1] * 4) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 5f6090a2d7..635fd20c71 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -58,6 +58,10 @@ class GPT2Preprocessor(Preprocessor): Args: tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. sequence_length: The length of the packed inputs. + add_start_token: If true, the preprocessor will append the tokenizer + start token to each input sequence. + add_end_token: If true, the preprocessor will append the tokenizer + end token to each input sequence. Examples: ```python @@ -65,16 +69,16 @@ class GPT2Preprocessor(Preprocessor): preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset("gpt2_base_en") # Tokenize and pack a single sentence. - sentence = tf.constant("league of legends") + sentence = tf.constant("League of legends") preprocessor(sentence) # Same output. - preprocessor("league of legends") + preprocessor("League of legends") # Tokenize a batch of sentences. - sentences = tf.constant(["taco tuesday", "gi gi gi gi"]) + sentences = tf.constant(["Taco tuesday", "Fish taco!"]) preprocessor(sentences) # Same output. - preprocessor(["taco tuesday", "gi gi gi gi"]) + preprocessor(["Taco tuesday", "Fish taco!"]) # Map a dataset to preprocess a single sentence. features = tf.constant( @@ -119,6 +123,8 @@ def __init__( self, tokenizer, sequence_length, + add_start_token=False, + add_end_token=False, **kwargs, ): @@ -126,6 +132,8 @@ def __init__( self._tokenizer = tokenizer self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token def get_config(self): config = super().get_config() @@ -149,9 +157,28 @@ def call(self, x, y=None, sample_weight=None): input_is_1d = len(token_ids.shape) == 1 if input_is_1d: token_ids = tf.RaggedTensor.from_tensor([token_ids]) + if self.add_start_token: + start_tokens = tf.fill( + [tf.shape(token_ids)[0], 1], + self.tokenizer.start_token_id, + ) + token_ids = tf.concat([start_tokens, token_ids], axis=1) + if self.add_end_token: + end_tokens = tf.fill( + [tf.shape(token_ids)[0], 1], + self.tokenizer.end_token_id, + ) + token_ids = tf.concat([token_ids, end_tokens], axis=1) mask = tf.ones_like(token_ids, dtype=tf.bool) - mask = mask.to_tensor(shape=(None, self.sequence_length)) - token_ids = token_ids.to_tensor(shape=(None, self.sequence_length)) + shape_after_padding = tf.stack( + [tf.constant(-1), self.sequence_length], + axis=0, + ) + mask = mask.to_tensor(shape=shape_after_padding) + token_ids = token_ids.to_tensor( + shape=shape_after_padding, + default_value=self.tokenizer.pad_token_id, + ) if input_is_1d: # If the input is a single string, we let the output be a 1D tensor. token_ids = tf.squeeze(token_ids, axis=0) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py index d06f8f9f45..2e5b7311d4 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py @@ -26,7 +26,7 @@ class GPT2PreprocessorTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): - vocab = { + self.vocab = { "<|endoftext|>": 0, "!": 1, "air": 2, @@ -45,11 +45,12 @@ def setUp(self): merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] merges += ["Ġt h", "ai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] merges += ["Ġai r", "Ġa i", "pla ne"] + self.merges = merges self.preprocessor = GPT2Preprocessor( tokenizer=GPT2Tokenizer( - vocabulary=vocab, - merges=merges, + vocabulary=self.vocab, + merges=self.merges, ), sequence_length=8, ) @@ -74,6 +75,28 @@ def test_tokenize_list_of_strings(self): output["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4 ) + def test_pad_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = GPT2Preprocessor( + tokenizer=GPT2Tokenizer( + vocabulary=self.vocab, + merges=self.merges, + ), + sequence_length=8, + add_start_token=True, + add_end_token=True, + ) + output = preprocessor(input_data) + self.assertAllEqual( + output["token_ids"], + [[0, 2, 4, 5, 3, 6, 0, 0]] * 4, + ) + + self.assertAllEqual( + output["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4 + ) + def test_tokenize_labeled_batch(self): x = tf.constant(["airplane at airport"] * 4) y = tf.constant([1] * 4) diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py index 9665391a0a..a739bf6fa2 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py @@ -112,6 +112,10 @@ def __init__( ) self.end_token_id = self.token_to_id(end_token) + # GPT2 uses the same start and pad token as end token, i.e., + # "<|endoftext|>". + self.start_token_id = self.end_token_id + self.pad_token_id = self.end_token_id @classproperty def presets(cls):