-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CLIP to ivy models #34
Open
kgmann
wants to merge
18
commits into
ivy-llc:main
Choose a base branch
from
kgmann:feat/clip-model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,102
−4
Open
Changes from 10 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
dd0edea
feat: add CLIP models to ivy
kgmann dea23d9
feat: added tests and fixed some bugs
kgmann 2a6d794
fix: Add forgotten test file
kgmann ad30582
refactor: refactored the code to use the new MHA layer
kgmann a5d3777
refactor: changes files' structure to adhere to the repo's guidelines
kgmann 81e3854
refactor: refactored to build layers in _build and use initializers i…
kgmann 96d278d
Merge branch 'master' into feat/clip-model
juliagsy 898180f
fix: added docs, refactored to use existing weights helpers, fixed so…
kgmann 782b3fe
Merge remote-tracking branch 'refs/remotes/origin/feat/clip-model' in…
kgmann ac43df0
refactor: minor changes in the test file
kgmann 99747a4
refactor: address PR reviews
kgmann d7de8ac
refactor: minor refactoring
kgmann 5c828d4
Merge branch 'master' into feat/clip-model
kgmann 8bfa751
refactor: refactored test to test only one model variant
kgmann 9307b4e
Merge branch 'main' into feat/clip-model
juliagsy c578985
fix: updated the RandomNormal init calls
kgmann 679d5b7
Merge remote-tracking branch 'refs/remotes/origin/feat/clip-model' in…
kgmann b12e84d
Merge branch 'main' into feat/clip-model
kgmann File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,6 @@ | |
|
||
from . import densenet | ||
from .densenet import * | ||
|
||
from . import clip | ||
from .clip import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# This submodule is heavily inspired by OpenAI's original implementation - | ||
# https://github.com/openai/CLIP | ||
|
||
from . import clip | ||
from .clip import * |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
from typing import Tuple, Union | ||
|
||
import numpy as np | ||
import ivy | ||
from ivy.stateful.initializers import Ones | ||
|
||
from .layers import * | ||
import ivy_models | ||
from .misc import ( | ||
get_model_args, | ||
get_clip_weights_url, | ||
load_clip_state_dict, | ||
tokenize, | ||
get_processors, | ||
) | ||
|
||
__all__ = ["CLIP", "load_clip", "tokenize", "get_processors"] | ||
|
||
|
||
class CLIP(ivy.Module): | ||
def __init__( | ||
self, | ||
embed_dim: int, | ||
# vision | ||
image_resolution: int, | ||
vision_layers: Union[Tuple[int, int, int, int], int], | ||
vision_width: int, | ||
vision_patch_size: int, | ||
# text | ||
context_length: int, | ||
vocab_size: int, | ||
transformer_width: int, | ||
transformer_heads: int, | ||
transformer_layers: int, | ||
# ivy | ||
device=None, | ||
v=None, | ||
): | ||
""" | ||
An ivy implementation of the CLIP model in fp32. | ||
The image encoders from the original implementation can be one of the following | ||
- Modified resnet variants (RN50, RN101, RN50x4, RN50x16, RNx64) | ||
- ViT variants: (ViT-B/32, ViT-B/16, ViT-L/14, ViT-l/14@336px) | ||
|
||
Parameters | ||
---------- | ||
embed_dim : | ||
Feature dimension that the text and image encoders will be projected to. | ||
image_resolution : | ||
Input image's resolution expected by the image encoder. (e.g. for 'RN101' it's 224) | ||
vision layers : | ||
For the ViT image encoders it's an integer that represents the number of residual attention block. | ||
For the modified Resnets it's a tuple of four integers that represent the number of residual block in each of the four residual layers. | ||
vision_width : | ||
For the Resnets it's the number of channels in the first residual layer. For the ViT it's the transformer's feature dimension. | ||
(.i.e. In both cases the final visual features are projected to embed_dim.) | ||
vision_patch_size: | ||
The patch size of the ViT encoder. Not application to the Resnets. | ||
context_length : | ||
The context length of the text encoder | ||
vocab_size : | ||
The size of the vocabulary. Used in the embedding layer. | ||
transformer_width : | ||
The feature dimension of the text encoder. (e.i. It's later projected to embed_dim) | ||
transformer_heads : | ||
Number of attention head per residual attention block for the text encoder. | ||
transformer_layers : | ||
Number of residual attention block in the text encoder. | ||
""" | ||
|
||
self.embed_dim = embed_dim | ||
self.image_resolution = image_resolution | ||
self.vision_layers = vision_layers | ||
self.vision_width = vision_width | ||
self.vision_patch_size = vision_patch_size | ||
|
||
self.context_length = context_length | ||
self.vocab_size = vocab_size | ||
self.transformer_width = transformer_width | ||
self.transformer_heads = transformer_heads | ||
self.transformer_layers = transformer_layers | ||
|
||
self._pos_embed_shape = (self.context_length, self.transformer_width) | ||
self._text_proj_shape = (self.transformer_width, self.embed_dim) | ||
self._scale_init = Ones() | ||
|
||
super().__init__(device=device, v=v) | ||
|
||
def _build(self, *args, **kwargs): | ||
if isinstance(self.vision_layers, (tuple, list)): | ||
vision_heads = self.vision_width * 32 // 64 | ||
self.visual = ModifiedResNet( | ||
layers=self.vision_layers, | ||
output_dim=self.embed_dim, | ||
heads=vision_heads, | ||
input_resolution=self.image_resolution, | ||
width=self.vision_width, | ||
) | ||
else: | ||
vision_heads = self.vision_width // 64 | ||
self.visual = VisionTransformer( | ||
input_resolution=self.image_resolution, | ||
patch_size=self.vision_patch_size, | ||
width=self.vision_width, | ||
layers=self.vision_layers, | ||
heads=vision_heads, | ||
output_dim=self.embed_dim, | ||
) | ||
|
||
self.transformer = Transformer( | ||
width=self.transformer_width, | ||
layers=self.transformer_layers, | ||
heads=self.transformer_heads, | ||
attn_mask=self.build_attention_mask(), | ||
) | ||
|
||
self.token_embedding = Embedding(self.vocab_size, self.transformer_width) | ||
self.ln_final = ivy.LayerNorm([self.transformer_width]) | ||
|
||
def _create_variables(self, *, device=None, dtype=None): | ||
v = { | ||
"positional_embedding": ivy.empty( | ||
self._pos_embed_shape, dtype=dtype, device=device | ||
), | ||
"text_projection": ivy.empty( | ||
self._text_proj_shape, dtype=dtype, device=device | ||
), | ||
# Casting to float32 because of an issue with avg_pool2d for jax backend when jax_enable_x64 is set to True | ||
"logit_scale": self._scale_init.create_variables([], device, dtype=dtype) | ||
* np.log(1 / 0.07).astype(ivy.float32), | ||
} | ||
return v | ||
|
||
def build_attention_mask(self): | ||
# lazily create causal attention mask, with full attention between the vision tokens | ||
# pytorch uses additive attention mask; but ivy expect a boolean mask (it's converted to a boolean mask) | ||
# IVY: Made changes to the mask cause ivy's behavior for float masks is different compared to torch | ||
mask = ivy.ones((self.context_length, self.context_length)) | ||
mask = mask.tril(k=0) | ||
return mask | ||
|
||
@property | ||
def dtype(self): | ||
return self.visual.conv1.v.w.dtype | ||
|
||
def encode_image(self, image): | ||
return self.visual(image) | ||
|
||
def encode_text(self, text): | ||
x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | ||
|
||
x = x + self.v.positional_embedding | ||
x = x.permute_dims((1, 0, 2)) # NLD -> LND | ||
x = self.transformer(x) | ||
x = x.permute_dims((1, 0, 2)) # LND -> NLD | ||
|
||
# x.shape = [batch_size, n_ctx, transformer.width] | ||
# take features from the eot embedding (eot_token is the highest number in each sequence) | ||
x = x[ivy.arange(x.shape[0]), text.argmax(axis=-1)] @ self.v.text_projection | ||
|
||
return x | ||
|
||
def _forward( | ||
self, | ||
image: Union[ivy.Array, ivy.NativeArray], | ||
text: Union[ivy.Array, ivy.NativeArray], | ||
): | ||
image_features = self.encode_image(image) | ||
text_features = self.encode_text(text) | ||
|
||
# normalized features | ||
image_features = image_features / image_features.vector_norm( | ||
axis=1, keepdims=True | ||
) | ||
text_features = text_features / text_features.vector_norm(axis=1, keepdims=True) | ||
|
||
# cosine similarity as logits | ||
logit_scale = self.v.logit_scale.exp() | ||
logits_per_image = logit_scale * image_features @ text_features.T | ||
logits_per_text = logits_per_image.T | ||
|
||
# shape = [global_batch_size, global_batch_size] | ||
return logits_per_image, logits_per_text | ||
|
||
|
||
def _clip_torch_mapping(old_key, new_key): | ||
new_mapping = new_key | ||
|
||
if "conv" in old_key: | ||
if "/weight" in old_key: | ||
new_mapping = {"key_chain": new_key, "pattern": "o c h w -> h w c o "} | ||
if "downsample" in old_key: | ||
if "/0/weight" in old_key: | ||
new_mapping = {"key_chain": new_key, "pattern": "o c h w -> h w c o "} | ||
|
||
return new_mapping | ||
|
||
|
||
def load_clip(name: str, pretrained=True): | ||
""" | ||
Load a CLIP model | ||
|
||
Parameters | ||
---------- | ||
name : str | ||
A model name listed in `clip.available_models()`. | ||
One in this list ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] | ||
|
||
Returns | ||
------- | ||
model : ivy.Module | ||
The CLIP model | ||
""" | ||
url = get_clip_weights_url(name) | ||
state_dict = load_clip_state_dict(url) | ||
args = get_model_args(state_dict) | ||
model = CLIP(*args) | ||
|
||
if not pretrained: | ||
return model | ||
|
||
raw_keys_to_prune = [ | ||
"context_length", | ||
"input_resolution", | ||
"vocab_size", | ||
"num_batches_tracked", | ||
] | ||
clean_weights = ivy_models.helpers.load_torch_weights( | ||
url, | ||
model, | ||
raw_keys_to_prune=raw_keys_to_prune, | ||
custom_mapping=_clip_torch_mapping, | ||
jit=True, | ||
data_type=ivy.float32, | ||
) | ||
model = CLIP(*args, v=clean_weights) | ||
return model |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be named
clip_<version>
to match previous conventions