@@ -183,6 +183,7 @@ def __init__(
183183 dw_start_stride = stride if not dw_kernel_size_mid else 1
184184 dw_start_groups = num_groups (group_size , in_chs )
185185 self .dw_start = ConvNormAct (
186+ nn .Conv2d ,
186187 in_chs ,
187188 in_chs ,
188189 kernel_size = dw_kernel_size_start ,
@@ -199,6 +200,7 @@ def __init__(
199200
200201 mid_chs = make_divisible (in_chs * exp_ratio )
201202 self .pw_exp = ConvNormAct (
203+ nn .Conv2d ,
202204 in_chs ,
203205 mid_chs ,
204206 kernel_size = 1 ,
@@ -212,11 +214,12 @@ def __init__(
212214 if dw_kernel_size_mid :
213215 dw_mid_groups = num_groups (group_size , mid_chs )
214216 self .dw_mid = ConvNormAct (
217+ Conv2dSame ,
215218 mid_chs ,
216219 mid_chs ,
217220 kernel_size = dw_kernel_size_mid ,
218221 stride = stride ,
219- padding = ( dw_kernel_size_mid - 1 ) // 2 ,
222+ padding = 0 ,
220223 dilation = dilation ,
221224 groups = dw_mid_groups ,
222225 bias = False ,
@@ -226,6 +229,7 @@ def __init__(
226229 self .dw_mid = nn .Identity ()
227230
228231 self .pw_proj = ConvNormAct (
232+ nn .Conv2d ,
229233 mid_chs ,
230234 out_chs ,
231235 kernel_size = 1 ,
@@ -257,6 +261,7 @@ def __call__(self, x: mx.array) -> mx.array:
257261class ConvNormAct (nn .Module ):
258262 def __init__ (
259263 self ,
264+ conv_cls ,
260265 in_chs : int ,
261266 out_chs : int ,
262267 kernel_size : int = 3 ,
@@ -270,8 +275,15 @@ def __init__(
270275 ):
271276 super ().__init__ ()
272277 self .out_chs = out_chs
273- self .conv = nn .Conv2d (
274- in_chs , out_chs , kernel_size , stride , padding , dilation , groups , bias
278+ self .conv = conv_cls (
279+ in_chs ,
280+ out_chs ,
281+ kernel_size ,
282+ stride ,
283+ padding ,
284+ (dilation , dilation ),
285+ groups ,
286+ bias ,
275287 )
276288 self .bn = RMSNormAct2d (out_chs , eps = eps , apply_act = apply_act )
277289
@@ -288,17 +300,20 @@ def pad_same(
288300 dilation : List [int ] = (1 , 1 ),
289301 value : float = 0 ,
290302):
291- ih , iw = x .shape [- 2 :]
303+ """
304+ Input should be in MLX format
305+ """
306+ ih , iw = x .shape [1 :3 ]
292307 pad_h = get_same_padding (ih , kernel_size [0 ], stride [0 ], dilation [0 ])
293308 pad_w = get_same_padding (iw , kernel_size [1 ], stride [1 ], dilation [1 ])
294309
295310 # MLX pad format: [(low, high), (low, high), ...] for each axis
296311 # Padding order is reversed compared to PyTorch F.pad
297312 pad_widths = [
298313 (0 , 0 ), # No padding for batch dimension
299- (0 , 0 ), # No padding for channel dimension
300314 (pad_h // 2 , pad_h - pad_h // 2 ), # Height padding
301315 (pad_w // 2 , pad_w - pad_w // 2 ), # Width padding
316+ (0 , 0 ), # No padding for channel dimension
302317 ]
303318
304319 x = mx .pad (x , pad_widths , constant_values = value )
@@ -373,12 +388,16 @@ def is_static_pad(kernel_size, stride=1, dilation=1, **_):
373388class Conv2dSame (nn .Conv2d ):
374389 def __init__ (self , * args , ** kwargs ):
375390 super ().__init__ (* args , ** kwargs )
391+ self .kernel_size = self .weight .shape [1 :3 ]
376392
377- def forward (self , x : mx .array ) -> mx .array :
393+ def __call__ (self , x : mx .array ) -> mx .array :
378394 x = pad_same (x , self .kernel_size , self .stride , self .dilation )
379- return mx .conv2d (
380- x , self .weight , self .bias , self .stride , ( 0 , 0 ) , self .dilation , self .groups
395+ y = mx .conv2d (
396+ x , self .weight , self .stride , self .padding , self .dilation , self .groups
381397 )
398+ if "bias" in self :
399+ y = y + self .bias
400+ return y
382401
383402
384403# https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L629
@@ -409,14 +428,13 @@ def __init__(
409428
410429 self .has_skip = (in_chs == out_chs and stride == 1 ) and not noskip
411430
412- padding = (exp_kernel_size - 1 ) // 2
413- self .conv_exp = nn .Conv2d (
431+ self .conv_exp = Conv2dSame (
414432 in_chs ,
415433 mid_chs ,
416434 kernel_size = exp_kernel_size ,
417435 stride = stride ,
418- padding = padding ,
419- dilation = dilation ,
436+ padding = 0 ,
437+ dilation = ( dilation , dilation ) ,
420438 groups = groups ,
421439 bias = False ,
422440 )
@@ -531,7 +549,6 @@ def __call__(self, x: mx.array) -> mx.array:
531549 # Apply skip connection if available
532550 if self .has_skip :
533551 x = self .drop_path (x ) + shortcut
534-
535552 return x
536553
537554
@@ -864,11 +881,12 @@ class VisionTower(nn.Module):
864881 def __init__ (self , config : VisionConfig ):
865882 super ().__init__ ()
866883 self .conv_stem = ConvNormAct (
884+ Conv2dSame ,
867885 in_chs = 3 ,
868886 out_chs = 64 ,
869887 kernel_size = 3 ,
870888 stride = 2 ,
871- padding = 1 ,
889+ padding = 0 ,
872890 eps = 1e-05 ,
873891 )
874892 msfa_indices = (3 , 4 )
@@ -954,14 +972,12 @@ def __call__(
954972 x = x .transpose (0 , 2 , 3 , 1 ) # Convert from NCHW to NHWC
955973 x = self .conv_stem (x )
956974 intermediates = []
957- hidden_states = []
958975
959976 if feat_idx in self .msfa_indices :
960977 intermediates .append (x )
961978
962979 # MBV5 is constructed of 4 stages, each stage is a group of blocks.
963980 for block_group in self .blocks :
964- # print_array_report(x.transpose(0,3,1,2), f"Stage {feat_idx + 1} input")
965981 feat_idx += 1
966982 for block in block_group :
967983 x = block (x )
0 commit comments