From 4fd0092e27c95405c6e66c3890012f90186c6784 Mon Sep 17 00:00:00 2001 From: HMUNACHI Date: Mon, 25 Mar 2024 12:21:47 +0000 Subject: [PATCH] touched up Ijepa --- nanodl/__src/models/ijepa.py | 52 +++--------------------------------- tests/test_models.py | 22 --------------- 2 files changed, 3 insertions(+), 71 deletions(-) diff --git a/nanodl/__src/models/ijepa.py b/nanodl/__src/models/ijepa.py index 01f5acc..6b8b5ed 100644 --- a/nanodl/__src/models/ijepa.py +++ b/nanodl/__src/models/ijepa.py @@ -20,9 +20,6 @@ class PatchEmbedding(nn.Module): patch_size (int): Size of square patches from image. embed_dim (int): Dimension of the embeddings for the patches. - Methods: - setup(): Calculates `num_patches` and initialises Conv layer. - __call__(x: jnp.ndarray): Passes image through Conv layer which extracts patches and projects into emebdding space. """ image_size:int patch_size:int @@ -43,7 +40,6 @@ def setup(self): def __call__(self, x:jnp.ndarray) -> jnp.ndarray: x = self.proj(x) x = jnp.reshape(x, (x.shape[0], -1, self.embed_dim)) # (batch_size, num_patches, embed_dim) - return x @@ -57,9 +53,6 @@ class PositionalEmbedding(nn.Module): embed_dim (int): Patch embedding dimensions. num_patches (int): Number of patches in an image which is dependent on the patch size. - Methods: - setup(): Initialises embedding layer - __call__(x: jnp.ndarray): Passes a tensor of positions through the positional embedding and adds the positional embeddings to the patch embeddings. """ embed_dim:int num_patches:int @@ -71,12 +64,9 @@ def setup(self): ) def __call__(self, x:jnp.ndarray) -> jnp.ndarray: - # assuming x of shape (batch_size, num_tokens, embed_dim) positions = jnp.arange(x.shape[1])[jnp.newaxis, :].repeat(x.shape[0], axis=0) embed = self.embedding(positions) - x = x + embed - return x @@ -90,9 +80,6 @@ class MultiHeadedAttention(nn.Module): embed_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes projection matrices for queries, keys, values, and the output projection. - __call__(x: jnp.ndarray): Processes the input tensor through the multi-head self-attention mechanism. """ embed_dim:int num_heads:int @@ -107,25 +94,17 @@ def setup(self): def __call__(self, x:jnp.ndarray) -> jnp.ndarray: qkv = self.attn_proj(x) query, key, value = jnp.array_split(qkv, 3, axis=-1) - query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_heads, -1)) key = jnp.reshape(key, (key.shape[0], key.shape[1], self.num_heads, -1)) value = jnp.reshape(value, (value.shape[0], value.shape[1], self.num_heads, -1)) - - # permute to (batch_size, num_heads, seq_len, embed_dim) - query = jnp.permute_dims(query, (0, 2, 1, 3)) key = jnp.permute_dims(key, (0, 2, 1, 3)) value = jnp.permute_dims(value, (0, 2, 1, 3)) - attn_weights = jnp.matmul(query, key.transpose(0, 1, 3, 2)) / (self.embed_dim **.5) attn_weights = nn.softmax(attn_weights, -1) - attn = jnp.matmul(attn_weights, value) - attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) # convert back to (batch_size, seq_len, embed_dim) - + attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) attn = self.out_proj(attn) - return attn, attn_weights @@ -141,9 +120,6 @@ class TransformerEncoderBlock(nn.Module): feed_forward_dim (int): Dimension of the feed-forward network. dropout_p (float): Dropout rate. - Methods: - setup(): Initializes the attention layer, feed forward layers and norm layers. - __call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block. """ embed_dim:int num_heads:int @@ -171,10 +147,8 @@ def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray: x_, attn_weights = self.attn(self.norm1(x)) x = x + x_ x = self.dropout(x, deterministic=not training) - x = x + self.ff(self.norm2(x)) x = self.dropout(x, deterministic=not training) - return x, attn_weights @@ -190,10 +164,7 @@ class TransformerEncoder(nn.Module): embed_dim (int): Dimensionality of inputs and outputs. num_layers (int): Number of encoder blocks. feed_forward_dim (int): Dimension of the feed-forward network. - - Methods: - setup(): Initializes the attention layer, feed forward layers and norm layers. - __call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block. + """ dropout:float num_heads:int @@ -214,11 +185,9 @@ def setup(self): def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray: attn_maps = [] - for layer in self.layers: x, attn_weights = layer(x, training=training) attn_maps.append(attn_weights) - return x, jnp.array(attn_maps) @@ -240,11 +209,6 @@ class IJEPA(nn.Module): predictor_num_heads (int): Number of transformer encoder heads for embedding predictor. share_patch_embedding (bool): Whether or not to share the patch embeddings across the context and target encoders. - - Methods: - setup(): Initializes the attention layer, feed forward layers and norm layers. - __call__(x:jnp.ndarray, content_mask:jnp.ndarray, target_mask:jnp.ndarray): Applies the context and target masks to the image to get the context and target blocks, then obtains the predicted representations of the target blocks. - Example usage: ```py import jax @@ -390,11 +354,8 @@ def setup(self): self.to_encoder_embed = nn.Dense(self.embed_dim) def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndarray, training:bool=False) -> Tuple[List[Tuple[jnp.ndarray, jnp.ndarray]], List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]: - # content & target masks of shape (N, M, num_patches) - x_context = self.patch_embedding["context"](x) x_context = self.positional_embedding(x_context) - x_target = self.patch_embedding["target"](x) x_target = self.positional_embedding(x_target) @@ -404,7 +365,6 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar for m in range(context_mask.shape[1]): context, context_attn_weights = self.context_encoder(x_context, training=training) context = context * jnp.expand_dims(context_mask[:, m], -1) # (N, num_patches, E) - target, target_attn_weights = self.target_encoder(x_target, training=training) target = target * jnp.expand_dims(target_mask[:, m], -1) # (N, num_patches, E) @@ -415,14 +375,10 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar predicted_embeddings = self.to_encoder_embed(predicted_embeddings) predicted_embeddings = predicted_embeddings * jnp.expand_dims(target_mask[:, m], -1) - outputs.append((predicted_embeddings, target)) attn_weights.append((context_attn_weights, target_attn_weights, embed_attn_weights)) - return ( - outputs, - attn_weights - ) + return (outputs, attn_weights) class IJEPADataSampler: @@ -543,10 +499,8 @@ def create_train_state(self, rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} context_mask, target_mask = self.data_sampler() - context_mask = jnp.repeat(context_mask[jnp.newaxis], input_shape[0], axis=0) target_mask = jnp.repeat(target_mask[jnp.newaxis], input_shape[0], axis=0) - params = self.model.init(rngs, jnp.ones(input_shape), context_mask, target_mask)['params'] if self.params_path is not None: diff --git a/tests/test_models.py b/tests/test_models.py index e2b9ff3..f078302 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -418,27 +418,5 @@ def test_ijepa_model_initialization_and_processing(self): self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim)) self.assertEqual(outputs[0][0].shape, outputs[0][1].shape) - - def test_ijepa_training(self): - x = jax.random.normal( - jax.random.PRNGKey(0), - (9, self.image_size, self.image_size, self.num_channels) - ) - - dataset = ArrayDataset(x) - - dataloader = DataLoader(dataset, - batch_size=3, - shuffle=True, - drop_last=False) - - data_sampler = IJEPADataSampler( - image_size=self.image_size, - patch_size=self.patch_size - ) - - trainer = IJEPADataParallelTrainer(self.model, x.shape, 'params.pkl', data_sampler=data_sampler) - trainer.train(dataloader, 10, dataloader) - if __name__ == '__main__': unittest.main() \ No newline at end of file