Skip to content

Commit f3128bd

Browse files
authored
[gemma3n] Fix vision encoder implementation (Blaizzy#410)
1 parent 74021c2 commit f3128bd

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

mlx_vlm/models/gemma3n/vision.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
dw_start_stride = stride if not dw_kernel_size_mid else 1
184184
dw_start_groups = num_groups(group_size, in_chs)
185185
self.dw_start = ConvNormAct(
186+
nn.Conv2d,
186187
in_chs,
187188
in_chs,
188189
kernel_size=dw_kernel_size_start,
@@ -199,6 +200,7 @@ def __init__(
199200

200201
mid_chs = make_divisible(in_chs * exp_ratio)
201202
self.pw_exp = ConvNormAct(
203+
nn.Conv2d,
202204
in_chs,
203205
mid_chs,
204206
kernel_size=1,
@@ -212,11 +214,12 @@ def __init__(
212214
if dw_kernel_size_mid:
213215
dw_mid_groups = num_groups(group_size, mid_chs)
214216
self.dw_mid = ConvNormAct(
217+
Conv2dSame,
215218
mid_chs,
216219
mid_chs,
217220
kernel_size=dw_kernel_size_mid,
218221
stride=stride,
219-
padding=(dw_kernel_size_mid - 1) // 2,
222+
padding=0,
220223
dilation=dilation,
221224
groups=dw_mid_groups,
222225
bias=False,
@@ -226,6 +229,7 @@ def __init__(
226229
self.dw_mid = nn.Identity()
227230

228231
self.pw_proj = ConvNormAct(
232+
nn.Conv2d,
229233
mid_chs,
230234
out_chs,
231235
kernel_size=1,
@@ -257,6 +261,7 @@ def __call__(self, x: mx.array) -> mx.array:
257261
class ConvNormAct(nn.Module):
258262
def __init__(
259263
self,
264+
conv_cls,
260265
in_chs: int,
261266
out_chs: int,
262267
kernel_size: int = 3,
@@ -270,8 +275,15 @@ def __init__(
270275
):
271276
super().__init__()
272277
self.out_chs = out_chs
273-
self.conv = nn.Conv2d(
274-
in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias
278+
self.conv = conv_cls(
279+
in_chs,
280+
out_chs,
281+
kernel_size,
282+
stride,
283+
padding,
284+
(dilation, dilation),
285+
groups,
286+
bias,
275287
)
276288
self.bn = RMSNormAct2d(out_chs, eps=eps, apply_act=apply_act)
277289

@@ -288,17 +300,20 @@ def pad_same(
288300
dilation: List[int] = (1, 1),
289301
value: float = 0,
290302
):
291-
ih, iw = x.shape[-2:]
303+
"""
304+
Input should be in MLX format
305+
"""
306+
ih, iw = x.shape[1:3]
292307
pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
293308
pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
294309

295310
# MLX pad format: [(low, high), (low, high), ...] for each axis
296311
# Padding order is reversed compared to PyTorch F.pad
297312
pad_widths = [
298313
(0, 0), # No padding for batch dimension
299-
(0, 0), # No padding for channel dimension
300314
(pad_h // 2, pad_h - pad_h // 2), # Height padding
301315
(pad_w // 2, pad_w - pad_w // 2), # Width padding
316+
(0, 0), # No padding for channel dimension
302317
]
303318

304319
x = mx.pad(x, pad_widths, constant_values=value)
@@ -373,12 +388,16 @@ def is_static_pad(kernel_size, stride=1, dilation=1, **_):
373388
class Conv2dSame(nn.Conv2d):
374389
def __init__(self, *args, **kwargs):
375390
super().__init__(*args, **kwargs)
391+
self.kernel_size = self.weight.shape[1:3]
376392

377-
def forward(self, x: mx.array) -> mx.array:
393+
def __call__(self, x: mx.array) -> mx.array:
378394
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
379-
return mx.conv2d(
380-
x, self.weight, self.bias, self.stride, (0, 0), self.dilation, self.groups
395+
y = mx.conv2d(
396+
x, self.weight, self.stride, self.padding, self.dilation, self.groups
381397
)
398+
if "bias" in self:
399+
y = y + self.bias
400+
return y
382401

383402

384403
# https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L629
@@ -409,14 +428,13 @@ def __init__(
409428

410429
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
411430

412-
padding = (exp_kernel_size - 1) // 2
413-
self.conv_exp = nn.Conv2d(
431+
self.conv_exp = Conv2dSame(
414432
in_chs,
415433
mid_chs,
416434
kernel_size=exp_kernel_size,
417435
stride=stride,
418-
padding=padding,
419-
dilation=dilation,
436+
padding=0,
437+
dilation=(dilation, dilation),
420438
groups=groups,
421439
bias=False,
422440
)
@@ -531,7 +549,6 @@ def __call__(self, x: mx.array) -> mx.array:
531549
# Apply skip connection if available
532550
if self.has_skip:
533551
x = self.drop_path(x) + shortcut
534-
535552
return x
536553

537554

@@ -864,11 +881,12 @@ class VisionTower(nn.Module):
864881
def __init__(self, config: VisionConfig):
865882
super().__init__()
866883
self.conv_stem = ConvNormAct(
884+
Conv2dSame,
867885
in_chs=3,
868886
out_chs=64,
869887
kernel_size=3,
870888
stride=2,
871-
padding=1,
889+
padding=0,
872890
eps=1e-05,
873891
)
874892
msfa_indices = (3, 4)
@@ -954,14 +972,12 @@ def __call__(
954972
x = x.transpose(0, 2, 3, 1) # Convert from NCHW to NHWC
955973
x = self.conv_stem(x)
956974
intermediates = []
957-
hidden_states = []
958975

959976
if feat_idx in self.msfa_indices:
960977
intermediates.append(x)
961978

962979
# MBV5 is constructed of 4 stages, each stage is a group of blocks.
963980
for block_group in self.blocks:
964-
# print_array_report(x.transpose(0,3,1,2), f"Stage {feat_idx + 1} input")
965981
feat_idx += 1
966982
for block in block_group:
967983
x = block(x)

0 commit comments

Comments
 (0)