Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion keras_nlp/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@

"""T5 backbone model."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.transformer_layer_utils import compute_causal_mask
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm
from keras_nlp.models.t5.t5_presets import backbone_presets
from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -54,6 +57,9 @@ class T5Backbone(Backbone):
hidden_dim: int. The hidden size of the Transformer layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each Transformer layer.
key_value_dim: int. The dimension of each head of the key/value
projections in the multi-head attention layers. Defaults to
hidden_dim / num_heads.
dropout: float. Dropout probability for the Transformer layers.
activation: activation function (or activation string name). The
activation to be used in the inner dense blocks of the
Expand All @@ -75,6 +81,7 @@ def __init__(
num_heads,
hidden_dim,
intermediate_dim,
key_value_dim=None,
dropout=0.1,
activation="gelu",
use_gated_activation=True,
Expand Down Expand Up @@ -123,6 +130,7 @@ def __init__(
is_decoder=False,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
key_value_dim=key_value_dim or hidden_dim // num_heads,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
Expand Down Expand Up @@ -167,6 +175,7 @@ def __init__(
is_decoder=True,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
key_value_dim=key_value_dim or hidden_dim // num_heads,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
Expand Down Expand Up @@ -237,4 +246,4 @@ def token_embedding(self):

@classproperty
def presets(cls):
return {}
return copy.deepcopy(backbone_presets)
3 changes: 2 additions & 1 deletion keras_nlp/models/t5/t5_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self,
is_decoder,
hidden_dim,
key_value_dim,
num_heads,
dropout,
use_relative_attention_bias=False,
Expand All @@ -39,7 +40,7 @@ def __init__(
super().__init__(**kwargs)
self.is_decoder = is_decoder
self.hidden_dim = hidden_dim
self.key_value_dim = hidden_dim // num_heads
self.key_value_dim = key_value_dim
self.num_heads = num_heads
self.use_relative_attention_bias = use_relative_attention_bias

Expand Down
138 changes: 138 additions & 0 deletions keras_nlp/models/t5/t5_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLM-RoBERTa model preset configurations."""

backbone_presets = {
"t5_small_en": {
"metadata": {
"description": (
"8-layer T5 model. Trained on the Colossal Clean Crawled "
"Corpus (C4)."
),
"params": 0,
"official_name": "T5",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
},
"config": {
"vocabulary_size": 32128,
"num_layers": 8,
"num_heads": 6,
"hidden_dim": 512,
"intermediate_dim": 1024,
"key_value_dim": 64,
"dropout": 0.1,
"activation": "gelu",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
},
"t5_base_en": {
"metadata": {
"description": (
"12-layer T5 model. Trained on the Colossal Clean Crawled "
"Corpus (C4)."
),
"params": 0,
"official_name": "T5",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
},
"config": {
"vocabulary_size": 32128,
"num_layers": 12,
"num_heads": 12,
"hidden_dim": 768,
"intermediate_dim": 2048,
"dropout": 0.1,
"activation": "gelu",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
},
"t5_large_en": {
"metadata": {
"description": (
"24-layer T5 model. Trained on the Colossal Clean Crawled "
"Corpus (C4)."
),
"params": 0,
"official_name": "T5",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
},
"config": {
"vocabulary_size": 32128,
"num_layers": 24,
"num_heads": 16,
"hidden_dim": 1024,
"intermediate_dim": 2816,
"dropout": 0.1,
"activation": "gelu",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
},
"t5_extra_large_en": {
"metadata": {
"description": (
"24-layer T5 model. Trained on the Colossal Clean Crawled "
"Corpus (C4)."
),
"params": 0,
"official_name": "T5",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
},
"config": {
"vocabulary_size": 32128,
"num_layers": 24,
"num_heads": 32,
"hidden_dim": 2048,
"intermediate_dim": 5120,
"dropout": 0.1,
"activation": "gelu",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
},
"t5_extra_extra_large_en": {
"metadata": {
"description": (
"24-layer T5 model. Trained on the Colossal Clean Crawled "
"Corpus (C4)."
),
"params": 0,
"official_name": "T5",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
},
"config": {
"vocabulary_size": 32128,
"num_layers": 24,
"num_heads": 64,
"hidden_dim": 4096,
"intermediate_dim": 10240,
"dropout": 0.1,
"activation": "gelu",
"use_gated_activation": True,
"layer_norm_epsilon": 1e-06,
},
"preprocessor_config": {},
},
}
30 changes: 21 additions & 9 deletions keras_nlp/models/t5/t5_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
is_decoder,
hidden_dim,
intermediate_dim,
key_value_dim,
dropout,
activation,
layer_norm_epsilon,
Expand All @@ -36,26 +37,37 @@ def __init__(
self.use_gated_activation = use_gated_activation

self.self_attention = T5MultiHeadAttention(
is_decoder,
hidden_dim,
num_heads,
dropout,
is_decoder=is_decoder,
hidden_dim=hidden_dim,
key_value_dim=key_value_dim,
num_heads=num_heads,
dropout=dropout,
use_relative_attention_bias=use_relative_attention_bias,
)
self.self_attention_layernorm = T5LayerNorm(layer_norm_epsilon)
self.self_attention_dropout = keras.layers.Dropout(dropout)

if self.is_decoder:
self.cross_attention = T5MultiHeadAttention(
is_decoder,
hidden_dim,
num_heads,
dropout,
is_decoder=is_decoder,
hidden_dim=hidden_dim,
key_value_dim=key_value_dim,
num_heads=num_heads,
dropout=dropout,
use_relative_attention_bias=False,
)
self.cross_attention_layernorm = T5LayerNorm(layer_norm_epsilon)
self.cross_attention_dropout = keras.layers.Dropout(dropout)

if activation == "gelu":

def approx_gelu(x):
return keras.activations.gelu(x, approximate=True)

activation = approx_gelu
else:
activation = keras.activations.get(activation)

self.input_projector = keras.layers.Dense(
intermediate_dim,
use_bias=False,
Expand Down Expand Up @@ -123,7 +135,7 @@ def call(
x = self.layer_norm(x)
if self.use_gated_activation:
hidden_activation = self.input_projector(x)
hidden_linear = self.gate_projector(hidden_states)
hidden_linear = self.gate_projector(x)
x = hidden_activation * hidden_linear
else:
x = self.input_projector(x)
Expand Down
43 changes: 43 additions & 0 deletions tools/checkpoint_conversion/checkpoint_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import inspect

import keras as keraslib
from keras.engine.base_layer_utils import make_variable


def get_md5_checksum(file_path):
Expand All @@ -20,3 +24,42 @@ def get_md5_checksum(file_path):
for byte_block in iter(lambda: f.read(4096), b""):
md5_hash.update(byte_block)
return md5_hash.hexdigest()


def port_weights_by_creation_order(source_model_fn, dest_model_fn, debug=False):
"""Assign weights between models by intercepting variable creation.

For each model makes a flat list of all variables created, in order."""
ALL_VARS = []

def make_var(name, shape=None, **kwargs):
if debug:
stack = inspect.stack()
instance = stack[1][0].f_locals["self"]
cls = instance.__class__.__name__
print(f"Class {cls} creating {name} with shape {shape}.")

v = make_variable(name, shape=shape, **kwargs)
ALL_VARS.append(v)
return v

# Patch make variable.
keraslib.engine.base_layer_utils.make_variable = make_var

source_model = source_model_fn()
source_model_vars = ALL_VARS[:]

[ALL_VARS.pop(0) for _ in list(ALL_VARS)]
dest_model = dest_model_fn()
if len(ALL_VARS) != len(source_model_vars):
raise ValueError(
f"Variable counts do not match. 1st model: {len(source_model_vars)} "
"vars, 2nd model: {len(ALL_VARS)} vars"
)
for v1, v2 in zip(source_model_vars, ALL_VARS):
v2.assign(v1.numpy())

# Unpatch make variable.
keraslib.engine.base_layer_utils.make_variable = make_variable

return source_model, dest_model
Loading