1515from  keras  import  layers 
1616
1717from  keras_nlp .src .api_export  import  keras_nlp_export 
18- from  keras_nlp .src .models .backbone  import  Backbone 
18+ from  keras_nlp .src .models .feature_pyramid_backbone  import  FeaturePyramidBackbone 
1919
2020
2121@keras_nlp_export ("keras_nlp.models.CSPDarkNetBackbone" ) 
22- class  CSPDarkNetBackbone (Backbone ):
22+ class  CSPDarkNetBackbone (FeaturePyramidBackbone ):
2323    """This class represents Keras Backbone of CSPDarkNet model. 
2424
2525    This class implements a CSPDarkNet backbone as described in 
@@ -65,12 +65,15 @@ def __init__(
6565        self ,
6666        stackwise_num_filters ,
6767        stackwise_depth ,
68-         include_rescaling ,
68+         include_rescaling = True ,
6969        block_type = "basic_block" ,
70-         image_shape = (224 ,  224 , 3 ),
70+         image_shape = (None ,  None , 3 ),
7171        ** kwargs ,
7272    ):
7373        # === Functional Model === 
74+         channel_axis  =  (
75+             - 1  if  keras .config .image_data_format () ==  "channels_last"  else  1 
76+         )
7477        apply_ConvBlock  =  (
7578            apply_darknet_conv_block_depthwise 
7679            if  block_type  ==  "depthwise_block" 
@@ -83,15 +86,22 @@ def __init__(
8386        if  include_rescaling :
8487            x  =  layers .Rescaling (scale = 1  /  255.0 )(x )
8588
86-         x  =  apply_focus (name = "stem_focus" )(x )
89+         x  =  apply_focus (channel_axis ,  name = "stem_focus" )(x )
8790        x  =  apply_darknet_conv_block (
88-             base_channels , kernel_size = 3 , strides = 1 , name = "stem_conv" 
91+             base_channels ,
92+             channel_axis ,
93+             kernel_size = 3 ,
94+             strides = 1 ,
95+             name = "stem_conv" ,
8996        )(x )
97+ 
98+         pyramid_outputs  =  {}
9099        for  index , (channels , depth ) in  enumerate (
91100            zip (stackwise_num_filters , stackwise_depth )
92101        ):
93102            x  =  apply_ConvBlock (
94103                channels ,
104+                 channel_axis ,
95105                kernel_size = 3 ,
96106                strides = 2 ,
97107                name = f"dark{ index  +  2 }  ,
@@ -100,17 +110,20 @@ def __init__(
100110            if  index  ==  len (stackwise_depth ) -  1 :
101111                x  =  apply_spatial_pyramid_pooling_bottleneck (
102112                    channels ,
113+                     channel_axis ,
103114                    hidden_filters = channels  //  2 ,
104115                    name = f"dark{ index  +  2 }  ,
105116                )(x )
106117
107118            x  =  apply_cross_stage_partial (
108119                channels ,
120+                 channel_axis ,
109121                num_bottlenecks = depth ,
110122                block_type = "basic_block" ,
111123                residual = (index  !=  len (stackwise_depth ) -  1 ),
112124                name = f"dark{ index  +  2 }  ,
113125            )(x )
126+             pyramid_outputs [f"P{ index  +  2 }  ] =  x 
114127
115128        super ().__init__ (inputs = image_input , outputs = x , ** kwargs )
116129
@@ -120,6 +133,7 @@ def __init__(
120133        self .include_rescaling  =  include_rescaling 
121134        self .block_type  =  block_type 
122135        self .image_shape  =  image_shape 
136+         self .pyramid_outputs  =  pyramid_outputs 
123137
124138    def  get_config (self ):
125139        config  =  super ().get_config ()
@@ -135,7 +149,7 @@ def get_config(self):
135149        return  config 
136150
137151
138- def  apply_focus (name = None ):
152+ def  apply_focus (channel_axis ,  name = None ):
139153    """A block used in CSPDarknet to focus information into channels of the 
140154    image. 
141155
@@ -151,7 +165,7 @@ def apply_focus(name=None):
151165    """ 
152166
153167    def  apply (x ):
154-         return  layers .Concatenate (name = name )(
168+         return  layers .Concatenate (axis = channel_axis ,  name = name )(
155169            [
156170                x [..., ::2 , ::2 , :],
157171                x [..., 1 ::2 , ::2 , :],
@@ -164,7 +178,13 @@ def apply(x):
164178
165179
166180def  apply_darknet_conv_block (
167-     filters , kernel_size , strides , use_bias = False , activation = "silu" , name = None 
181+     filters ,
182+     channel_axis ,
183+     kernel_size ,
184+     strides ,
185+     use_bias = False ,
186+     activation = "silu" ,
187+     name = None ,
168188):
169189    """ 
170190    The basic conv block used in Darknet. Applies Conv2D followed by a 
@@ -193,11 +213,12 @@ def apply(inputs):
193213            kernel_size ,
194214            strides ,
195215            padding = "same" ,
216+             data_format = keras .config .image_data_format (),
196217            use_bias = use_bias ,
197218            name = name  +  "_conv" ,
198219        )(inputs )
199220
200-         x  =  layers .BatchNormalization (name = name  +  "_bn" )(x )
221+         x  =  layers .BatchNormalization (axis = channel_axis ,  name = name  +  "_bn" )(x )
201222
202223        if  activation  ==  "silu" :
203224            x  =  layers .Lambda (lambda  x : keras .activations .silu (x ))(x )
@@ -212,7 +233,7 @@ def apply(inputs):
212233
213234
214235def  apply_darknet_conv_block_depthwise (
215-     filters , kernel_size , strides , activation = "silu" , name = None 
236+     filters , channel_axis ,  kernel_size , strides , activation = "silu" , name = None 
216237):
217238    """ 
218239    The depthwise conv block used in CSPDarknet. 
@@ -236,9 +257,13 @@ def apply_darknet_conv_block_depthwise(
236257
237258    def  apply (inputs ):
238259        x  =  layers .DepthwiseConv2D (
239-             kernel_size , strides , padding = "same" , use_bias = False 
260+             kernel_size ,
261+             strides ,
262+             padding = "same" ,
263+             data_format = keras .config .image_data_format (),
264+             use_bias = False ,
240265        )(inputs )
241-         x  =  layers .BatchNormalization ()(x )
266+         x  =  layers .BatchNormalization (axis = channel_axis )(x )
242267
243268        if  activation  ==  "silu" :
244269            x  =  layers .Lambda (lambda  x : keras .activations .swish (x ))(x )
@@ -248,7 +273,11 @@ def apply(inputs):
248273            x  =  layers .LeakyReLU (0.1 )(x )
249274
250275        x  =  apply_darknet_conv_block (
251-             filters , kernel_size = 1 , strides = 1 , activation = activation 
276+             filters ,
277+             channel_axis ,
278+             kernel_size = 1 ,
279+             strides = 1 ,
280+             activation = activation ,
252281        )(x )
253282
254283        return  x 
@@ -258,6 +287,7 @@ def apply(inputs):
258287
259288def  apply_spatial_pyramid_pooling_bottleneck (
260289    filters ,
290+     channel_axis ,
261291    hidden_filters = None ,
262292    kernel_sizes = (5 , 9 , 13 ),
263293    activation = "silu" ,
@@ -291,6 +321,7 @@ def apply_spatial_pyramid_pooling_bottleneck(
291321    def  apply (x ):
292322        x  =  apply_darknet_conv_block (
293323            hidden_filters ,
324+             channel_axis ,
294325            kernel_size = 1 ,
295326            strides = 1 ,
296327            activation = activation ,
@@ -304,13 +335,15 @@ def apply(x):
304335                    kernel_size ,
305336                    strides = 1 ,
306337                    padding = "same" ,
338+                     data_format = keras .config .image_data_format (),
307339                    name = f"{ name } { kernel_size }  ,
308340                )(x [0 ])
309341            )
310342
311-         x  =  layers .Concatenate (name = f"{ name }  )(x )
343+         x  =  layers .Concatenate (axis = channel_axis ,  name = f"{ name }  )(x )
312344        x  =  apply_darknet_conv_block (
313345            filters ,
346+             channel_axis ,
314347            kernel_size = 1 ,
315348            strides = 1 ,
316349            activation = activation ,
@@ -324,6 +357,7 @@ def apply(x):
324357
325358def  apply_cross_stage_partial (
326359    filters ,
360+     channel_axis ,
327361    num_bottlenecks ,
328362    residual = True ,
329363    block_type = "basic_block" ,
@@ -361,6 +395,7 @@ def apply(inputs):
361395
362396        x1  =  apply_darknet_conv_block (
363397            hidden_channels ,
398+             channel_axis ,
364399            kernel_size = 1 ,
365400            strides = 1 ,
366401            activation = activation ,
@@ -369,6 +404,7 @@ def apply(inputs):
369404
370405        x2  =  apply_darknet_conv_block (
371406            hidden_channels ,
407+             channel_axis ,
372408            kernel_size = 1 ,
373409            strides = 1 ,
374410            activation = activation ,
@@ -379,13 +415,15 @@ def apply(inputs):
379415            residual_x  =  x1 
380416            x1  =  apply_darknet_conv_block (
381417                hidden_channels ,
418+                 channel_axis ,
382419                kernel_size = 1 ,
383420                strides = 1 ,
384421                activation = activation ,
385422                name = f"{ name } { i }  ,
386423            )(x1 )
387424            x1  =  ConvBlock (
388425                hidden_channels ,
426+                 channel_axis ,
389427                kernel_size = 3 ,
390428                strides = 1 ,
391429                activation = activation ,
@@ -399,6 +437,7 @@ def apply(inputs):
399437        x  =  layers .Concatenate (name = f"{ name }  )([x1 , x2 ])
400438        x  =  apply_darknet_conv_block (
401439            filters ,
440+             channel_axis ,
402441            kernel_size = 1 ,
403442            strides = 1 ,
404443            activation = activation ,
0 commit comments