Skip to content

Commit 4f710be

Browse files
committed
T5 conversion script and presets
1 parent 3652d64 commit 4f710be

File tree

6 files changed

+393
-11
lines changed

6 files changed

+393
-11
lines changed

keras_nlp/models/t5/t5_backbone.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414

1515
"""T5 backbone model."""
1616

17+
import copy
18+
1719
import tensorflow as tf
1820
from tensorflow import keras
1921

2022
from keras_nlp.api_export import keras_nlp_export
2123
from keras_nlp.layers.transformer_layer_utils import compute_causal_mask
2224
from keras_nlp.models.backbone import Backbone
2325
from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm
26+
from keras_nlp.models.t5.t5_presets import backbone_presets
2427
from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer
2528
from keras_nlp.utils.python_utils import classproperty
2629

@@ -54,6 +57,9 @@ class T5Backbone(Backbone):
5457
hidden_dim: int. The hidden size of the Transformer layers.
5558
intermediate_dim: int. The output dimension of the first Dense layer in
5659
a two-layer feedforward network for each Transformer layer.
60+
key_value_dim: int. The dimension of each head of the key/value
61+
projections in the multi-head attention layers. Defaults to
62+
hidden_dim / num_heads.
5763
dropout: float. Dropout probability for the Transformer layers.
5864
activation: activation function (or activation string name). The
5965
activation to be used in the inner dense blocks of the
@@ -75,6 +81,7 @@ def __init__(
7581
num_heads,
7682
hidden_dim,
7783
intermediate_dim,
84+
key_value_dim=None,
7885
dropout=0.1,
7986
activation="gelu",
8087
use_gated_activation=True,
@@ -123,6 +130,7 @@ def __init__(
123130
is_decoder=False,
124131
hidden_dim=hidden_dim,
125132
intermediate_dim=intermediate_dim,
133+
key_value_dim=key_value_dim or hidden_dim // num_heads,
126134
dropout=dropout,
127135
activation=activation,
128136
layer_norm_epsilon=layer_norm_epsilon,
@@ -167,6 +175,7 @@ def __init__(
167175
is_decoder=True,
168176
hidden_dim=hidden_dim,
169177
intermediate_dim=intermediate_dim,
178+
key_value_dim=key_value_dim or hidden_dim // num_heads,
170179
dropout=dropout,
171180
activation=activation,
172181
layer_norm_epsilon=layer_norm_epsilon,
@@ -237,4 +246,4 @@ def token_embedding(self):
237246

238247
@classproperty
239248
def presets(cls):
240-
return {}
249+
return copy.deepcopy(backbone_presets)

keras_nlp/models/t5/t5_multi_head_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
self,
3232
is_decoder,
3333
hidden_dim,
34+
key_value_dim,
3435
num_heads,
3536
dropout,
3637
use_relative_attention_bias=False,
@@ -39,7 +40,7 @@ def __init__(
3940
super().__init__(**kwargs)
4041
self.is_decoder = is_decoder
4142
self.hidden_dim = hidden_dim
42-
self.key_value_dim = hidden_dim // num_heads
43+
self.key_value_dim = key_value_dim
4344
self.num_heads = num_heads
4445
self.use_relative_attention_bias = use_relative_attention_bias
4546

keras_nlp/models/t5/t5_presets.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2023 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""XLM-RoBERTa model preset configurations."""
15+
16+
backbone_presets = {
17+
"t5_small_en": {
18+
"metadata": {
19+
"description": (
20+
"8-layer T5 model. Trained on the Colossal Clean Crawled "
21+
"Corpus (C4)."
22+
),
23+
"params": 0,
24+
"official_name": "T5",
25+
"path": "t5",
26+
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
27+
},
28+
"config": {
29+
"vocabulary_size": 32128,
30+
"num_layers": 8,
31+
"num_heads": 6,
32+
"hidden_dim": 512,
33+
"intermediate_dim": 1024,
34+
"key_value_dim": 64,
35+
"dropout": 0.1,
36+
"activation": "gelu",
37+
"use_gated_activation": True,
38+
"layer_norm_epsilon": 1e-06,
39+
},
40+
"preprocessor_config": {},
41+
},
42+
"t5_base_en": {
43+
"metadata": {
44+
"description": (
45+
"12-layer T5 model. Trained on the Colossal Clean Crawled "
46+
"Corpus (C4)."
47+
),
48+
"params": 0,
49+
"official_name": "T5",
50+
"path": "t5",
51+
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
52+
},
53+
"config": {
54+
"vocabulary_size": 32128,
55+
"num_layers": 12,
56+
"num_heads": 12,
57+
"hidden_dim": 768,
58+
"intermediate_dim": 2048,
59+
"dropout": 0.1,
60+
"activation": "gelu",
61+
"use_gated_activation": True,
62+
"layer_norm_epsilon": 1e-06,
63+
},
64+
"preprocessor_config": {},
65+
},
66+
"t5_large_en": {
67+
"metadata": {
68+
"description": (
69+
"24-layer T5 model. Trained on the Colossal Clean Crawled "
70+
"Corpus (C4)."
71+
),
72+
"params": 0,
73+
"official_name": "T5",
74+
"path": "t5",
75+
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
76+
},
77+
"config": {
78+
"vocabulary_size": 32128,
79+
"num_layers": 24,
80+
"num_heads": 16,
81+
"hidden_dim": 1024,
82+
"intermediate_dim": 2816,
83+
"dropout": 0.1,
84+
"activation": "gelu",
85+
"use_gated_activation": True,
86+
"layer_norm_epsilon": 1e-06,
87+
},
88+
"preprocessor_config": {},
89+
},
90+
"t5_extra_large_en": {
91+
"metadata": {
92+
"description": (
93+
"24-layer T5 model. Trained on the Colossal Clean Crawled "
94+
"Corpus (C4)."
95+
),
96+
"params": 0,
97+
"official_name": "T5",
98+
"path": "t5",
99+
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
100+
},
101+
"config": {
102+
"vocabulary_size": 32128,
103+
"num_layers": 24,
104+
"num_heads": 32,
105+
"hidden_dim": 2048,
106+
"intermediate_dim": 5120,
107+
"dropout": 0.1,
108+
"activation": "gelu",
109+
"use_gated_activation": True,
110+
"layer_norm_epsilon": 1e-06,
111+
},
112+
"preprocessor_config": {},
113+
},
114+
"t5_extra_extra_large_en": {
115+
"metadata": {
116+
"description": (
117+
"24-layer T5 model. Trained on the Colossal Clean Crawled "
118+
"Corpus (C4)."
119+
),
120+
"params": 0,
121+
"official_name": "T5",
122+
"path": "t5",
123+
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md",
124+
},
125+
"config": {
126+
"vocabulary_size": 32128,
127+
"num_layers": 24,
128+
"num_heads": 64,
129+
"hidden_dim": 4096,
130+
"intermediate_dim": 10240,
131+
"dropout": 0.1,
132+
"activation": "gelu",
133+
"use_gated_activation": True,
134+
"layer_norm_epsilon": 1e-06,
135+
},
136+
"preprocessor_config": {},
137+
},
138+
}

keras_nlp/models/t5/t5_transformer_layer.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
is_decoder,
2424
hidden_dim,
2525
intermediate_dim,
26+
key_value_dim,
2627
dropout,
2728
activation,
2829
layer_norm_epsilon,
@@ -36,26 +37,37 @@ def __init__(
3637
self.use_gated_activation = use_gated_activation
3738

3839
self.self_attention = T5MultiHeadAttention(
39-
is_decoder,
40-
hidden_dim,
41-
num_heads,
42-
dropout,
40+
is_decoder=is_decoder,
41+
hidden_dim=hidden_dim,
42+
key_value_dim=key_value_dim,
43+
num_heads=num_heads,
44+
dropout=dropout,
4345
use_relative_attention_bias=use_relative_attention_bias,
4446
)
4547
self.self_attention_layernorm = T5LayerNorm(layer_norm_epsilon)
4648
self.self_attention_dropout = keras.layers.Dropout(dropout)
4749

4850
if self.is_decoder:
4951
self.cross_attention = T5MultiHeadAttention(
50-
is_decoder,
51-
hidden_dim,
52-
num_heads,
53-
dropout,
52+
is_decoder=is_decoder,
53+
hidden_dim=hidden_dim,
54+
key_value_dim=key_value_dim,
55+
num_heads=num_heads,
56+
dropout=dropout,
5457
use_relative_attention_bias=False,
5558
)
5659
self.cross_attention_layernorm = T5LayerNorm(layer_norm_epsilon)
5760
self.cross_attention_dropout = keras.layers.Dropout(dropout)
5861

62+
if activation == "gelu":
63+
64+
def approx_gelu(x):
65+
return keras.activations.gelu(x, approximate=True)
66+
67+
activation = approx_gelu
68+
else:
69+
activation = keras.activations.get(activation)
70+
5971
self.input_projector = keras.layers.Dense(
6072
intermediate_dim,
6173
use_bias=False,
@@ -123,7 +135,7 @@ def call(
123135
x = self.layer_norm(x)
124136
if self.use_gated_activation:
125137
hidden_activation = self.input_projector(x)
126-
hidden_linear = self.gate_projector(hidden_states)
138+
hidden_linear = self.gate_projector(x)
127139
x = hidden_activation * hidden_linear
128140
else:
129141
x = self.input_projector(x)

tools/checkpoint_conversion/checkpoint_conversion_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import hashlib
15+
import inspect
16+
17+
import keras as keraslib
18+
from keras.engine.base_layer_utils import make_variable
1519

1620

1721
def get_md5_checksum(file_path):
@@ -20,3 +24,42 @@ def get_md5_checksum(file_path):
2024
for byte_block in iter(lambda: f.read(4096), b""):
2125
md5_hash.update(byte_block)
2226
return md5_hash.hexdigest()
27+
28+
29+
def port_weights_by_creation_order(source_model_fn, dest_model_fn, debug=False):
30+
"""Assign weights between models by intercepting variable creation.
31+
32+
For each model makes a flat list of all variables created, in order."""
33+
ALL_VARS = []
34+
35+
def make_var(name, shape=None, **kwargs):
36+
if debug:
37+
stack = inspect.stack()
38+
instance = stack[1][0].f_locals["self"]
39+
cls = instance.__class__.__name__
40+
print(f"Class {cls} creating {name} with shape {shape}.")
41+
42+
v = make_variable(name, shape=shape, **kwargs)
43+
ALL_VARS.append(v)
44+
return v
45+
46+
# Patch make variable.
47+
keraslib.engine.base_layer_utils.make_variable = make_var
48+
49+
source_model = source_model_fn()
50+
source_model_vars = ALL_VARS[:]
51+
52+
[ALL_VARS.pop(0) for _ in list(ALL_VARS)]
53+
dest_model = dest_model_fn()
54+
if len(ALL_VARS) != len(source_model_vars):
55+
raise ValueError(
56+
f"Variable counts do not match. 1st model: {len(source_model_vars)} "
57+
"vars, 2nd model: {len(ALL_VARS)} vars"
58+
)
59+
for v1, v2 in zip(source_model_vars, ALL_VARS):
60+
v2.assign(v1.numpy())
61+
62+
# Unpatch make variable.
63+
keraslib.engine.base_layer_utils.make_variable = make_variable
64+
65+
return source_model, dest_model

0 commit comments

Comments
 (0)