Skip to content

Commit b1a6dfd

Browse files
add classifier test
1 parent 40ad2ed commit b1a6dfd

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

keras_nlp/src/models/vgg/vgg_image_classifier_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,47 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
17+
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
18+
from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier
19+
from keras_nlp.src.tests.test_case import TestCase
20+
21+
22+
class VGGImageClassifierTest(TestCase):
23+
def setUp(self):
24+
# Setup model.
25+
images = np.ones((2, 224, 224, 3), dtype="float32")
26+
labels = [0, 3]
27+
self.backbone = VGGBackbone(
28+
stackwise_num_repeats=[2, 2, 3, 3, 3],
29+
input_shape=(224, 224, 3),
30+
include_rescaling=False,
31+
pooling="avg",
32+
)
33+
self.init_kwargs = {
34+
"backbone": self.backbone,
35+
"num_classes": 4,
36+
}
37+
self.train_data = (
38+
images,
39+
labels,
40+
)
41+
42+
def test_classifier_basics(self):
43+
pytest.skip(reason="enable after preprocessor flow is figured out")
44+
self.run_task_test(
45+
cls=VGGImageClassifier,
46+
init_kwargs=self.init_kwargs,
47+
train_data=self.train_data,
48+
expected_output_shape=(2, 2),
49+
)
50+
51+
@pytest.mark.large
52+
def test_saved_model(self):
53+
self.run_model_saving_test(
54+
cls=VGGImageClassifier,
55+
init_kwargs=self.init_kwargs,
56+
input_data=self.input_data,
57+
)

0 commit comments

Comments
 (0)