Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions equimo/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def features(
key: PRNGKeyArray,
mask: Optional[Int[Array, "embed_h embed_w"]] = None,
inference: Optional[bool] = None,
**kwargs,
) -> Float[Array, "seqlen dim"]:
"""Extract features from input image.

Expand Down Expand Up @@ -434,7 +435,7 @@ def features(
x = self._pos_embed(x, h=self.embed_size, w=self.embed_size)

for blk, key_block in zip(self.blocks, block_subkeys):
x = blk(x, inference=inference, key=key_block)
x = blk(x, inference=inference, key=key_block, **kwargs)

return x

Expand All @@ -443,6 +444,7 @@ def forward_features(
x: Float[Array, "channels height width"],
key: PRNGKeyArray,
inference: Optional[bool] = None,
**kwargs,
) -> dict:
"""Process features and return intermediate representations.

Expand All @@ -458,7 +460,7 @@ def forward_features(
- x_norm_patchtokens: Normalized patch tokens
- x_prenorm: Pre-normalized features
"""
x = self.features(x, inference=inference, key=key)
x = self.features(x, inference=inference, key=key, **kwargs)
x_norm = jax.vmap(self.norm)(x)

return {
Expand All @@ -473,6 +475,7 @@ def __call__(
x: Float[Array, "channels height width"],
key: PRNGKeyArray = jr.PRNGKey(42),
inference: Optional[bool] = None,
**kwargs,
) -> Float[Array, "num_classes"]:
"""Process input image through the full network.

Expand All @@ -484,7 +487,7 @@ def __call__(
Returns:
Classification logits
"""
x = self.features(x, inference=inference, key=key)
x = self.features(x, inference=inference, key=key, **kwargs)
x = jax.vmap(self.norm)(x)
x = pool_sd(
x,
Expand Down