Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_nlp/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
Expand All @@ -19,9 +21,11 @@
from keras_nlp.models.mistral.mistral_layer_norm import (
MistralLayerNormalization,
)
from keras_nlp.models.mistral.mistral_presets import backbone_presets
from keras_nlp.models.mistral.mistral_transformer_decoder import (
MistralTransformerDecoder,
)
from keras_nlp.utils.python_utils import classproperty


def _mistral_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -196,3 +200,7 @@ def get_config(self):
}
)
return config

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
26 changes: 26 additions & 0 deletions keras_nlp/models/mistral/mistral_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,29 @@ def test_num_parameters(self):
model = MistralBackbone(**self.init_kwargs)
# Reference value calculated using the PyTorch model
self.assertEqual(model.count_params(), 2704)

@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=MistralBackbone,
preset="mistral_7b_en",
input_data={
"token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]),
"padding_mask": ops.ones((1, 6), dtype="int32"),
},
expected_output_shape=(1, 6, 4096),
# The forward pass from a preset should be stable!
# Reference values computed using PyTorch HF model.
expected_partial_output=ops.array(
[-1.6875, 0.5117, -1.7188, 2.3125, -0.0996]
),
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in MistralBackbone.presets:
self.run_preset_test(
cls=MistralBackbone,
preset=preset,
input_data=self.input_data,
)
6 changes: 6 additions & 0 deletions keras_nlp/models/mistral/mistral_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
Expand All @@ -20,6 +21,7 @@
from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
MistralCausalLMPreprocessor,
)
from keras_nlp.models.mistral.mistral_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -211,3 +213,7 @@ def next(prompt, cache, index):
"token_ids": token_ids,
"padding_mask": padding_mask,
}

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
14 changes: 14 additions & 0 deletions keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,20 @@ def generate_preprocess(
x,
sequence_length=None,
):
"""Covert strings to integer token input for generation.

Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.

Unlike calling the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
if not self.built:
self.build(None)

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
Expand Down
11 changes: 11 additions & 0 deletions keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import os

import pytest

from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import (
MistralCausalLMPreprocessor,
)
Expand Down Expand Up @@ -79,3 +81,12 @@ def test_generate_postprocess(self):
preprocessor = MistralCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "the quick brown fox")

@pytest.mark.extra_large
def test_all_presets(self):
for preset in MistralCausalLMPreprocessor.presets:
self.run_preset_test(
cls=MistralCausalLMPreprocessor,
preset=preset,
input_data=self.input_data,
)
20 changes: 16 additions & 4 deletions keras_nlp/models/mistral/mistral_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.models.mistral.mistral_presets import backbone_presets
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
Expand Down Expand Up @@ -121,15 +123,21 @@ def __init__(
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.packer = None
self.add_start_token = add_start_token
self.add_end_token = add_end_token
self.sequence_length = sequence_length

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.
self.packer = StartEndPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
sequence_length=sequence_length,
sequence_length=self.sequence_length,
return_padding_mask=True,
)
self.add_start_token = add_start_token
self.add_end_token = add_end_token
self.sequence_length = sequence_length
self.built = True

def get_config(self):
config = super().get_config()
Expand Down Expand Up @@ -184,3 +192,7 @@ def sequence_length(self, value):
@classproperty
def tokenizer_cls(cls):
return MistralTokenizer

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
11 changes: 11 additions & 0 deletions keras_nlp/models/mistral/mistral_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import os

import pytest

from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor
from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.tests.test_case import TestCase
Expand Down Expand Up @@ -57,3 +59,12 @@ def test_errors_for_2d_list_input(self):
ambiguous_input = [["one", "two"], ["three", "four"]]
with self.assertRaises(ValueError):
preprocessor(ambiguous_input)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in MistralPreprocessor.presets:
self.run_preset_test(
cls=MistralPreprocessor,
preset=preset,
input_data=self.input_data,
)
38 changes: 38 additions & 0 deletions keras_nlp/models/mistral/mistral_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mistral model preset configurations."""

# Metadata for loading pretrained model weights.
backbone_presets = {
"mistral_7b_en": {
"metadata": {
"description": "Mistral 7B base model",
"params": 7241732096,
"official_name": "Mistral",
"path": "mistral",
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3",
},
"mistral_instruct_7b_en": {
"metadata": {
"description": "Mistral 7B instruct model",
"params": 7241732096,
"official_name": "Mistral",
"path": "mistral",
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3",
},
}
8 changes: 8 additions & 0 deletions keras_nlp/models/mistral/mistral_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.mistral.mistral_presets import backbone_presets
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
from keras_nlp.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.models.MistralTokenizer")
Expand Down Expand Up @@ -77,3 +81,7 @@ def set_proto(self, proto):
else:
self.start_token_id = None
self.end_token_id = None

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
20 changes: 20 additions & 0 deletions keras_nlp/models/mistral/mistral_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import os

import pytest

from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.tests.test_case import TestCase

Expand Down Expand Up @@ -44,3 +46,21 @@ def test_errors_missing_special_tokens(self):
self.get_test_data_dir(), "no_special_token_vocab.spm"
)
)

@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=MistralTokenizer,
preset="mistral_7b_en",
input_data=["The quick brown fox."],
expected_output=[[415, 2936, 9060, 285, 1142, 28723]],
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in MistralTokenizer.presets:
self.run_preset_test(
cls=MistralTokenizer,
preset=preset,
input_data=self.input_data,
)
Loading