Skip to content

Commit 5ccca25

Browse files
Update f_net_presets_test.py
1 parent 8a6459d commit 5ccca25

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

keras_nlp/models/f_net/f_net_presets_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,23 @@ def test_classifier_output(self, load_weights):
8585
# We don't assert output values, as the head weights are random.
8686
model.predict(input_data)
8787

88+
@parameterized.named_parameters(
89+
("preset_weights", True), ("random_weights", False)
90+
)
91+
def test_backbone_output(self, load_weights):
92+
input_data = {
93+
"token_ids": tf.constant([[4, 97, 1467, 5]]),
94+
"segment_ids": tf.constant([[0, 0, 0, 0]]),
95+
}
96+
model = FNetBackbone.from_preset(
97+
"f_net_base_en", load_weights=load_weights
98+
)
99+
outputs = model(input_data)
100+
if load_weights:
101+
outputs = outputs["sequence_output"][0, 0, :5]
102+
expected = [4.182479, -0.072181, -0.138097, -0.036582, -0.521765]
103+
self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01)
104+
88105
@parameterized.named_parameters(
89106
("f_net_tokenizer", FNetTokenizer),
90107
("f_net_preprocessor", FNetPreprocessor),

0 commit comments

Comments
 (0)