Skip to content

Commit

Permalink
add multilingual clip
Browse files Browse the repository at this point in the history
  • Loading branch information
wanliAlex committed Jan 10, 2023
1 parent 266d2f0 commit cd39c46
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 7 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ pandas==1.5.1
optimum==1.4.1
opencv-python-headless==4.6.0.66
psutil==5.9.4
multilingual-clip==1.0.10
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"uvicorn[standard]",
"fastapi_utils",
"opencv-python-headless",
"psutil"
"psutil",
"multilingual_clip"
],
name="marqo-engine",
version="0.1.10",
Expand Down
70 changes: 70 additions & 0 deletions src/marqo/s2_inference/clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import torch
from PIL import Image, UnidentifiedImageError
import open_clip
from multilingual_clip import pt_multilingual_clip
import transformers

from marqo.s2_inference.types import *
from marqo.s2_inference.logger import get_logger
import marqo.s2_inference.model_registry as model_registry

logger = get_logger(__name__)


Expand Down Expand Up @@ -271,6 +275,72 @@ class MULTILINGUAL_CLIP(CLIP):
def __init__(self, model_type: str = "multilingual-clip/ViT-L/14", device: str = 'cpu', embedding_dim: int = None,
truncate: bool = True, **kwargs) -> None:

self.model_name = model_type
self.model_info = model_registry._get_multilingual_clip_properties()[self.model_name]
self.visual_name = self.model_info["visual_model"]
self.textual_name = self.model_info["textual_model"]
self.device = device


def load(self) -> None:
if self.visual_name.startswith("openai/"):
clip_name = self.visual_name.replace("openai/", "")
self.visual_model, self.preprocess = clip.load(name = clip_name, device = "cpu", jit = False)
self.visual_model = self.visual_model.to(self.device)

elif self.visual_name.startswith("open_clip/"):
clip_name = self.visual_name.replace("open_clip/", "")
self.visual_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name=clip_name.split("/")[0], pretrained= clip_name.split("/")[1], device = self.device)

self.textual_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(self.textual_name)
self.textual_model = self.textual_model.to(self.device)

self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.textual_name)

self.textual_model.eval()
self.visual_model.eval()

def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatTensor:

if self.textual_model is None:
self.load()

with torch.no_grad():
outputs = self.textual_model.forward(sentence, self.tokenizer)

if normalize:
_shape_before = outputs.shape
outputs /= self.normalize(outputs)
assert outputs.shape == _shape_before

return self._convert_output(outputs)

def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
normalize=True) -> FloatTensor:

if self.visual_model is None:
self.load()

# default to batch encoding
if isinstance(images, list):
image_input = format_and_load_CLIP_images(images)
else:
image_input = [format_and_load_CLIP_image(images)]

self.image_input_processed = torch.stack([self.preprocess(_img).to(self.device) for _img in image_input])

with torch.no_grad():
outputs = self.visual_model.encode_image(self.image_input_processed)

if normalize:
_shape_before = outputs.shape
outputs /= self.normalize(outputs)
assert outputs.shape == _shape_before
return self._convert_output(outputs)







Expand Down
41 changes: 35 additions & 6 deletions src/marqo/s2_inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from marqo.s2_inference.sbert_onnx_utils import SBERT_ONNX
from marqo.s2_inference.sbert_utils import SBERT, TEST
from marqo.s2_inference.random_utils import Random
from marqo.s2_inference.clip_utils import CLIP, OPEN_CLIP
from marqo.s2_inference.clip_utils import CLIP, OPEN_CLIP, MULTILINGUAL_CLIP
from marqo.s2_inference.types import Any, Dict, List, Optional, Union, FloatTensor
from marqo.s2_inference.onnx_clip_utils import CLIP_ONNX

Expand Down Expand Up @@ -498,14 +498,40 @@ def _get_sbert_onnx_properties() -> Dict:

def _get_multilingual_clip_properties() -> Dict:
MULTILINGUAL_CLIP_PROPERTIES = {
"multilingual-clip/ViT-L/14" :
"multilingual-clip/XLM-Roberta-Large-Vit-L-14" :
{
"model_name" : "multilingual-clip/ViT-L/14",
"visual_name" : "openai/ViT-L/14",
"textual_name" : 'M-CLIP/XLM-Roberta-Large-Vit-L-14',
"name" : "multilingual-clip/XLM-Roberta-Large-Vit-L-14",
"visual_model" : "openai/ViT-L/14",
"textual_model" : 'M-CLIP/XLM-Roberta-Large-Vit-L-14',
"dimensions" : 768,
"type": "multilingual-clip",
"type": "multilingual_clip",
},
"multilingual-clip/XLM-R Large Vit-B/16+":
{
"name": "multilingual-clip/XLM-R Large Vit-B/16+",
"visual_model": "open_clip/ViT-B-16-plus-240/laion400m_e32",
"textual_model": 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus',
"dimensions": 640,
"type": "multilingual_clip",
},
"multilingual-clip/XLM-Roberta-Large-Vit-B-32":
{
"name": "multilingual-clip/XLM-Roberta-Large-Vit-B-32",
"visual_model": "openai/ViT-B/32",
"textual_model": 'M-CLIP/XLM-Roberta-Large-Vit-B-32',
"dimensions": 512,
"type": "multilingual_clip",
},
"multilingual-clip/LABSE-Vit-L-14":
{
"name": "multilingual-clip/LABSE-Vit-L-14",
"visual_model": "openai/ViT-L/14",
"textual_model": 'M-CLIP/LABSE-Vit-L-14',
"dimensions": 768,
"type": "multilingual_clip",

}

}

return MULTILINGUAL_CLIP_PROPERTIES
Expand Down Expand Up @@ -1544,6 +1570,7 @@ def _get_model_load_mappings() -> Dict:
'test':TEST,
'sbert_onnx':SBERT_ONNX,
'clip_onnx': CLIP_ONNX,
"multilingual_clip" : MULTILINGUAL_CLIP,
'random':Random,
'hf':HF_MODEL}

Expand All @@ -1560,6 +1587,7 @@ def load_model_properties() -> Dict:
hf_model_properties = _get_hf_properties()
open_clip_model_properties = _get_open_clip_properties()
onnx_clip_model_properties = _get_onnx_clip_properties()
multilingual_clip_model_properties = _get_multilingual_clip_properties()

# combine the above dicts
model_properties = dict(clip_model_properties.items())
Expand All @@ -1570,6 +1598,7 @@ def load_model_properties() -> Dict:
model_properties.update(hf_model_properties)
model_properties.update(open_clip_model_properties)
model_properties.update(onnx_clip_model_properties)
model_properties.update(multilingual_clip_model_properties)

all_properties = dict()
all_properties['models'] = model_properties
Expand Down

0 comments on commit cd39c46

Please sign in to comment.