Skip to content

Commit 4845b6a

Browse files
committed
Fix tests
1 parent 6797231 commit 4845b6a

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

keras_hub/src/models/differential_binarization/differential_binarization_backbone.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import keras
1516
from keras import layers
1617

1718
from keras_hub.src.api_export import keras_hub_export
@@ -53,9 +54,16 @@ def __init__(
5354
def get_config(self):
5455
config = super().get_config()
5556
config["fpn_channels"] = self.fpn_channels
56-
config["image_encoder"] = self.image_encoder
57+
config["image_encoder"] = keras.layers.serialize(self.image_encoder)
5758
return config
5859

60+
@classmethod
61+
def from_config(cls, config):
62+
config["image_encoder"] = keras.layers.deserialize(
63+
config["image_encoder"]
64+
)
65+
return cls(**config)
66+
5967

6068
def diffbin_fpn_model(inputs, out_channels):
6169
in2 = layers.Conv2D(

keras_hub/src/models/differential_binarization/differential_binarization_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,40 @@
1818
from keras_hub.src.models.differential_binarization.differential_binarization import (
1919
DifferentialBinarization,
2020
)
21+
from keras_hub.src.models.differential_binarization.differential_binarization_backbone import (
22+
DifferentialBinarizationBackbone,
23+
)
2124
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
25+
from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
26+
ResNetImageClassifierPreprocessor,
27+
)
2228
from keras_hub.src.tests.test_case import TestCase
2329

2430

2531
class DifferentialBinarizationTest(TestCase):
2632
def setUp(self):
2733
self.images = ops.ones((2, 224, 224, 3))
2834
self.labels = ops.zeros((2, 224, 224, 4))
29-
self.backbone = ResNetBackbone(
35+
image_encoder = ResNetBackbone(
3036
input_conv_filters=[64],
3137
input_conv_kernel_sizes=[7],
3238
stackwise_num_filters=[64, 128, 256, 512],
3339
stackwise_num_blocks=[3, 4, 6, 3],
3440
stackwise_num_strides=[1, 2, 2, 2],
3541
block_type="bottleneck_block",
3642
image_shape=(224, 224, 3),
37-
include_rescaling=False,
3843
)
44+
self.backbone = DifferentialBinarizationBackbone(
45+
image_encoder=image_encoder
46+
)
47+
self.preprocessor = ResNetImageClassifierPreprocessor()
3948
self.init_kwargs = {
4049
"backbone": self.backbone,
50+
"preprocessor": self.preprocessor,
4151
}
4252
self.train_data = (self.images, self.labels)
4353

4454
def test_basics(self):
45-
pytest.skip(
46-
reason="TODO: enable after preprocessor flow is figured out"
47-
)
4855
self.run_task_test(
4956
cls=DifferentialBinarization,
5057
init_kwargs=self.init_kwargs,

0 commit comments

Comments
 (0)