Skip to content
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
e2b569a
Add phi3
abuelnasr0 Apr 25, 2024
c963f7b
Add phi3 to init
abuelnasr0 Apr 25, 2024
0f30b5f
layer naming and some nits
abuelnasr0 Apr 25, 2024
1d18572
Decoder layers naming
abuelnasr0 Apr 25, 2024
3d24cb2
Remove bias from einsumdense
abuelnasr0 Apr 26, 2024
a98369a
nit fix for layernorm
abuelnasr0 Apr 26, 2024
aecd9e2
Add SuRotary embedding
abuelnasr0 Apr 26, 2024
368e5a2
Romve print()
abuelnasr0 Apr 26, 2024
4267257
Add conversion script
abuelnasr0 Apr 27, 2024
e946503
Nit fix in script
abuelnasr0 Apr 27, 2024
145864b
Add phi3_4k as default preset
abuelnasr0 Apr 27, 2024
c5c78ed
Fix Doc and nit changes
abuelnasr0 Apr 27, 2024
7b0def0
Nit in test
abuelnasr0 Apr 27, 2024
c78c482
Doc fix
abuelnasr0 Apr 29, 2024
cd0381a
Add length check for rope scaling factors
abuelnasr0 Apr 29, 2024
6f9108d
Calculate the mean of the absolute differnce in conversion script
abuelnasr0 Apr 29, 2024
8e99e04
Fix typo
abuelnasr0 Apr 29, 2024
b3ca8a3
Add tokenizer and preprocessor
abuelnasr0 May 2, 2024
b53a326
Format fix
abuelnasr0 May 2, 2024
0e37c9a
Fix dtype and device in conversion script
abuelnasr0 May 2, 2024
45ab340
Batch the input
abuelnasr0 May 2, 2024
a459038
Batch the input
abuelnasr0 May 2, 2024
9c38dec
Nit
abuelnasr0 May 2, 2024
07832e5
Add notify for upload
abuelnasr0 May 2, 2024
fc1cf0b
ADd causal_lm preprocessor
abuelnasr0 May 2, 2024
aac962d
Add causal lm
abuelnasr0 May 2, 2024
49103f3
Fix format
abuelnasr0 May 2, 2024
49a5495
small fixes
abuelnasr0 May 2, 2024
dcead36
Add phi3 to the new api
abuelnasr0 May 2, 2024
0d990f5
Api gen
abuelnasr0 May 2, 2024
ac8770d
Public named sublayers
abuelnasr0 May 6, 2024
1c2a70e
Publicc named sublayers in decoder layer
abuelnasr0 May 6, 2024
e63430a
Simplify dropout
abuelnasr0 May 6, 2024
9ac44b6
Fix tokenizer tests
abuelnasr0 May 6, 2024
6355e67
Fix conversion script
abuelnasr0 May 6, 2024
a206480
use preprocessor
abuelnasr0 May 6, 2024
968a220
use preprocessor
abuelnasr0 May 6, 2024
7c17bd1
Fix keras input
abuelnasr0 May 6, 2024
0c165a0
Fix keras model input
abuelnasr0 May 6, 2024
074ea69
Only validate with validate_dtype
abuelnasr0 May 6, 2024
d5c7fab
Only validate with validate_dtype
abuelnasr0 May 6, 2024
0368483
Change seq length
abuelnasr0 May 6, 2024
1eed34b
Change text
abuelnasr0 May 6, 2024
d048f2d
Set pad token id to 0
abuelnasr0 May 7, 2024
3af4096
Default stop at EOS and EOT
abuelnasr0 May 7, 2024
c9f0ad9
Add presets
abuelnasr0 May 7, 2024
5c5e4ef
Add presets and tests to tokenizer
abuelnasr0 May 7, 2024
5225c6f
Add prepreocessor preset tests
abuelnasr0 May 7, 2024
b02c0b4
Add preset tests to causal_lm
abuelnasr0 May 7, 2024
4ab0d32
Add backbone preset tests
abuelnasr0 May 7, 2024
b2b7c55
Naming nits
abuelnasr0 May 7, 2024
0dff9f1
Clean surotaryembeddding
abuelnasr0 May 7, 2024
9552750
Lower case file name
abuelnasr0 May 9, 2024
b76f314
Save SuScaled rope factors as python lists
abuelnasr0 May 9, 2024
ad585b0
Rename orignal_max seq_length to training seq_length
abuelnasr0 May 9, 2024
9f10b63
Foemat
abuelnasr0 May 9, 2024
55e15bf
Remove placeholders tokens from spm
abuelnasr0 May 9, 2024
19fc9ca
Edit examples
abuelnasr0 May 9, 2024
f0a4236
Nit in generate
abuelnasr0 May 10, 2024
c205c20
Change training_seq_length to pretraining_seq_length
abuelnasr0 May 14, 2024
b735170
Update links
mattdangerw May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Phi3CausalLMPreprocessor,
)
from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier
Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Phi3CausalLMPreprocessor,
)
from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_classifier import RobertaClassifier
Expand Down
69 changes: 69 additions & 0 deletions keras_nlp/src/models/phi3/Phi3_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

from keras_nlp.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.tests.test_case import TestCase


class Phi3PreprocessorTest(TestCase):
def setUp(self):
self.tokenizer = Phi3Tokenizer(
# Generated using create_phi3_test_proto.py
proto=os.path.join(self.get_test_data_dir(), "phi3_test_vocab.spm")
)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 12,
}
self.input_data = (
# Encoded to [3, 5, 6, 4, 3, 9, 7, 11, 3, 15]
["the fox <|endoftext|>"],
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
)

def test_preprocessor_basics(self):
self.run_preprocessor_test(
cls=Phi3Preprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=(
{
"token_ids": [[1, 3, 5, 6, 4, 3, 9, 7, 11, 3, 15, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]],
},
[1], # Pass through labels.
[1.0], # Pass through sample_weights.
),
)

def test_errors_for_2d_list_input(self):
preprocessor = Phi3Preprocessor(**self.init_kwargs)
ambiguous_input = [["one", "two"], ["three", "four"]]
with self.assertRaises(ValueError):
preprocessor(ambiguous_input)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in Phi3Preprocessor.presets:
self.run_preset_test(
cls=Phi3Preprocessor,
preset=preset,
input_data=self.input_data,
)
20 changes: 20 additions & 0 deletions keras_nlp/src/models/phi3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_presets import backbone_presets
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (Phi3Backbone, Phi3Tokenizer))
259 changes: 259 additions & 0 deletions keras_nlp/src/models/phi3/phi3_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.backend import keras
from keras_nlp.src.backend import ops
from keras_nlp.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_nlp.src.models.phi3.phi3_rotary_embedding import (
Phi3SuScaledRotaryEmbedding,
)
from keras_nlp.src.utils.keras_utils import clone_initializer


class Phi3Attention(keras.layers.Layer):
"""A cached grounded query attention layer."""

def __init__(
self,
num_query_heads,
num_key_value_heads,
kernel_initializer="glorot_uniform",
dropout=0,
max_sequence_length=4096,
original_max_sequence_length=4096,
rope_max_wavelength=10000,
rope_scaling_type=None,
rope_scaling_short_factor=None,
rope_scaling_long_factor=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.dropout = dropout

self.max_sequence_length = max_sequence_length
self.original_max_sequence_length = original_max_sequence_length
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_short_factor = rope_scaling_short_factor
self.rope_scaling_long_factor = rope_scaling_long_factor

self.kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
hidden_dim = inputs_shape[-1]
head_dim = hidden_dim // self.num_query_heads
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, head_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)

self.softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)

self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, hidden_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build((None, None, self.num_query_heads, head_dim))

if self.rope_scaling_type is None:
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
dtype=self.dtype_policy,
)
elif self.rope_scaling_type == "su":
if len(self.rope_scaling_short_factor) != head_dim // 2:
raise ValueError(
"`rope_scaling_short_factor` must be of length "
"`hidden_dim//num_query_heads//2`. "
"`len(rope_scaling_short_factor)` is "
f"{len(self.rope_scaling_short_factor)} "
f"while it should be {head_dim // 2}."
)
if len(self.rope_scaling_long_factor) != head_dim // 2:
raise ValueError(
"`rope_scaling_long_factor` must be of length "
"`hidden_dim//num_query_heads//2`. "
"`len(rope_scaling_long_factor)` is "
f"{len(self.rope_scaling_long_factor)} "
f"while it should be {head_dim // 2}."
)
self.rotary_embedding_layer = Phi3SuScaledRotaryEmbedding(
max_sequence_length=self.max_sequence_length,
original_max_sequence_length=self.original_max_sequence_length,
inverese_freq_short_factor=self.rope_scaling_short_factor,
inverese_freq_long_factor=self.rope_scaling_long_factor,
max_wavelength=self.rope_max_wavelength,
dtype=self.dtype_policy,
)
else:
raise ValueError(
'`rope_scaling_type` must be `None` or `"su"`.'
"if `None` is choosed, `RotaryEmbedding` will be used."
'if `"su"` is choosed, `Phi3SuScaledRotaryEmbedding` will be '
"used."
)

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)

query = self.query_dense(hidden_states)
key = self.key_dense(hidden_states)
value = self.value_dense(hidden_states)

# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)
key = self.rotary_embedding_layer(key, start_index=start_index)

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key)
value = ops.slice_update(value_cache, start, value)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self.softmax(attention_scores, attention_mask[:, None, :, :])
return self.softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum("bquh,bkuh->buqk", query, key)
attention_scores = attention_scores / self._norm_factor
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
"buqk,bkuh->bquh", attention_scores, value
)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"original_max_sequence_length": self.original_max_sequence_length,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_type": self.rope_scaling_type,
"rope_scaling_short_factor": self.rope_scaling_short_factor,
"rope_scaling_long_factor": self.rope_scaling_long_factor,
}
)
return config
Loading