|
| 1 | +import torch |
| 2 | +from torch import Tensor |
| 3 | +import torch.nn as nn |
| 4 | +import torch.nn.functional as F |
| 5 | +from typing import List |
| 6 | +from transformers import AutoModel, AutoTokenizer |
| 7 | +from source.interface import Embedding |
| 8 | + |
| 9 | + |
| 10 | +class MiniPcmEmbedding(Embedding): |
| 11 | + """ |
| 12 | + This class implements the MiniPcm embedding model. |
| 13 | + """ |
| 14 | + |
| 15 | + name = "MiniPcm" |
| 16 | + size = 2304 |
| 17 | + |
| 18 | + def __init__(self, devices: List[int] = [0]) -> None: |
| 19 | + self.devices = devices |
| 20 | + assert len(self.devices) > 0 |
| 21 | + self.tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-Embedding") |
| 22 | + self.pad_idx = self.tokenizer.pad_token_id |
| 23 | + kwargs = dict() |
| 24 | + kwargs["trust_remote_code"] = True |
| 25 | + kwargs["attn_implementation"] = "flash_attention_2" |
| 26 | + kwargs["torch_dtype"] = torch.float16 |
| 27 | + model = AutoModel.from_pretrained("openbmb/MiniCPM-Embedding", **kwargs) |
| 28 | + model = model.eval().to(devices[0]) |
| 29 | + self.model = nn.DataParallel(model, devices) |
| 30 | + |
| 31 | + @torch.inference_mode() |
| 32 | + def forward(self, passages: List[str]) -> Tensor: |
| 33 | + """ |
| 34 | + Adopted from https://huggingface.co/openbmb/MiniCPM-Embedding. |
| 35 | + """ |
| 36 | + kwargs = dict() |
| 37 | + kwargs["padding"] = True |
| 38 | + kwargs["truncation"] = True |
| 39 | + kwargs["return_tensors"] = "pt" |
| 40 | + kwargs["return_attention_mask"] = True |
| 41 | + encoded = self.tokenizer(passages, **kwargs) |
| 42 | + encoded = encoded.to(self.devices[0]) |
| 43 | + outputs = self.model.forward(**encoded) |
| 44 | + masking = encoded["attention_mask"] |
| 45 | + s = torch.sum(outputs.last_hidden_state * masking.unsqueeze(-1).float(), dim=1) |
| 46 | + d = masking.sum(dim=1, keepdim=True).float() |
| 47 | + return F.normalize(s / d, p=2, dim=1) |
| 48 | + |
| 49 | + # @torch.inference_mode() |
| 50 | + # def forward_prefix(self, passages: List[str]) -> Tuple[Tensor, Any, Any]: |
| 51 | + # """ |
| 52 | + # @todo: fix the return type. |
| 53 | + # """ |
| 54 | + # kwargs = dict(padding=True, truncation=True, return_tensors="pt") |
| 55 | + # encoded = self.tokenizer(passages[0], **kwargs) |
| 56 | + # input_ids = encoded.input_ids[0] # Shape: [seq_len] |
| 57 | + # tokens = self.tokenizer.convert_ids_to_tokens(input_ids) |
| 58 | + # prefix_input_ids = [input_ids[:i] for i in range(1, len(input_ids) + 1)] |
| 59 | + # batch_encoded = self.tokenizer.pad({'input_ids': prefix_input_ids}, padding=True, return_tensors="pt") |
| 60 | + # batch_input_ids = batch_encoded.input_ids.to(self.devices[0]) |
| 61 | + # outputs = self.model(batch_input_ids) |
| 62 | + # assert isinstance(outputs, BaseModelOutputWithPoolingAndCrossAttentions) |
| 63 | + # hiddens = outputs.last_hidden_state |
| 64 | + # return F.normalize(hiddens[:, 0], p=2, dim=1), tokens, input_ids |
| 65 | + |
| 66 | + # @torch.inference_mode() |
| 67 | + # def forward_tokens(self, tokens: List[List[float]]) -> Tensor: |
| 68 | + # kwargs = dict(padding=True, truncation=True, return_tensors="pt") |
| 69 | + # batch_encoded = self.tokenizer.pad({'input_ids': tokens}, padding=True, return_tensors="pt") |
| 70 | + # batch_input_ids = batch_encoded.input_ids.to(self.devices[0]) |
| 71 | + # outputs = self.model(batch_input_ids) |
| 72 | + # assert isinstance(outputs, BaseModelOutputWithPoolingAndCrossAttentions) |
| 73 | + # hiddens = outputs.last_hidden_state |
| 74 | + # return F.normalize(hiddens[:, 0], p=2, dim=1) |
0 commit comments