Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BGE-M3 Encoder #22

Merged
merged 21 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,43 @@

This provides various Dense Retrieval functionality for [PyTerrier](https://github.com/terrier-org/pyterrier).


## Installation

This repostory can be installed using pip.
This repository can be installed using pip.
andreaschari marked this conversation as resolved.
Show resolved Hide resolved

```bash
pip install pyterrier-dr
```

If you want the latest version of `pyterrier_dr`, you can install direct from the Github repo:

```bash
pip install --upgrade git+https://github.com/terrierteam/pyterrier_dr.git
```

if you want to use the BGE-M3 encoder with `pyterrier_dr`, you can install the package with the `bgem3` dependency:

```bash
pip install pyterrier-dr[bgem3]
```

---
You'll also need to install FAISS.

On Colab:

!pip install faiss-cpu
On Anaconda:
```bash
!pip install faiss-cpu
```

# CPU-only version
$ conda install -c pytorch faiss-cpu
On Anaconda:

# GPU(+CPU) version
$ conda install -c pytorch faiss-gpu
```bash
# CPU-only version
conda install -c pytorch faiss-cpu
# GPU(+CPU) version
conda install -c pytorch faiss-gpu
```

You can then import the package and PyTerrier in Python:

Expand All @@ -40,6 +55,7 @@ import pyterrier_dr
| [`TasB`](https://arxiv.org/abs/2104.06967) | ✅ | ✅ | ✅ |
| [`Ance`](https://arxiv.org/abs/2007.00808) | ✅ | ✅ | ✅ |
| [`Query2Query`](https://neeva.com/blog/state-of-the-art-query2query-similarity) | ✅ | | |
| [`BGE-M3`](https://arxiv.org/abs/2402.03216) | ✅ | ✅ | ✅|

## Inference

Expand Down Expand Up @@ -166,6 +182,48 @@ retr_pipeline = model >> index.faiss_hnsw_retriever()
# ...
```

## BGE-M3 Encoder

`pyterrier_dr` also supports using BGE-M3 for indexing and retrieval with the following encoders:

1. `query_encoder()`: Encodes queries into single-vector representations only.
2. `doc_encoder()`: Encodes documents into single-vector representations only.
3. `query_multi_encoder()`: Allows user to encode queries in dense, sparse or multi-vector representations.
4. `doc_multi_encoder()`: Allows user to encode documents in dense, sparse or multi-vector representations.

What encodings are returned by both `query_multi_encoder()` and `doc_multi_encoder()` can be controlled by the `return_dense`, `return_sparse` and `return_colbert_vecs` parameters. By default, all three are set to `True`.

### Dependencies

The BGE-M3 Encoder requires the [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) library. You can install it using pip or install it as part of the `bgem3` dependency of `pyterrier_dr` (see Installation section):

```bash
pip install -U FlagEmbedding
andreaschari marked this conversation as resolved.
Show resolved Hide resolved
```

### Indexing

```python
factory = BGEM3(batch_size=32, max_length=1024, verbose=True)
encoder = factory.doc_encoder()

index = FlexIndex(f"mmarco/v2/fr_bgem3", verbose=True)
indexing_pipeline = encoder >> index

indexing_pipeline.index(pt.get_dataset(f"irds:mmarco/v2/fr").get_corpus_iter())
```

### Retrieval

```python
factory = BGEM3(batch_size=32, max_length=1024)
encoder = factory.query_encoder()

index = FlexIndex(f"mmarco/v2/fr_bgem3", verbose=True)

pipeline = encoder >> idx.np_retriever()
```

## References

- PyTerrier: PyTerrier: Declarative Experimentation in Python from BM25 to Dense Retrieval (Macdonald et al, CIKM 2021)
Expand Down
1 change: 1 addition & 0 deletions pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .sbert_models import SBertBiEncoder, Ance, Query2Query, GTR
from .tctcolbert_model import TctColBert
from .electra import ElectraScorer
from .bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder
from .cde import CDE, CDECache
146 changes: 146 additions & 0 deletions pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from tqdm import tqdm
import pyterrier as pt
import pandas as pd
import numpy as np
import torch
from .biencoder import BiEncoder

class BGEM3(BiEncoder):
andreaschari marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False):
super().__init__(batch_size, text_field, verbose)
self.model_name = model_name
self.use_fp16 = use_fp16
self.max_length = max_length
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError as e:
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install -U FlagEmbedding'")
andreaschari marked this conversation as resolved.
Show resolved Hide resolved
andreaschari marked this conversation as resolved.
Show resolved Hide resolved

self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device)


def __repr__(self):
return f'BGEM3({repr(self.model_name)})'

def encode_queries(self, texts, batch_size=None):
return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length,
return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']

def encode_docs(self, texts, batch_size=None):
return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length,
return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']

# Only does single_vec encoding
def query_encoder(self, verbose=None, batch_size=None):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size)
def doc_encoder(self, verbose=None, batch_size=None):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size)

# Can do dense, sparse and colbert encodings
def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)
def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)

class BGEM3QueryEncoder(pt.Transformer):
def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False):
self.bge_factory = bge_factory
self.verbose = verbose if verbose is not None else bge_factory.verbose
self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size
self.max_length = max_length if max_length is not None else bge_factory.max_length

self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
assert all(c in inp.columns for c in ['query'])

# check if inp is empty
if len(inp) == 0:
if self.dense:
inp = inp.assign(query_vec=[])
if self.sparse:
inp = inp.assign(query_toks=[])
if self.multivecs:
inp = inp.assign(query_embs_toks=[])
return inp

it = inp['query'].values
it, inv = np.unique(it, return_inverse=True)
if self.verbose:
it = pt.tqdm(it, desc='Encoding Queries', unit='query')
bgem3_results = self.encode(it)

if self.dense:
query_vec = [bgem3_results['dense_vecs'][i] for i in inv]
inp = inp.assign(query_vec=query_vec)
if self.sparse:
# for sparse convert ids to the actual tokens
query_toks = self.bge_factory.model.convert_id_to_token(bgem3_results['lexical_weights'])
inp = inp.assign(query_toks=query_toks)
if self.multivecs:
query_embs_toks = [bgem3_results['colbert_vecs'][i] for i in inv]
inp = inp.assign(query_embs_toks=query_embs_toks)
andreaschari marked this conversation as resolved.
Show resolved Hide resolved

return inp

def __repr__(self):
return f'{repr(self.bge_factory)}.query_encoder()'

class BGEM3DocEncoder(pt.Transformer):
def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False):
self.bge_factory = bge_factory
self.verbose = verbose if verbose is not None else bge_factory.verbose
self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size
self.max_length = max_length if max_length is not None else bge_factory.max_length

self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# check if the input dataframe contains the field(s) specified in the text_field
assert all(c in inp.columns for c in [self.bge_factory.text_field])
# check if inp is empty
if len(inp) == 0:
if self.dense:
inp = inp.assign(doc_vec=[])
if self.sparse:
inp = inp.assign(toks=[])
if self.multivecs:
inp = inp.assign(doc_embs_toks=[])
return inp

it = inp[self.bge_factory.text_field]
if self.verbose:
it = pt.tqdm(it, desc='Encoding Documents', unit='doc')
bgem3_results = self.encode(it)

if self.dense:
doc_vec = bgem3_results['dense_vecs']
andreaschari marked this conversation as resolved.
Show resolved Hide resolved
inp = inp.assign(doc_vec=list(doc_vec))
if self.sparse:
toks = bgem3_results['lexical_weights']
# for sparse convert ids to the actual tokens
toks = self.bge_factory.model.convert_id_to_token(toks)
inp = inp.assign(toks=toks)
if self.multivecs:
doc_embs_toks = bgem3_results['colbert_vecs']
inp = inp.assign(doc_embs_toks=list(doc_embs_toks))
andreaschari marked this conversation as resolved.
Show resolved Hide resolved

return inp

def __repr__(self):
return f'{repr(self.bge_factory)}.doc_encoder()'
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def get_version(rel_path):
long_description_content_type="text/markdown",
packages=setuptools.find_packages(),
install_requires=requirements,
extras_require={
['bgem3']: ['FlagEmbedding'],
andreaschari marked this conversation as resolved.
Show resolved Hide resolved
},
python_requires='>=3.6',
entry_points={
'pyterrier.artifact': [
Expand Down
50 changes: 50 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,46 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test
self.assertTrue('docno' in retr_res.columns)
self.assertTrue('score' in retr_res.columns)
self.assertTrue('rank' in retr_res.columns)

def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False):
dataset = pt.get_dataset('irds:vaswani')

docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200))
docs_df = pd.DataFrame(docs)

if test_query_multivec_encoder:
with self.subTest('query_multivec_encoder'):
topics = dataset.get_topics()
enc_topics = model(topics)
self.assertEqual(len(enc_topics), len(topics))
self.assertTrue('query_toks' in enc_topics.columns)
self.assertTrue('query_embs_toks' in enc_topics.columns)
self.assertTrue(all(c in enc_topics.columns for c in topics.columns))
self.assertEqual(enc_topics.query_toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks))
self.assertEqual(enc_topics.query_embs_toks.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs_toks))
with self.subTest('query_multivec_encoder empty'):
enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query']))
self.assertEqual(len(enc_topics_empty), 0)
self.assertTrue('query_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs_toks' in enc_topics_empty.columns)
if test_doc_multivec_encoder:
with self.subTest('doc_multi_encoder'):
enc_docs = model(pd.DataFrame(docs_df))
self.assertEqual(len(enc_docs), len(docs_df))
self.assertTrue('toks' in enc_docs.columns)
self.assertTrue('doc_embs_toks' in enc_docs.columns)
self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns))
self.assertEqual(enc_docs.toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks))
self.assertEqual(enc_docs.doc_embs_toks.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs_toks))
with self.subTest('doc_multi_encoder empty'):
enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text']))
self.assertEqual(len(enc_docs_empty), 0)
self.assertTrue('toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs_toks' in enc_docs_empty.columns)

def test_tct(self):
from pyterrier_dr import TctColBert
Expand All @@ -129,6 +169,16 @@ def test_query2query(self):
from pyterrier_dr import Query2Query
self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

def test_bgem3(self):
from pyterrier_dr import BGEM3
# create BGEM3 instance
bgem3 = BGEM3(max_length=1024)

self._base_test(bgem3.query_multi_encoder(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)
self._base_test(bgem3.doc_multi_encoder(), test_query_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True)
self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True)
def setUp(self):
import pyterrier as pt
if not pt.started():
Expand Down
Loading