| 
18 | 18 | from keras_hub.src.models.differential_binarization.differential_binarization import (  | 
19 | 19 |     DifferentialBinarization,  | 
20 | 20 | )  | 
 | 21 | +from keras_hub.src.models.differential_binarization.differential_binarization_backbone import (  | 
 | 22 | +    DifferentialBinarizationBackbone,  | 
 | 23 | +)  | 
21 | 24 | from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone  | 
 | 25 | +from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (  | 
 | 26 | +    ResNetImageClassifierPreprocessor,  | 
 | 27 | +)  | 
22 | 28 | from keras_hub.src.tests.test_case import TestCase  | 
23 | 29 | 
 
  | 
24 | 30 | 
 
  | 
25 | 31 | class DifferentialBinarizationTest(TestCase):  | 
26 | 32 |     def setUp(self):  | 
27 | 33 |         self.images = ops.ones((2, 224, 224, 3))  | 
28 | 34 |         self.labels = ops.zeros((2, 224, 224, 4))  | 
29 |  | -        self.backbone = ResNetBackbone(  | 
 | 35 | +        image_encoder = ResNetBackbone(  | 
30 | 36 |             input_conv_filters=[64],  | 
31 | 37 |             input_conv_kernel_sizes=[7],  | 
32 | 38 |             stackwise_num_filters=[64, 128, 256, 512],  | 
33 | 39 |             stackwise_num_blocks=[3, 4, 6, 3],  | 
34 | 40 |             stackwise_num_strides=[1, 2, 2, 2],  | 
35 | 41 |             block_type="bottleneck_block",  | 
36 | 42 |             image_shape=(224, 224, 3),  | 
37 |  | -            include_rescaling=False,  | 
38 | 43 |         )  | 
 | 44 | +        self.backbone = DifferentialBinarizationBackbone(  | 
 | 45 | +            image_encoder=image_encoder  | 
 | 46 | +        )  | 
 | 47 | +        self.preprocessor = ResNetImageClassifierPreprocessor()  | 
39 | 48 |         self.init_kwargs = {  | 
40 | 49 |             "backbone": self.backbone,  | 
 | 50 | +            "preprocessor": self.preprocessor,  | 
41 | 51 |         }  | 
42 | 52 |         self.train_data = (self.images, self.labels)  | 
43 | 53 | 
 
  | 
44 | 54 |     def test_basics(self):  | 
45 |  | -        pytest.skip(  | 
46 |  | -            reason="TODO: enable after preprocessor flow is figured out"  | 
47 |  | -        )  | 
48 | 55 |         self.run_task_test(  | 
49 | 56 |             cls=DifferentialBinarization,  | 
50 | 57 |             init_kwargs=self.init_kwargs,  | 
 | 
0 commit comments