Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_nlp/models/f_net/f_net_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

"""FNet backbone model."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.f_net_encoder import FNetEncoder
from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.f_net.f_net_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty


def f_net_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -209,3 +213,7 @@ def get_config(self):
"name": self.name,
"trainable": self.trainable,
}

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
7 changes: 7 additions & 0 deletions keras_nlp/models/f_net/f_net_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.
"""FNet preprocessor layer."""

import copy

from tensorflow import keras

from keras_nlp.layers.multi_segment_packer import MultiSegmentPacker
from keras_nlp.models.f_net.f_net_presets import backbone_presets
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
Expand Down Expand Up @@ -177,3 +180,7 @@ def call(self, x, y=None, sample_weight=None):
@classproperty
def tokenizer_cls(cls):
return FNetTokenizer

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
55 changes: 55 additions & 0 deletions keras_nlp/models/f_net/f_net_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FNet model preset configurations."""

backbone_presets = {
"f_net_base_en": {
"config": {
"vocabulary_size": 32000,
"num_layers": 12,
"hidden_dim": 768,
"intermediate_dim": 3072,
"dropout": 0.1,
"max_sequence_length": 512,
"num_segments": 4,
},
"preprocessor_config": {},
"description": (
"Base size of FNet. Trained on the C4 dataset (English)."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/f_net_base_en/v1/model.h5",
"weights_hash": "35db90842b85a985a0e54c86c00746fe",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/f_net_base_en/v1/vocab.spm",
"spm_proto_hash": "71c5f4610bef1daf116998a113a01f3d",
},
"f_net_large_en": {
"config": {
"vocabulary_size": 32000,
"num_layers": 24,
"hidden_dim": 1024,
"intermediate_dim": 4096,
"dropout": 0.1,
"max_sequence_length": 512,
"num_segments": 4,
},
"preprocessor_config": {},
"description": (
"Large size of FNet. Trained on the C4 dataset (English)."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/f_net_large_en/v1/model.h5",
"weights_hash": "7ae4a3faa67ff054f8cecffb5619f779",
"spm_proto_url": "https://storage.googleapis.com/keras-nlp/models/f_net_large_en/v1/vocab.spm",
"spm_proto_hash": "71c5f4610bef1daf116998a113a01f3d",
},
}
123 changes: 123 additions & 0 deletions keras_nlp/models/f_net/f_net_presets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for loading pretrained model presets."""

import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer


@pytest.mark.large
class FNetPresetSmokeTest(tf.test.TestCase, parameterized.TestCase):
"""
A smoke test for FNet presets we run continuously.

This only tests the smallest weights we have available. Run with:
`pytest keras_nlp/models/f_net/f_net_presets_test.py --run_large`
"""

def test_tokenizer_output(self):
tokenizer = FNetTokenizer.from_preset(
"f_net_base_en",
)
outputs = tokenizer("The quick brown fox.")
expected_outputs = [97, 1467, 5187, 26, 2521, 16678]
self.assertAllEqual(outputs, expected_outputs)

def test_preprocessor_output(self):
preprocessor = FNetPreprocessor.from_preset(
"f_net_base_en",
sequence_length=4,
)
outputs = preprocessor("The quick brown fox.")["token_ids"]
expected_outputs = [4, 97, 1467, 5]
self.assertAllEqual(outputs, expected_outputs)

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_backbone_output(self, load_weights):
input_data = {
"token_ids": tf.constant([[4, 97, 1467, 5]]),
"segment_ids": tf.constant([[0, 0, 0, 0]]),
}
model = FNetBackbone.from_preset(
"f_net_base_en", load_weights=load_weights
)
outputs = model(input_data)
if load_weights:
outputs = outputs["sequence_output"][0, 0, :5]
expected = [4.182479, -0.072181, -0.138097, -0.036582, -0.521765]
self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01)

@parameterized.named_parameters(
("f_net_tokenizer", FNetTokenizer),
("f_net_preprocessor", FNetPreprocessor),
("f_net", FNetBackbone),
)
def test_preset_docstring(self, cls):
"""Check we did our docstring formatting correctly."""
for name in cls.presets:
self.assertRegex(cls.from_preset.__doc__, name)

@parameterized.named_parameters(
("f_net_tokenizer", FNetTokenizer),
("f_net_preprocessor", FNetPreprocessor),
("f_net", FNetBackbone),
)
def test_unknown_preset_error(self, cls):
# Not a preset name
with self.assertRaises(ValueError):
cls.from_preset("f_net_base_en_clowntown")


@pytest.mark.extra_large
class FNetPresetFullTest(tf.test.TestCase, parameterized.TestCase):
"""
Test the full enumeration of our preset.

This tests every FNet preset and is only run manually.
Run with:
`pytest keras_nlp/models/f_net/f_net_presets_test.py --run_extra_large`
"""

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_load_f_net(self, load_weights):
for preset in FNetBackbone.presets:
model = FNetBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
"token_ids": tf.random.uniform(
shape=(1, 512), dtype=tf.int64, maxval=model.vocabulary_size
),
"segment_ids": tf.constant(
[0] * 200 + [1] * 312, shape=(1, 512)
),
}
model(input_data)

def test_load_tokenizers(self):
for preset in FNetTokenizer.presets:
tokenizer = FNetTokenizer.from_preset(preset)
tokenizer("The quick brown fox.")

def test_load_preprocessors(self):
for preset in FNetPreprocessor.presets:
preprocessor = FNetPreprocessor.from_preset(preset)
preprocessor("The quick brown fox.")
49 changes: 47 additions & 2 deletions keras_nlp/models/f_net/f_net_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

"""FNet tokenizer."""

import copy
import os

from tensorflow import keras

from keras_nlp.models.f_net.f_net_presets import backbone_presets
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring


@keras.utils.register_keras_serializable(package="keras_nlp")
Expand Down Expand Up @@ -84,12 +88,53 @@ def __init__(self, proto, **kwargs):

@classproperty
def presets(cls):
return {}
return copy.deepcopy(backbone_presets)

@classmethod
@format_docstring(names=", ".join(backbone_presets))
def from_preset(
cls,
preset,
**kwargs,
):
raise NotImplementedError
"""Instantiate an FNet tokenizer from preset vocabulary.

Args:
preset: string. Must be one of {{names}}.

Examples:
```python
# Load a preset tokenizer.
tokenizer = keras_nlp.models.FNetTokenizer.from_preset(
"f_net_base_en",
)

# Tokenize some input.
tokenizer("The quick brown fox tripped.")

# Detokenize some input.
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)
metadata = cls.presets[preset]

spm_proto = keras.utils.get_file(
"vocab.spm",
metadata["spm_proto_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["spm_proto_hash"],
)

config = metadata["preprocessor_config"]
config.update(
{
"proto": spm_proto,
},
)

return cls.from_config({**config, **kwargs})