Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

touched up Ijepa #29

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 3 additions & 49 deletions nanodl/__src/models/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class PatchEmbedding(nn.Module):
patch_size (int): Size of square patches from image.
embed_dim (int): Dimension of the embeddings for the patches.

Methods:
setup(): Calculates `num_patches` and initialises Conv layer.
__call__(x: jnp.ndarray): Passes image through Conv layer which extracts patches and projects into emebdding space.
"""
image_size:int
patch_size:int
Expand All @@ -43,7 +40,6 @@ def setup(self):
def __call__(self, x:jnp.ndarray) -> jnp.ndarray:
x = self.proj(x)
x = jnp.reshape(x, (x.shape[0], -1, self.embed_dim)) # (batch_size, num_patches, embed_dim)

return x


Expand All @@ -57,9 +53,6 @@ class PositionalEmbedding(nn.Module):
embed_dim (int): Patch embedding dimensions.
num_patches (int): Number of patches in an image which is dependent on the patch size.

Methods:
setup(): Initialises embedding layer
__call__(x: jnp.ndarray): Passes a tensor of positions through the positional embedding and adds the positional embeddings to the patch embeddings.
"""
embed_dim:int
num_patches:int
Expand All @@ -71,12 +64,9 @@ def setup(self):
)

def __call__(self, x:jnp.ndarray) -> jnp.ndarray:
# assuming x of shape (batch_size, num_tokens, embed_dim)
positions = jnp.arange(x.shape[1])[jnp.newaxis, :].repeat(x.shape[0], axis=0)
embed = self.embedding(positions)

x = x + embed

return x


Expand All @@ -90,9 +80,6 @@ class MultiHeadedAttention(nn.Module):
embed_dim (int): Dimensionality of the input and output features.
num_heads (int): Number of attention heads.

Methods:
setup(): Initializes projection matrices for queries, keys, values, and the output projection.
__call__(x: jnp.ndarray): Processes the input tensor through the multi-head self-attention mechanism.
"""
embed_dim:int
num_heads:int
Expand All @@ -107,25 +94,17 @@ def setup(self):
def __call__(self, x:jnp.ndarray) -> jnp.ndarray:
qkv = self.attn_proj(x)
query, key, value = jnp.array_split(qkv, 3, axis=-1)

query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_heads, -1))
key = jnp.reshape(key, (key.shape[0], key.shape[1], self.num_heads, -1))
value = jnp.reshape(value, (value.shape[0], value.shape[1], self.num_heads, -1))

# permute to (batch_size, num_heads, seq_len, embed_dim)

query = jnp.permute_dims(query, (0, 2, 1, 3))
key = jnp.permute_dims(key, (0, 2, 1, 3))
value = jnp.permute_dims(value, (0, 2, 1, 3))

attn_weights = jnp.matmul(query, key.transpose(0, 1, 3, 2)) / (self.embed_dim **.5)
attn_weights = nn.softmax(attn_weights, -1)

attn = jnp.matmul(attn_weights, value)
attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) # convert back to (batch_size, seq_len, embed_dim)

attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim))
attn = self.out_proj(attn)

return attn, attn_weights


Expand All @@ -141,9 +120,6 @@ class TransformerEncoderBlock(nn.Module):
feed_forward_dim (int): Dimension of the feed-forward network.
dropout_p (float): Dropout rate.

Methods:
setup(): Initializes the attention layer, feed forward layers and norm layers.
__call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block.
"""
embed_dim:int
num_heads:int
Expand Down Expand Up @@ -171,10 +147,8 @@ def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray:
x_, attn_weights = self.attn(self.norm1(x))
x = x + x_
x = self.dropout(x, deterministic=not training)

x = x + self.ff(self.norm2(x))
x = self.dropout(x, deterministic=not training)

return x, attn_weights


Expand All @@ -190,10 +164,7 @@ class TransformerEncoder(nn.Module):
embed_dim (int): Dimensionality of inputs and outputs.
num_layers (int): Number of encoder blocks.
feed_forward_dim (int): Dimension of the feed-forward network.

Methods:
setup(): Initializes the attention layer, feed forward layers and norm layers.
__call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block.

"""
dropout:float
num_heads:int
Expand All @@ -214,11 +185,9 @@ def setup(self):

def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray:
attn_maps = []

for layer in self.layers:
x, attn_weights = layer(x, training=training)
attn_maps.append(attn_weights)

return x, jnp.array(attn_maps)


Expand All @@ -240,11 +209,6 @@ class IJEPA(nn.Module):
predictor_num_heads (int): Number of transformer encoder heads for embedding predictor.
share_patch_embedding (bool): Whether or not to share the patch embeddings across the context and target encoders.


Methods:
setup(): Initializes the attention layer, feed forward layers and norm layers.
__call__(x:jnp.ndarray, content_mask:jnp.ndarray, target_mask:jnp.ndarray): Applies the context and target masks to the image to get the context and target blocks, then obtains the predicted representations of the target blocks.

Example usage:
```py
import jax
Expand Down Expand Up @@ -390,11 +354,8 @@ def setup(self):
self.to_encoder_embed = nn.Dense(self.embed_dim)

def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndarray, training:bool=False) -> Tuple[List[Tuple[jnp.ndarray, jnp.ndarray]], List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]:
# content & target masks of shape (N, M, num_patches)

x_context = self.patch_embedding["context"](x)
x_context = self.positional_embedding(x_context)

x_target = self.patch_embedding["target"](x)
x_target = self.positional_embedding(x_target)

Expand All @@ -404,7 +365,6 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar
for m in range(context_mask.shape[1]):
context, context_attn_weights = self.context_encoder(x_context, training=training)
context = context * jnp.expand_dims(context_mask[:, m], -1) # (N, num_patches, E)

target, target_attn_weights = self.target_encoder(x_target, training=training)
target = target * jnp.expand_dims(target_mask[:, m], -1) # (N, num_patches, E)

Expand All @@ -415,14 +375,10 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar

predicted_embeddings = self.to_encoder_embed(predicted_embeddings)
predicted_embeddings = predicted_embeddings * jnp.expand_dims(target_mask[:, m], -1)

outputs.append((predicted_embeddings, target))
attn_weights.append((context_attn_weights, target_attn_weights, embed_attn_weights))

return (
outputs,
attn_weights
)
return (outputs, attn_weights)


class IJEPADataSampler:
Expand Down Expand Up @@ -543,10 +499,8 @@ def create_train_state(self,

rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
context_mask, target_mask = self.data_sampler()

context_mask = jnp.repeat(context_mask[jnp.newaxis], input_shape[0], axis=0)
target_mask = jnp.repeat(target_mask[jnp.newaxis], input_shape[0], axis=0)

params = self.model.init(rngs, jnp.ones(input_shape), context_mask, target_mask)['params']

if self.params_path is not None:
Expand Down
22 changes: 0 additions & 22 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,27 +418,5 @@ def test_ijepa_model_initialization_and_processing(self):
self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim))
self.assertEqual(outputs[0][0].shape, outputs[0][1].shape)


def test_ijepa_training(self):
x = jax.random.normal(
jax.random.PRNGKey(0),
(9, self.image_size, self.image_size, self.num_channels)
)

dataset = ArrayDataset(x)

dataloader = DataLoader(dataset,
batch_size=3,
shuffle=True,
drop_last=False)

data_sampler = IJEPADataSampler(
image_size=self.image_size,
patch_size=self.patch_size
)

trainer = IJEPADataParallelTrainer(self.model, x.shape, 'params.pkl', data_sampler=data_sampler)
trainer.train(dataloader, 10, dataloader)

if __name__ == '__main__':
unittest.main()
Loading