Skip to content

Commit 741f98d

Browse files
james77777778mattdangerw
authored andcommitted
Add CLIP and T5XXL for StableDiffusionV3 (#1790)
* Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`. * Make CLIPTextEncoder as Backbone * Add `T5XXLPreprocessor` and remove `T5XXLTokenizer` Add `CLIPPreprocessor` * Use `tf = None` at the top * Replace manual implementation of `CLIPAttention` with `MultiHeadAttention`
1 parent 3e9aaba commit 741f98d

File tree

10 files changed

+960
-0
lines changed

10 files changed

+960
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras import layers
15+
from keras import ops
16+
17+
18+
def quick_gelu(x):
19+
return x * ops.sigmoid(1.702 * x)
20+
21+
22+
class CLIPEncoderBlock(layers.Layer):
23+
def __init__(
24+
self,
25+
hidden_dim,
26+
num_heads,
27+
intermediate_dim,
28+
intermediate_activation="quick_gelu",
29+
**kwargs,
30+
):
31+
super().__init__(**kwargs)
32+
if hidden_dim % num_heads != 0:
33+
raise ValueError(
34+
"`hidden_dim` must be divisible by `num_heads`. "
35+
f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
36+
)
37+
self.hidden_dim = hidden_dim
38+
self.num_heads = num_heads
39+
self.intermediate_dim = intermediate_dim
40+
self.intermediate_activation = intermediate_activation
41+
42+
if intermediate_activation == "quick_gelu":
43+
intermediate_activation = quick_gelu
44+
45+
self.layer_norm_1 = layers.LayerNormalization(
46+
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
47+
)
48+
self.attention = layers.MultiHeadAttention(
49+
num_heads,
50+
hidden_dim // num_heads,
51+
dtype=self.dtype_policy,
52+
name="attention",
53+
)
54+
self.layer_norm_2 = layers.LayerNormalization(
55+
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
56+
)
57+
self.dense_1 = layers.Dense(
58+
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
59+
)
60+
self.activation = layers.Activation(
61+
intermediate_activation, dtype=self.dtype_policy, name="activation"
62+
)
63+
self.dense_2 = layers.Dense(
64+
self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
65+
)
66+
67+
def build(self, input_shape):
68+
self.layer_norm_1.build(input_shape)
69+
self.attention.build(input_shape, input_shape, input_shape)
70+
self.layer_norm_2.build(input_shape)
71+
self.dense_1.build(input_shape)
72+
input_shape = self.dense_1.compute_output_shape(input_shape)
73+
self.dense_2.build(input_shape)
74+
75+
def compute_output_shape(self, inputs_shape):
76+
outputs_shape = list(inputs_shape)
77+
outputs_shape[-1] = self.hidden_dim
78+
return outputs_shape
79+
80+
def call(self, x, training=None):
81+
residual = x
82+
x = self.layer_norm_1(x)
83+
x = self.attention(x, x, x, training=training, use_causal_mask=True)
84+
x = ops.add(residual, x)
85+
86+
residual = x
87+
x = self.dense_1(self.layer_norm_2(residual))
88+
x = self.activation(x)
89+
x = self.dense_2(x)
90+
x = ops.add(residual, x)
91+
return x
92+
93+
def get_config(self):
94+
config = super().get_config()
95+
config.update(
96+
{
97+
"hidden_dim": self.hidden_dim,
98+
"num_heads": self.num_heads,
99+
"intermediate_dim": self.intermediate_dim,
100+
"intermediate_activation": self.intermediate_activation,
101+
}
102+
)
103+
return config
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
16+
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
17+
from keras_nlp.src.models.preprocessor import Preprocessor
18+
from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import (
19+
CLIPTokenizer,
20+
)
21+
from keras_nlp.src.utils.tensor_utils import preprocessing_function
22+
23+
try:
24+
import tensorflow as tf
25+
except ImportError:
26+
tf = None
27+
28+
29+
class CLIPPreprocessor(Preprocessor):
30+
tokenizer_cls = CLIPTokenizer
31+
32+
def __init__(
33+
self,
34+
tokenizer,
35+
sequence_length=77,
36+
add_start_token=True,
37+
add_end_token=False,
38+
to_lower=True,
39+
pad_with_end_token=True,
40+
**kwargs,
41+
):
42+
super().__init__(**kwargs)
43+
self.tokenizer = tokenizer
44+
self.sequence_length = sequence_length
45+
self.add_start_token = add_start_token
46+
self.add_end_token = add_end_token
47+
self.to_lower = to_lower
48+
self.pad_with_end_token = pad_with_end_token
49+
50+
def build(self, input_shape):
51+
# Defer packer creation to `build()` so that we can be sure tokenizer
52+
# assets have loaded when restoring a saved model.
53+
pad_value = self.tokenizer.pad_token_id
54+
if self.pad_with_end_token:
55+
pad_value = self.tokenizer.end_token_id
56+
57+
self.packer = StartEndPacker(
58+
start_value=self.tokenizer.start_token_id,
59+
end_value=self.tokenizer.end_token_id,
60+
pad_value=pad_value,
61+
sequence_length=self.sequence_length,
62+
return_padding_mask=True,
63+
)
64+
self.built = True
65+
66+
@preprocessing_function
67+
def call(self, x, y=None, sample_weight=None, sequence_length=None):
68+
if self.to_lower:
69+
x = tf.strings.lower(x)
70+
token_ids, padding_mask = self.packer(
71+
self.tokenizer(x),
72+
sequence_length=sequence_length or self.sequence_length,
73+
add_start_value=self.add_start_token,
74+
add_end_value=self.add_end_token,
75+
)
76+
x = {
77+
"token_ids": token_ids,
78+
"padding_mask": padding_mask,
79+
}
80+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
81+
82+
def get_config(self):
83+
config = super().get_config()
84+
config.update(
85+
{
86+
"sequence_length": self.sequence_length,
87+
"add_start_token": self.add_start_token,
88+
"add_end_token": self.add_end_token,
89+
"to_lower": self.to_lower,
90+
"pad_with_end_token": self.pad_with_end_token,
91+
}
92+
)
93+
return config
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
from keras_nlp.src.models.stable_diffusion_v3.clip_preprocessor import (
17+
CLIPPreprocessor,
18+
)
19+
from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import (
20+
CLIPTokenizer,
21+
)
22+
from keras_nlp.src.tests.test_case import TestCase
23+
24+
25+
class CLIPPreprocessorTest(TestCase):
26+
def setUp(self):
27+
vocab = ["air", "plane</w>", "port</w>"]
28+
vocab += ["<|endoftext|>", "<|startoftext|>"]
29+
vocab = dict([(token, i + 1) for i, token in enumerate(vocab)])
30+
merges = ["a i", "p l", "n e</w>", "p o", "r t</w>", "ai r", "pl a"]
31+
merges += ["po rt</w>", "pla ne</w>"]
32+
self.tokenizer = CLIPTokenizer(vocabulary=vocab, merges=merges)
33+
self.init_kwargs = {
34+
"tokenizer": self.tokenizer,
35+
"sequence_length": 8,
36+
}
37+
self.input_data = [" airplane airport"]
38+
39+
def test_preprocessor_basics(self):
40+
self.run_preprocessing_layer_test(
41+
cls=CLIPPreprocessor,
42+
init_kwargs=self.init_kwargs,
43+
input_data=self.input_data,
44+
expected_output={
45+
"token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]],
46+
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]],
47+
},
48+
)
49+
50+
def test_no_start_end_token(self):
51+
input_data = [" airplane airport"] * 4
52+
preprocessor = CLIPPreprocessor(
53+
tokenizer=self.tokenizer,
54+
sequence_length=8,
55+
add_start_token=False,
56+
add_end_token=False,
57+
pad_with_end_token=False,
58+
)
59+
x = preprocessor(input_data)
60+
self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4)
61+
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
62+
63+
def test_sequence_length_override(self):
64+
input_data = " airplane airport"
65+
preprocessor = CLIPPreprocessor(**self.init_kwargs)
66+
x = preprocessor(input_data, sequence_length=4)
67+
self.assertAllEqual(x["token_ids"], [5, 1, 2, 1])
68+
69+
@pytest.mark.kaggle_key_required
70+
@pytest.mark.extra_large
71+
def test_all_presets(self):
72+
self.skipTest("TODO")
73+
for preset in CLIPPreprocessor.presets:
74+
self.run_preset_test(
75+
cls=CLIPPreprocessor,
76+
preset=preset,
77+
input_data=self.input_data,
78+
)

0 commit comments

Comments
 (0)