diff --git a/equimo/models/vit.py b/equimo/models/vit.py index 22dec4a..a0fcfad 100644 --- a/equimo/models/vit.py +++ b/equimo/models/vit.py @@ -403,6 +403,7 @@ def features( key: PRNGKeyArray, mask: Optional[Int[Array, "embed_h embed_w"]] = None, inference: Optional[bool] = None, + **kwargs, ) -> Float[Array, "seqlen dim"]: """Extract features from input image. @@ -434,7 +435,7 @@ def features( x = self._pos_embed(x, h=self.embed_size, w=self.embed_size) for blk, key_block in zip(self.blocks, block_subkeys): - x = blk(x, inference=inference, key=key_block) + x = blk(x, inference=inference, key=key_block, **kwargs) return x @@ -443,6 +444,7 @@ def forward_features( x: Float[Array, "channels height width"], key: PRNGKeyArray, inference: Optional[bool] = None, + **kwargs, ) -> dict: """Process features and return intermediate representations. @@ -458,7 +460,7 @@ def forward_features( - x_norm_patchtokens: Normalized patch tokens - x_prenorm: Pre-normalized features """ - x = self.features(x, inference=inference, key=key) + x = self.features(x, inference=inference, key=key, **kwargs) x_norm = jax.vmap(self.norm)(x) return { @@ -473,6 +475,7 @@ def __call__( x: Float[Array, "channels height width"], key: PRNGKeyArray = jr.PRNGKey(42), inference: Optional[bool] = None, + **kwargs, ) -> Float[Array, "num_classes"]: """Process input image through the full network. @@ -484,7 +487,7 @@ def __call__( Returns: Classification logits """ - x = self.features(x, inference=inference, key=key) + x = self.features(x, inference=inference, key=key, **kwargs) x = jax.vmap(self.norm)(x) x = pool_sd( x,