Skip to content

Commit 80d474d

Browse files
committed
implement miniPCM embedding
1 parent 7c78d85 commit 80d474d

File tree

4 files changed

+88
-0
lines changed

4 files changed

+88
-0
lines changed

environment.yml

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ dependencies:
2727
- elasticsearch=8.15.1
2828
- pillow
2929
- seaborn
30+
- sentencepiece
3031
- pip:
3132
- treevizer
3233
- beir
34+
- flash-attn

source/embedding/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from source.embedding.bgeBase import BgeBaseEmbedding
2+
# from source.embedding.miniPcm import MiniPcmEmbedding

source/embedding/miniPcm.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
from torch import Tensor
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from typing import List
6+
from transformers import AutoModel, AutoTokenizer
7+
from source.interface import Embedding
8+
9+
10+
class MiniPcmEmbedding(Embedding):
11+
"""
12+
This class implements the MiniPcm embedding model.
13+
"""
14+
15+
name = "MiniPcm"
16+
size = 2304
17+
18+
def __init__(self, devices: List[int] = [0]) -> None:
19+
self.devices = devices
20+
assert len(self.devices) > 0
21+
self.tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-Embedding")
22+
self.pad_idx = self.tokenizer.pad_token_id
23+
kwargs = dict()
24+
kwargs["trust_remote_code"] = True
25+
kwargs["attn_implementation"] = "flash_attention_2"
26+
kwargs["torch_dtype"] = torch.float16
27+
model = AutoModel.from_pretrained("openbmb/MiniCPM-Embedding", **kwargs)
28+
model = model.eval().to(devices[0])
29+
self.model = nn.DataParallel(model, devices)
30+
31+
@torch.inference_mode()
32+
def forward(self, passages: List[str]) -> Tensor:
33+
"""
34+
Adopted from https://huggingface.co/openbmb/MiniCPM-Embedding.
35+
"""
36+
kwargs = dict()
37+
kwargs["padding"] = True
38+
kwargs["truncation"] = True
39+
kwargs["return_tensors"] = "pt"
40+
kwargs["return_attention_mask"] = True
41+
encoded = self.tokenizer(passages, **kwargs)
42+
encoded = encoded.to(self.devices[0])
43+
outputs = self.model.forward(**encoded)
44+
masking = encoded["attention_mask"]
45+
s = torch.sum(outputs.last_hidden_state * masking.unsqueeze(-1).float(), dim=1)
46+
d = masking.sum(dim=1, keepdim=True).float()
47+
return F.normalize(s / d, p=2, dim=1)
48+
49+
# @torch.inference_mode()
50+
# def forward_prefix(self, passages: List[str]) -> Tuple[Tensor, Any, Any]:
51+
# """
52+
# @todo: fix the return type.
53+
# """
54+
# kwargs = dict(padding=True, truncation=True, return_tensors="pt")
55+
# encoded = self.tokenizer(passages[0], **kwargs)
56+
# input_ids = encoded.input_ids[0] # Shape: [seq_len]
57+
# tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
58+
# prefix_input_ids = [input_ids[:i] for i in range(1, len(input_ids) + 1)]
59+
# batch_encoded = self.tokenizer.pad({'input_ids': prefix_input_ids}, padding=True, return_tensors="pt")
60+
# batch_input_ids = batch_encoded.input_ids.to(self.devices[0])
61+
# outputs = self.model(batch_input_ids)
62+
# assert isinstance(outputs, BaseModelOutputWithPoolingAndCrossAttentions)
63+
# hiddens = outputs.last_hidden_state
64+
# return F.normalize(hiddens[:, 0], p=2, dim=1), tokens, input_ids
65+
66+
# @torch.inference_mode()
67+
# def forward_tokens(self, tokens: List[List[float]]) -> Tensor:
68+
# kwargs = dict(padding=True, truncation=True, return_tensors="pt")
69+
# batch_encoded = self.tokenizer.pad({'input_ids': tokens}, padding=True, return_tensors="pt")
70+
# batch_input_ids = batch_encoded.input_ids.to(self.devices[0])
71+
# outputs = self.model(batch_input_ids)
72+
# assert isinstance(outputs, BaseModelOutputWithPoolingAndCrossAttentions)
73+
# hiddens = outputs.last_hidden_state
74+
# return F.normalize(hiddens[:, 0], p=2, dim=1)

source/embedding/test_miniPcm.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from source.embedding.miniPcm import MiniPcmEmbedding
2+
3+
4+
def test_forward():
5+
"""
6+
Test forward method.
7+
"""
8+
embedding = MiniPcmEmbedding()
9+
passages = ["Hello, world!", "Goodbye, world!"]
10+
results = embedding.forward(passages)
11+
assert results.shape == (len(passages), MiniPcmEmbedding.size)

0 commit comments

Comments
 (0)