-
Notifications
You must be signed in to change notification settings - Fork 306
Add phi3 #1597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add phi3 #1597
Changes from 52 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
e2b569a
Add phi3
abuelnasr0 c963f7b
Add phi3 to init
abuelnasr0 0f30b5f
layer naming and some nits
abuelnasr0 1d18572
Decoder layers naming
abuelnasr0 3d24cb2
Remove bias from einsumdense
abuelnasr0 a98369a
nit fix for layernorm
abuelnasr0 aecd9e2
Add SuRotary embedding
abuelnasr0 368e5a2
Romve print()
abuelnasr0 4267257
Add conversion script
abuelnasr0 e946503
Nit fix in script
abuelnasr0 145864b
Add phi3_4k as default preset
abuelnasr0 c5c78ed
Fix Doc and nit changes
abuelnasr0 7b0def0
Nit in test
abuelnasr0 c78c482
Doc fix
abuelnasr0 cd0381a
Add length check for rope scaling factors
abuelnasr0 6f9108d
Calculate the mean of the absolute differnce in conversion script
abuelnasr0 8e99e04
Fix typo
abuelnasr0 b3ca8a3
Add tokenizer and preprocessor
abuelnasr0 b53a326
Format fix
abuelnasr0 0e37c9a
Fix dtype and device in conversion script
abuelnasr0 45ab340
Batch the input
abuelnasr0 a459038
Batch the input
abuelnasr0 9c38dec
Nit
abuelnasr0 07832e5
Add notify for upload
abuelnasr0 fc1cf0b
ADd causal_lm preprocessor
abuelnasr0 aac962d
Add causal lm
abuelnasr0 49103f3
Fix format
abuelnasr0 49a5495
small fixes
abuelnasr0 dcead36
Add phi3 to the new api
abuelnasr0 0d990f5
Api gen
abuelnasr0 ac8770d
Public named sublayers
abuelnasr0 1c2a70e
Publicc named sublayers in decoder layer
abuelnasr0 e63430a
Simplify dropout
abuelnasr0 9ac44b6
Fix tokenizer tests
abuelnasr0 6355e67
Fix conversion script
abuelnasr0 a206480
use preprocessor
abuelnasr0 968a220
use preprocessor
abuelnasr0 7c17bd1
Fix keras input
abuelnasr0 0c165a0
Fix keras model input
abuelnasr0 074ea69
Only validate with validate_dtype
abuelnasr0 d5c7fab
Only validate with validate_dtype
abuelnasr0 0368483
Change seq length
abuelnasr0 1eed34b
Change text
abuelnasr0 d048f2d
Set pad token id to 0
abuelnasr0 3af4096
Default stop at EOS and EOT
abuelnasr0 c9f0ad9
Add presets
abuelnasr0 5c5e4ef
Add presets and tests to tokenizer
abuelnasr0 5225c6f
Add prepreocessor preset tests
abuelnasr0 b02c0b4
Add preset tests to causal_lm
abuelnasr0 4ab0d32
Add backbone preset tests
abuelnasr0 b2b7c55
Naming nits
abuelnasr0 0dff9f1
Clean surotaryembeddding
abuelnasr0 9552750
Lower case file name
abuelnasr0 b76f314
Save SuScaled rope factors as python lists
abuelnasr0 ad585b0
Rename orignal_max seq_length to training seq_length
abuelnasr0 9f10b63
Foemat
abuelnasr0 55e15bf
Remove placeholders tokens from spm
abuelnasr0 19fc9ca
Edit examples
abuelnasr0 f0a4236
Nit in generate
abuelnasr0 c205c20
Change training_seq_length to pretraining_seq_length
abuelnasr0 b735170
Update links
mattdangerw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
|
|
||
| from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor | ||
| from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer | ||
| from keras_nlp.src.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class Phi3PreprocessorTest(TestCase): | ||
| def setUp(self): | ||
| self.tokenizer = Phi3Tokenizer( | ||
| # Generated using create_phi3_test_proto.py | ||
| proto=os.path.join(self.get_test_data_dir(), "phi3_test_vocab.spm") | ||
| ) | ||
| self.init_kwargs = { | ||
| "tokenizer": self.tokenizer, | ||
| "sequence_length": 12, | ||
| } | ||
| self.input_data = ( | ||
| # Encoded to [3, 5, 6, 4, 3, 9, 7, 11, 3, 15] | ||
| ["the fox <|endoftext|>"], | ||
| [1], # Pass through labels. | ||
| [1.0], # Pass through sample_weights. | ||
| ) | ||
|
|
||
| def test_preprocessor_basics(self): | ||
| self.run_preprocessor_test( | ||
| cls=Phi3Preprocessor, | ||
| init_kwargs=self.init_kwargs, | ||
| input_data=self.input_data, | ||
| expected_output=( | ||
| { | ||
| "token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 3, 15, 0]], | ||
| "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], | ||
| }, | ||
| [1], # Pass through labels. | ||
| [1.0], # Pass through sample_weights. | ||
| ), | ||
| ) | ||
|
|
||
| def test_errors_for_2d_list_input(self): | ||
| preprocessor = Phi3Preprocessor(**self.init_kwargs) | ||
| ambiguous_input = [["one", "two"], ["three", "four"]] | ||
| with self.assertRaises(ValueError): | ||
| preprocessor(ambiguous_input) | ||
|
|
||
| @pytest.mark.extra_large | ||
| def test_all_presets(self): | ||
| for preset in Phi3Preprocessor.presets: | ||
| self.run_preset_test( | ||
| cls=Phi3Preprocessor, | ||
| preset=preset, | ||
| input_data=self.input_data, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone | ||
| from keras_nlp.src.models.phi3.phi3_presets import backbone_presets | ||
| from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer | ||
| from keras_nlp.src.utils.preset_utils import register_presets | ||
|
|
||
| register_presets(backbone_presets, (Phi3Backbone, Phi3Tokenizer)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,259 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from keras_nlp.src.backend import keras | ||
| from keras_nlp.src.backend import ops | ||
| from keras_nlp.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
| from keras_nlp.src.models.phi3.phi3_rotary_embedding import ( | ||
| Phi3SuScaledRotaryEmbedding, | ||
| ) | ||
| from keras_nlp.src.utils.keras_utils import clone_initializer | ||
|
|
||
|
|
||
| class Phi3Attention(keras.layers.Layer): | ||
| """A cached grounded query attention layer.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_query_heads, | ||
| num_key_value_heads, | ||
| kernel_initializer="glorot_uniform", | ||
| dropout=0, | ||
| max_sequence_length=4096, | ||
| original_max_sequence_length=4096, | ||
| rope_max_wavelength=10000, | ||
| rope_scaling_type=None, | ||
| rope_scaling_short_factor=None, | ||
| rope_scaling_long_factor=None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.num_query_heads = num_query_heads | ||
| self.num_key_value_heads = num_key_value_heads | ||
| self.num_key_value_groups = num_query_heads // num_key_value_heads | ||
| self.dropout = dropout | ||
|
|
||
| self.max_sequence_length = max_sequence_length | ||
| self.original_max_sequence_length = original_max_sequence_length | ||
| self.rope_max_wavelength = rope_max_wavelength | ||
| self.rope_scaling_type = rope_scaling_type | ||
| self.rope_scaling_short_factor = rope_scaling_short_factor | ||
| self.rope_scaling_long_factor = rope_scaling_long_factor | ||
|
|
||
| self.kernel_initializer = keras.initializers.get( | ||
| clone_initializer(kernel_initializer) | ||
| ) | ||
|
|
||
| def build(self, inputs_shape): | ||
| # Einsum variables: | ||
| # b = batch size | ||
| # q = query length | ||
| # k = key/value length | ||
| # m = model dim | ||
| # u = num query heads | ||
| # v = num key/value heads | ||
| # h = head dim | ||
| hidden_dim = inputs_shape[-1] | ||
| head_dim = hidden_dim // self.num_query_heads | ||
| self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype)) | ||
|
|
||
| self.query_dense = keras.layers.EinsumDense( | ||
| equation="bqm,muh->bquh", | ||
| output_shape=(None, self.num_query_heads, head_dim), | ||
| kernel_initializer=self.kernel_initializer, | ||
| dtype=self.dtype_policy, | ||
| name="query", | ||
| ) | ||
| self.query_dense.build(inputs_shape) | ||
|
|
||
| self.key_dense = keras.layers.EinsumDense( | ||
| equation="bkm,mvh->bkvh", | ||
| output_shape=( | ||
| None, | ||
| self.num_key_value_heads, | ||
| head_dim, | ||
| ), | ||
| kernel_initializer=self.kernel_initializer, | ||
| dtype=self.dtype_policy, | ||
| name="key", | ||
| ) | ||
| self.key_dense.build(inputs_shape) | ||
|
|
||
| self.value_dense = keras.layers.EinsumDense( | ||
| equation="bkm,mvh->bkvh", | ||
| output_shape=( | ||
| None, | ||
| self.num_key_value_heads, | ||
| head_dim, | ||
| ), | ||
| kernel_initializer=self.kernel_initializer, | ||
| dtype=self.dtype_policy, | ||
| name="value", | ||
| ) | ||
| self.value_dense.build(inputs_shape) | ||
|
|
||
| self.softmax = keras.layers.Softmax( | ||
| axis=-1, | ||
| dtype="float32", | ||
| name="attention_softmax", | ||
| ) | ||
|
|
||
| self.dropout_layer = keras.layers.Dropout( | ||
| rate=self.dropout, | ||
| dtype=self.dtype_policy, | ||
| ) | ||
|
|
||
| self.output_dense = keras.layers.EinsumDense( | ||
| equation="bquh,uhm->bqm", | ||
| output_shape=(None, hidden_dim), | ||
| kernel_initializer=self.kernel_initializer, | ||
| dtype=self.dtype_policy, | ||
| name="attention_output", | ||
| ) | ||
| self.output_dense.build((None, None, self.num_query_heads, head_dim)) | ||
|
|
||
| if self.rope_scaling_type is None: | ||
| self.rotary_embedding_layer = RotaryEmbedding( | ||
| max_wavelength=self.rope_max_wavelength, | ||
| dtype=self.dtype_policy, | ||
| ) | ||
| elif self.rope_scaling_type == "su": | ||
| if len(self.rope_scaling_short_factor) != head_dim // 2: | ||
| raise ValueError( | ||
| "`rope_scaling_short_factor` must be of length " | ||
| "`hidden_dim//num_query_heads//2`. " | ||
| "`len(rope_scaling_short_factor)` is " | ||
| f"{len(self.rope_scaling_short_factor)} " | ||
| f"while it should be {head_dim // 2}." | ||
| ) | ||
| if len(self.rope_scaling_long_factor) != head_dim // 2: | ||
| raise ValueError( | ||
| "`rope_scaling_long_factor` must be of length " | ||
| "`hidden_dim//num_query_heads//2`. " | ||
| "`len(rope_scaling_long_factor)` is " | ||
| f"{len(self.rope_scaling_long_factor)} " | ||
| f"while it should be {head_dim // 2}." | ||
| ) | ||
| self.rotary_embedding_layer = Phi3SuScaledRotaryEmbedding( | ||
| max_sequence_length=self.max_sequence_length, | ||
| original_max_sequence_length=self.original_max_sequence_length, | ||
| inverese_freq_short_factor=self.rope_scaling_short_factor, | ||
| inverese_freq_long_factor=self.rope_scaling_long_factor, | ||
| max_wavelength=self.rope_max_wavelength, | ||
| dtype=self.dtype_policy, | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| '`rope_scaling_type` must be `None` or `"su"`.' | ||
| "if `None` is choosed, `RotaryEmbedding` will be used." | ||
| 'if `"su"` is choosed, `Phi3SuScaledRotaryEmbedding` will be ' | ||
| "used." | ||
| ) | ||
|
|
||
| self.built = True | ||
|
|
||
| def call( | ||
| self, | ||
| hidden_states, | ||
| attention_mask=None, | ||
| cache=None, | ||
| cache_update_index=None, | ||
| training=None, | ||
| ): | ||
| start_index = ( | ||
| cache_update_index if cache_update_index is not None else 0 | ||
| ) | ||
|
|
||
| query = self.query_dense(hidden_states) | ||
| key = self.key_dense(hidden_states) | ||
| value = self.value_dense(hidden_states) | ||
|
|
||
| # Compute RoPE for queries | ||
| query = self.rotary_embedding_layer(query, start_index=start_index) | ||
| key = self.rotary_embedding_layer(key, start_index=start_index) | ||
|
|
||
| if cache is not None: | ||
| key_cache = cache[:, 0, ...] | ||
| value_cache = cache[:, 1, ...] | ||
| if cache_update_index is None: | ||
| key = key_cache | ||
| value = value_cache | ||
| else: | ||
| start = [0, cache_update_index, 0, 0] | ||
| key = ops.slice_update(key_cache, start, key) | ||
| value = ops.slice_update(value_cache, start, value) | ||
| cache = ops.stack((key, value), axis=1) | ||
| else: | ||
| if cache_update_index is not None: | ||
| raise ValueError( | ||
| "`cache_update_index` should not be set if `cache` is " | ||
| f"`None`. Received: cache={cache}, " | ||
| f"cache_update_index={cache_update_index}" | ||
| ) | ||
|
|
||
| # [batch_shape, seq_len, num_key_value_heads, head_dim] | ||
| # -> [batch_shape, seq_len, num_heads, head_dim] | ||
| key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) | ||
| value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) | ||
|
|
||
| attention_output = self._compute_attention( | ||
| query, key, value, attention_mask | ||
| ) | ||
|
|
||
| attention_output = self.dropout_layer( | ||
| attention_output, training=training | ||
| ) | ||
|
|
||
| attention_output = self.output_dense(attention_output) | ||
|
|
||
| if cache is not None: | ||
| return attention_output, cache | ||
| return attention_output | ||
|
|
||
| def _masked_softmax(self, attention_scores, attention_mask=None): | ||
| if attention_mask is not None: | ||
| return self.softmax(attention_scores, attention_mask[:, None, :, :]) | ||
| return self.softmax(attention_scores) | ||
|
|
||
| def _compute_attention(self, query, key, value, attention_mask=None): | ||
| attention_scores = ops.einsum("bquh,bkuh->buqk", query, key) | ||
| attention_scores = attention_scores / self._norm_factor | ||
| attention_scores = self._masked_softmax( | ||
| attention_scores, attention_mask | ||
| ) | ||
| attention_scores = ops.cast(attention_scores, self.compute_dtype) | ||
| attention_output = ops.einsum( | ||
| "buqk,bkuh->bquh", attention_scores, value | ||
| ) | ||
|
|
||
| return attention_output | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "num_query_heads": self.num_query_heads, | ||
| "num_key_value_heads": self.num_key_value_heads, | ||
| "kernel_initializer": keras.initializers.serialize( | ||
| self.kernel_initializer | ||
| ), | ||
| "dropout": self.dropout, | ||
| "max_sequence_length": self.max_sequence_length, | ||
| "original_max_sequence_length": self.original_max_sequence_length, | ||
| "rope_max_wavelength": self.rope_max_wavelength, | ||
| "rope_scaling_type": self.rope_scaling_type, | ||
| "rope_scaling_short_factor": self.rope_scaling_short_factor, | ||
| "rope_scaling_long_factor": self.rope_scaling_long_factor, | ||
| } | ||
| ) | ||
| return config |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.