diff --git a/ivy_models/__init__.py b/ivy_models/__init__.py index 4f1c1d2b..a06f506b 100644 --- a/ivy_models/__init__.py +++ b/ivy_models/__init__.py @@ -17,10 +17,13 @@ from .squeezenet import * from .densenet import * +from .vit import * + +from . import clip +from .clip import * from . import bart from .bart import * from . import bert from .bert import * -from .vit import * diff --git a/ivy_models/clip/__init__.py b/ivy_models/clip/__init__.py new file mode 100644 index 00000000..55a53dd1 --- /dev/null +++ b/ivy_models/clip/__init__.py @@ -0,0 +1,5 @@ +# This submodule is heavily inspired by OpenAI's original implementation - +# https://github.com/openai/CLIP + +from . import clip +from .clip import * \ No newline at end of file diff --git a/ivy_models/clip/bpe_simple_vocab_16e6.txt.gz b/ivy_models/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 00000000..7b5088a5 Binary files /dev/null and b/ivy_models/clip/bpe_simple_vocab_16e6.txt.gz differ diff --git a/ivy_models/clip/clip.py b/ivy_models/clip/clip.py new file mode 100644 index 00000000..c0062050 --- /dev/null +++ b/ivy_models/clip/clip.py @@ -0,0 +1,247 @@ +from typing import Tuple, Union + +import numpy as np +import ivy +from ivy.stateful.initializers import Ones + +from .layers import ( + CLIPModifiedResNet, + CLIPTransformer, + CLIPVisionTransformer, + Embedding, +) +import ivy_models +from .misc import ( + get_model_args, + get_clip_weights_url, + load_clip_state_dict, + tokenize, + get_processors, +) + +__all__ = ["CLIP", "clip", "tokenize", "get_processors"] + + +class CLIP(ivy.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + # ivy + device=None, + v=None, + ): + """ + An ivy implementation of the CLIP model in fp32. + The image encoders from the original implementation can be one of the following + - Modified resnet variants (RN50, RN101, RN50x4, RN50x16, RNx64) + - ViT variants: (ViT-B/32, ViT-B/16, ViT-L/14, ViT-l/14@336px) + + Parameters + ---------- + embed_dim : + Feature dimension that the text and image encoders will be projected to. + image_resolution : + Input image's resolution expected by the image encoder. (e.g. 224) + vision layers : + For the ViT image encoders it's the number of residual attention block. + For the modified Resnets it's a tuple of four integers that represent the + number of residual block in each of the four residual layers. + vision_width : + For the Resnets it's the number of channels in the first residual layer. + For the ViT it's the transformer's feature dimension. + (.i.e. In both cases the final visual features are projected to embed_dim.) + vision_patch_size: + The patch size of the ViT encoder. Not application to the Resnets. + context_length : + The context length of the text encoder + vocab_size : + The size of the vocabulary. Used in the embedding layer. + transformer_width : + The feature dimension of the text encoder. + (e.i. It's later projected to embed_dim) + transformer_heads : + Number of attention head per residual attention block for the text encoder. + transformer_layers : + Number of residual attention block in the text encoder. + """ + self.embed_dim = embed_dim + self.image_resolution = image_resolution + self.vision_layers = vision_layers + self.vision_width = vision_width + self.vision_patch_size = vision_patch_size + + self.context_length = context_length + self.vocab_size = vocab_size + self.transformer_width = transformer_width + self.transformer_heads = transformer_heads + self.transformer_layers = transformer_layers + + self._pos_embed_shape = (self.context_length, self.transformer_width) + self._text_proj_shape = (self.transformer_width, self.embed_dim) + self._scale_init = Ones() + + super().__init__(device=device, v=v) + + def _build(self, *args, **kwargs): + if isinstance(self.vision_layers, (tuple, list)): + vision_heads = self.vision_width * 32 // 64 + self.visual = CLIPModifiedResNet( + layers=self.vision_layers, + output_dim=self.embed_dim, + heads=vision_heads, + input_resolution=self.image_resolution, + width=self.vision_width, + ) + else: + vision_heads = self.vision_width // 64 + self.visual = CLIPVisionTransformer( + input_resolution=self.image_resolution, + patch_size=self.vision_patch_size, + width=self.vision_width, + layers=self.vision_layers, + heads=vision_heads, + output_dim=self.embed_dim, + ) + + self.transformer = CLIPTransformer( + width=self.transformer_width, + layers=self.transformer_layers, + heads=self.transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.token_embedding = Embedding(self.vocab_size, self.transformer_width) + self.ln_final = ivy.LayerNorm([self.transformer_width]) + + def _create_variables(self, *, device=None, dtype=None): + v = { + "positional_embedding": ivy.empty( + self._pos_embed_shape, dtype=dtype, device=device + ), + "text_projection": ivy.empty( + self._text_proj_shape, dtype=dtype, device=device + ), + # Casting to float32 because of an issue with avg_pool2d for jax backend + # when jax_enable_x64 is set to True + "logit_scale": self._scale_init.create_variables([], device, dtype=dtype) + * np.log(1 / 0.07).astype(ivy.float32), + } + return v + + def build_attention_mask(self): + # Create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask for floats; but ivy expect a boolean mask + mask = ivy.ones((self.context_length, self.context_length)) + mask = mask.tril(k=0) + return mask + + @property + def dtype(self): + return self.visual.conv1.v.w.dtype + + def encode_image(self, image): + return self.visual(image) + + def encode_text(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.v.positional_embedding + x = x.permute_dims((1, 0, 2)) # NLD -> LND + x = self.transformer(x) + x = x.permute_dims((1, 0, 2)) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[ivy.arange(x.shape[0]), text.argmax(axis=-1)] @ self.v.text_projection + + return x + + def _forward( + self, + image: Union[ivy.Array, ivy.NativeArray], + text: Union[ivy.Array, ivy.NativeArray], + ): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.vector_norm( + axis=1, keepdims=True + ) + text_features = text_features / text_features.vector_norm(axis=1, keepdims=True) + + # cosine similarity as logits + logit_scale = self.v.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logits_per_image.T + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def _clip_torch_mapping(old_key, new_key): + new_mapping = new_key + + if "conv" in old_key: + if "/weight" in old_key: + new_mapping = {"key_chain": new_key, "pattern": "o c h w -> h w c o "} + if "downsample" in old_key: + if "/0/weight" in old_key: + new_mapping = {"key_chain": new_key, "pattern": "o c h w -> h w c o "} + + return new_mapping + + +def clip(name: str, pretrained=True): + """ + Load a pretrained CLIP model variant. + + Parameters + ---------- + name : str + A model name listed in `clip.available_models()`. + It's actually the pretrained image encoder that'll be used in the model. + One in this list ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', + 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] + + Returns + ------- + model : ivy.Module + The pretrained CLIP model + """ + url = get_clip_weights_url(name) + state_dict = load_clip_state_dict(url) + args = get_model_args(state_dict) + model = CLIP(*args) + + if not pretrained: + return model + + raw_keys_to_prune = [ + "context_length", + "input_resolution", + "vocab_size", + "num_batches_tracked", + ] + clean_weights = ivy_models.helpers.load_torch_weights( + url, + model, + raw_keys_to_prune=raw_keys_to_prune, + custom_mapping=_clip_torch_mapping, + jit=True, + data_type=ivy.float32, + ) + model = CLIP(*args, v=clean_weights) + return model diff --git a/ivy_models/clip/layers.py b/ivy_models/clip/layers.py new file mode 100644 index 00000000..fc348f1b --- /dev/null +++ b/ivy_models/clip/layers.py @@ -0,0 +1,424 @@ +from typing import Union + +import ivy + + +# TODO: Refactor once the layer is added to ivy's API. It's a generic layer. +class Identity(ivy.Module): + def __init__(self): + super(Identity, self).__init__() + + def _forward(self, x): + return x + + +# TODO: Refactor once the layer is added to ivy's API. It's a generic layer. +class Embedding(ivy.Module): + def __init__(self, vocab_size, embed_dim, max_norm=None, device=None, dtype=None): + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.max_norm = max_norm + self._w_init = ivy.RandomNormal(0.0, 1.0) + super(Embedding, self).__init__(device=device, dtype=dtype) + + def _create_variables(self, device=None, dtype=None): + v = { + "weight": self._w_init.create_variables( + var_shape=(self.vocab_size, self.embed_dim), device=device, dtype=dtype + ) + } + return v + + def _forward(self, x): + return ivy.embedding(self.v.weight, x, max_norm=self.max_norm) + + +class CLIPBottleneck(ivy.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + self.inplanes = inplanes + self.planes = planes + self.stride = stride + super().__init__() + + def _build(self, *args, **kwargs): + # all conv layers have stride 1. + # An avgpool is performed after the second convolution when stride > 1 + self.conv1 = ivy.Conv2D( + self.inplanes, + self.planes, + [1, 1], + (1, 1), + 0, + with_bias=False, + data_format="NCHW", + ) + self.bn1 = ivy.BatchNorm2D(self.planes, data_format="NCS") + self.relu1 = ivy.ReLU() + + self.conv2 = ivy.Conv2D( + self.planes, + self.planes, + [3, 3], + (1, 1), + 1, + with_bias=False, + data_format="NCHW", + ) + self.bn2 = ivy.BatchNorm2D(self.planes, data_format="NCS") + self.relu2 = ivy.ReLU() + + self.avgpool = ( + ivy.AvgPool2D(self.stride, self.stride, 0, data_format="NCHW") + if self.stride > 1 + else Identity() + ) + + self.conv3 = ivy.Conv2D( + self.planes, + self.planes * self.expansion, + [1, 1], + (1, 1), + 0, + with_bias=False, + data_format="NCHW", + ) + self.bn3 = ivy.BatchNorm2D(self.planes * self.expansion, data_format="NCS") + self.relu3 = ivy.ReLU() + + self.downsample = None + + if self.stride > 1 or self.inplanes != self.planes * CLIPBottleneck.expansion: + # downsampling layer is prepended with an avgpool + # and the subsequent conv has stride 1 + self.downsample = ivy.Sequential( + *[ + ivy.AvgPool2D(self.stride, self.stride, 0, data_format="NCHW"), + ivy.Conv2D( + self.inplanes, + self.planes * self.expansion, + [1, 1], + (1, 1), + 0, + with_bias=False, + data_format="NCHW", + ), + ivy.BatchNorm2D(self.planes * self.expansion, data_format="NCS"), + ] + ) + + def _forward(self, x: ivy.Array): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class CLIPAttentionPool2d(ivy.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + self.spacial_dim = spacial_dim + self.embed_dim = embed_dim + self.num_heads = num_heads + self.output_dim = output_dim + self.dot_prod_scale = (embed_dim // self.num_heads) ** -0.5 + self._pos_embed_init = ivy.RandomNormal(0.0, 1.0) + super().__init__() + + def _build(self, *args, **kwargs): + self.q_proj = ivy.Linear(self.embed_dim, self.embed_dim) + self.k_proj = ivy.Linear(self.embed_dim, self.embed_dim) + self.v_proj = ivy.Linear(self.embed_dim, self.embed_dim) + self.c_proj = ivy.Linear(self.embed_dim, self.output_dim or self.embed_dim) + + def _create_variables(self, device=None, dtype=None): + v = { + "positional_embedding": self._pos_embed_init.create_variables( + var_shape=(self.spacial_dim**2 + 1, self.embed_dim), + device=device, + dtype=dtype, + ) + / self.embed_dim**0.5 + } + return v + + def _forward(self, x): + x = x.flatten(start_dim=2).permute_dims((2, 0, 1)) # NCHW -> (HW)NC + x = ivy.concat([x.mean(axis=0, keepdims=True), x], axis=0) # (HW+1)NC + x = x + self.v.positional_embedding[:, None, :] # (HW+1)NC + # Ivy expects the query in NLE, not LNE + x = x.permute_dims((1, 0, 2)) + x = ivy.multi_head_attention( + x[:, :1, :], + x, + x, + num_heads=self.num_heads, + scale=self.dot_prod_scale, + q_proj_weights=self.v.q_proj.w, + k_proj_weights=self.v.k_proj.w, + v_proj_weights=self.v.v_proj.w, + out_proj_weights=self.v.c_proj.w, + in_proj_bias=ivy.concat( + [self.v.q_proj.b, self.v.k_proj.b, self.v.v_proj.b] + ), + out_proj_bias=self.v.c_proj.b, + ) # N1E + return x.squeeze(axis=1) # N1E -> NE + + +class CLIPModifiedResNet(ivy.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - Now we have 3 "stem" convs as opposed to 1, with an avgpool instead of a max pool. + - Anti-aliasing strided convs, where an avgpool is prepended to convs with stride>1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + self.layers = layers + self.output_dim = output_dim + self.heads = heads + self.input_resolution = input_resolution + self.width = width + super().__init__() + + def _build(self, *args, **kwargs): + # the 3-layer stem + self.conv1 = ivy.Conv2D( + 3, self.width // 2, [3, 3], (2, 2), 1, with_bias=False, data_format="NCHW" + ) + self.bn1 = ivy.BatchNorm2D(self.width // 2, data_format="NCS") + self.relu1 = ivy.ReLU() + self.conv2 = ivy.Conv2D( + self.width // 2, + self.width // 2, + [3, 3], + (1, 1), + 1, + with_bias=False, + data_format="NCHW", + ) + self.bn2 = ivy.BatchNorm2D(self.width // 2, data_format="NCS") + self.relu2 = ivy.ReLU() + self.conv3 = ivy.Conv2D( + self.width // 2, + self.width, + [3, 3], + (1, 1), + 1, + with_bias=False, + data_format="NCHW", + ) + self.bn3 = ivy.BatchNorm2D(self.width, data_format="NCS") + self.relu3 = ivy.ReLU() + self.avgpool = ivy.AvgPool2D(2, 2, 0, data_format="NCHW") + + # residual layers + self._inplanes = ( + self.width + ) # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(self.width, self.layers[0]) + self.layer2 = self._make_layer(self.width * 2, self.layers[1], stride=2) + self.layer3 = self._make_layer(self.width * 4, self.layers[2], stride=2) + self.layer4 = self._make_layer(self.width * 8, self.layers[3], stride=2) + + embed_dim = self.width * 32 # the ResNet feature dimension + self.attnpool = CLIPAttentionPool2d( + self.input_resolution // 32, embed_dim, self.heads, self.output_dim + ) + + def _make_layer(self, planes, blocks, stride=1): + layers = [CLIPBottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * CLIPBottleneck.expansion + for _ in range(1, blocks): + layers.append(CLIPBottleneck(self._inplanes, planes)) + + return ivy.Sequential(*layers) + + def _forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class CLIPQuickGELU(ivy.Module): + def _forward(self, x: Union[ivy.Array, ivy.NativeArray]): + return x * ivy.sigmoid(1.702 * x) + + +class CLIPResidualAttentionBlock(ivy.Module): + def __init__( + self, + d_model: int, + n_head: int, + attn_mask: Union[ivy.Array, ivy.NativeArray] = None, + ): + self.d_model = d_model + self.n_head = n_head + self.attn_mask = attn_mask + super().__init__() + + def _build(self, *args, **kwargs): + self.attn = ivy.MultiHeadAttention(self.d_model, num_heads=self.n_head) + self.ln_1 = ivy.LayerNorm([self.d_model]) + self.mlp = ivy.Sequential( + ivy.Linear(self.d_model, self.d_model * 4), + CLIPQuickGELU(), + ivy.Linear(self.d_model * 4, self.d_model), + ) + self.ln_2 = ivy.LayerNorm([self.d_model]) + + def attention(self, x: Union[ivy.Array, ivy.NativeArray]): + if self.attn_mask is not None: + self.attn_mask = self.attn_mask.to_device(x.device) + return self.attn(x, attention_mask=self.attn_mask) + + def _forward(self, x: Union[ivy.Array, ivy.NativeArray]): + x = x.permute_dims( + (1, 0, 2) + ) # LND -> NLD : ivy's MultiHeadAtention layer expects NLD + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + x = x.permute_dims((1, 0, 2)) # NLD -> LND + return x + + +class CLIPTransformer(ivy.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + attn_mask: Union[ivy.Array, ivy.NativeArray] = None, + ): + self.width = width + self.layers = layers + self.heads = heads + self.attn_mask = attn_mask + super().__init__() + + def _build(self, *args, **kwargs): + self.resblocks = ivy.Sequential( + *[ + CLIPResidualAttentionBlock(self.width, self.heads, self.attn_mask) + for _ in range(self.layers) + ] + ) + + def _forward(self, x: Union[ivy.Array, ivy.NativeArray]): + return self.resblocks(x) + + +class CLIPVisionTransformer(ivy.Module): + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + ): + self.input_resolution = input_resolution + self.patch_size = patch_size + self.width = width + self.layers = layers + self.heads = heads + self.output_dim = output_dim + + self._scale = width**-0.5 + self._class_embed_init = ivy.RandomNormal(0.0, 1.0) + self._pos_embed_init = ivy.RandomNormal(0.0, 1.0) + self._proj_init = ivy.RandomNormal(0.0, 1.0) + super().__init__() + + def _build(self, *args, **kwargs): + self.conv1 = ivy.Conv2D( + 3, + self.width, + [ + self.patch_size, + ] + * 2, + (self.patch_size,) * 2, + 0, + with_bias=False, + data_format="NCHW", + ) + self.ln_pre = ivy.LayerNorm([self.width]) + self.transformer = CLIPTransformer(self.width, self.layers, self.heads) + self.ln_post = ivy.LayerNorm([self.width]) + + def _create_variables(self, device=None, dtype=None): + v = { + "class_embedding": self._class_embed_init.create_variables( + var_shape=self.width, device=device, dtype=dtype + ) + * self._scale, + "positional_embedding": self._pos_embed_init.create_variables( + var_shape=( + (self.input_resolution // self.patch_size) ** 2 + 1, + self.width, + ), + device=device, + dtype=dtype, + ) + * self._scale, + "proj": self._proj_init.create_variables( + var_shape=(self.width, self.output_dim), device=device, dtype=dtype + ) + * self._scale, + } + return v + + def _forward(self, x: Union[ivy.Array, ivy.NativeArray]): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape((x.shape[0], x.shape[1], -1)) # shape = [*, width, grid ** 2] + x = x.permute_dims((0, 2, 1)) # shape = [*, grid ** 2, width] + x = ivy.concat( + [ + self.v.class_embedding + + ivy.zeros( + (x.shape[0], 1, x.shape[-1]), dtype=x.dtype, device=x.device + ), + x, + ], + axis=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.v.positional_embedding + x = self.ln_pre(x) + + x = x.permute_dims((1, 0, 2)) # NLD -> LND + x = self.transformer(x) + x = x.permute_dims((1, 0, 2)) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.v.proj is not None: + x = x @ self.v.proj + + return x diff --git a/ivy_models/clip/misc.py b/ivy_models/clip/misc.py new file mode 100644 index 00000000..d5e2c2c9 --- /dev/null +++ b/ivy_models/clip/misc.py @@ -0,0 +1,193 @@ +import warnings +from typing import Union, List +from pkg_resources import packaging + + +import ivy +import torch +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image + +from .simple_tokenizer import CLIPTokenizer + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +_tokenizer = CLIPTokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # noqa + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", # noqa + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", # noqa +} + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose( + [ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + lambda x: ivy.array(x.numpy()), + ] + ) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def get_model_args(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 + ) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round( + (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 + ) + vision_patch_size = None + assert ( + output_width**2 + 1 + == state_dict["visual.attnpool.positional_embedding"].shape[0] + ) + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks") + ) + ) + + return ( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + +def get_clip_weights_url(name): + if name not in _MODELS: + raise ValueError( + f"Model '{name}' not found; available models = {available_models()}" + ) + return _MODELS[name] + + +def load_clip_state_dict(url: str): + return torch.hub.load_state_dict_from_url(url, map_location="cpu").state_dict() + + +def get_processors(model): + """Returns the text tokenizer and the image transform for a given model.""" + return tokenize, _transform(model.visual.input_resolution) + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False +) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text if its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, + shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, + since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f"Input {texts[i]} is too long for context length {context_length}" + ) + result[i, : len(tokens)] = torch.tensor(tokens) + + return ivy.array(result.numpy()) diff --git a/ivy_models/clip/simple_tokenizer.py b/ivy_models/clip/simple_tokenizer.py new file mode 100644 index 00000000..13ee2c0d --- /dev/null +++ b/ivy_models/clip/simple_tokenizer.py @@ -0,0 +1,151 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K + for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class CLIPTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", # noqa + re.IGNORECASE, + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except ValueError: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text diff --git a/ivy_models/helpers/weights_helpers.py b/ivy_models/helpers/weights_helpers.py index bb42ffd9..9e7d99ab 100644 --- a/ivy_models/helpers/weights_helpers.py +++ b/ivy_models/helpers/weights_helpers.py @@ -139,12 +139,26 @@ def load_torch_weights( ref_keys_to_prune=[], custom_mapping=None, map_location=torch.device("cpu"), + jit=False, + data_type=None, ): ivy_torch = ivy.with_backend("torch") weights = torch.hub.load_state_dict_from_url(url, map_location=map_location) - weights_raw = ivy.Container( - ivy_torch.to_numpy(ivy_torch.Container(weights)).cont_to_dict() - ) + + if jit: + weights = weights.state_dict() + + if data_type: + weights_raw = ivy.Container( + ivy_torch.to_numpy( + ivy_torch.Container(weights).astype(data_type) + ).cont_to_dict() + ) + else: + weights_raw = ivy.Container( + ivy_torch.to_numpy(ivy_torch.Container(weights)).cont_to_dict() + ) + weights_raw, weights_ref, pruned_ref = _prune_keys( weights_raw, ref_model.v, raw_keys_to_prune, ref_keys_to_prune ) diff --git a/ivy_models_tests/clip/test_clip.py b/ivy_models_tests/clip/test_clip.py new file mode 100644 index 00000000..55291a27 --- /dev/null +++ b/ivy_models_tests/clip/test_clip.py @@ -0,0 +1,59 @@ +import os +import random + +import ivy +import numpy as np +from PIL import Image + +from ivy_models import clip, get_processors + + +VARIANTS_LOGITS = { + "RN50": ivy.array([12.2088346, 14.9655876, 21.0058422]), + "RN101": ivy.array([35.2729797, 36.0812988, 42.8816681]), + "RN50x4": ivy.array([27.0689335, 29.3602104, 35.6379929]), + "RN50x16": ivy.array([16.3668022, 20.4796104, 27.0634518]), + "RN50x64": ivy.array([8.9237432, 14.1180887, 20.0675087]), + "ViT-B/32": ivy.array([17.3713417, 19.5949516, 25.5068512]), + "ViT-B/16": ivy.array([17.9823151, 20.7719479, 26.9038792]), + "ViT-L/14": ivy.array([11.9254637, 15.6604385, 22.6723843]), + "ViT-L/14@336px": ivy.array([10.9720955, 13.5543489, 21.4815979]), +} + +load_weights = random.choice([False, True]) +model_var = random.choice(list(VARIANTS_LOGITS.keys())) +model = clip(model_var, pretrained=load_weights) +v = ivy.to_numpy(model.v) + + +def test_all_clip_img_classification( + device, + f, + fw, +): + """Test one CLIP variant for zero shot image classification.""" + image_name = "cat.jpg" + one_shot_labels = ["a diagram", "a dog", "a cat"] + batch_shape = [1] + num_classes = len(one_shot_labels) + this_dir = os.path.dirname(os.path.realpath(__file__)) + + # Load image and processors + tokenize, im_tfms = get_processors(model) + img = Image.open(os.path.join(this_dir, "..", "..", "images", image_name)) + img = ivy.expand_dims(im_tfms(img), axis=0) + text = tokenize(one_shot_labels) + + # Get logits and probs + model.v = ivy.asarray(v) + logits_per_image, logits_per_text = model(img, text) + calc_probs = ivy.to_numpy(logits_per_image.softmax(axis=-1)[0]) + true_probs = ivy.to_numpy(VARIANTS_LOGITS[model_var].softmax()) + + # Cardinality test + assert logits_per_image.shape == tuple([ivy.to_scalar(batch_shape), num_classes]) + + # Value test + # Probs instead of logits because the raw weights are in fp16 and we used float32. + if load_weights: + assert np.allclose(true_probs, calc_probs, atol=5e-3) diff --git a/optional.txt b/optional.txt index bcdc8625..1ba13af6 100644 --- a/optional.txt +++ b/optional.txt @@ -3,4 +3,6 @@ numpy torchvision Pillow # mod_name=PIL huggingface_hub +ftfy +regex transformers