We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8a6459d commit 5ccca25Copy full SHA for 5ccca25
keras_nlp/models/f_net/f_net_presets_test.py
@@ -85,6 +85,23 @@ def test_classifier_output(self, load_weights):
85
# We don't assert output values, as the head weights are random.
86
model.predict(input_data)
87
88
+ @parameterized.named_parameters(
89
+ ("preset_weights", True), ("random_weights", False)
90
+ )
91
+ def test_backbone_output(self, load_weights):
92
+ input_data = {
93
+ "token_ids": tf.constant([[4, 97, 1467, 5]]),
94
+ "segment_ids": tf.constant([[0, 0, 0, 0]]),
95
+ }
96
+ model = FNetBackbone.from_preset(
97
+ "f_net_base_en", load_weights=load_weights
98
99
+ outputs = model(input_data)
100
+ if load_weights:
101
+ outputs = outputs["sequence_output"][0, 0, :5]
102
+ expected = [4.182479, -0.072181, -0.138097, -0.036582, -0.521765]
103
+ self.assertAllClose(outputs, expected, atol=0.01, rtol=0.01)
104
+
105
@parameterized.named_parameters(
106
("f_net_tokenizer", FNetTokenizer),
107
("f_net_preprocessor", FNetPreprocessor),
0 commit comments