diff --git a/README.md b/README.md index 5e566b1..2c5dfe6 100644 --- a/README.md +++ b/README.md @@ -2,28 +2,43 @@ This provides various Dense Retrieval functionality for [PyTerrier](https://github.com/terrier-org/pyterrier). - ## Installation -This repostory can be installed using pip. +This repository can be installed using pip. + +```bash +pip install pyterrier-dr +``` + +If you want the latest version of `pyterrier_dr`, you can install direct from the Github repo: ```bash pip install --upgrade git+https://github.com/terrierteam/pyterrier_dr.git ``` +if you want to use the BGE-M3 encoder with `pyterrier_dr`, you can install the package with the `bgem3` dependency: + +```bash +pip install pyterrier-dr[bgem3] +``` + +--- You'll also need to install FAISS. On Colab: - !pip install faiss-cpu - -On Anaconda: +```bash +!pip install faiss-cpu +``` - # CPU-only version - $ conda install -c pytorch faiss-cpu +On Anaconda: - # GPU(+CPU) version - $ conda install -c pytorch faiss-gpu +```bash +# CPU-only version +conda install -c pytorch faiss-cpu +# GPU(+CPU) version +conda install -c pytorch faiss-gpu +``` You can then import the package and PyTerrier in Python: @@ -40,6 +55,7 @@ import pyterrier_dr | [`TasB`](https://arxiv.org/abs/2104.06967) | ✅ | ✅ | ✅ | | [`Ance`](https://arxiv.org/abs/2007.00808) | ✅ | ✅ | ✅ | | [`Query2Query`](https://neeva.com/blog/state-of-the-art-query2query-similarity) | ✅ | | | +| [`BGE-M3`](https://arxiv.org/abs/2402.03216) | ✅ | ✅ | ✅| ## Inference @@ -166,6 +182,44 @@ retr_pipeline = model >> index.faiss_hnsw_retriever() # ... ``` +## BGE-M3 Encoder + +`pyterrier_dr` also supports using BGE-M3 for indexing and retrieval with the following encoders: + + 1. `query_encoder()`: Encodes queries into single-vector representations only. + 2. `doc_encoder()`: Encodes documents into single-vector representations only. + 3. `query_multi_encoder()`: Allows user to encode queries in dense, sparse or multi-vector representations. + 4. `doc_multi_encoder()`: Allows user to encode documents in dense, sparse or multi-vector representations. + +What encodings are returned by both `query_multi_encoder()` and `doc_multi_encoder()` can be controlled by the `return_dense`, `return_sparse` and `return_colbert_vecs` parameters. By default, all three are set to `True`. + +### Dependencies + +The BGE-M3 Encoder requires the [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) library. You can install it as part of the `bgem3` dependency of `pyterrier_dr` (see Installation section). + +### Indexing + +```python +factory = BGEM3(batch_size=32, max_length=1024, verbose=True) +encoder = factory.doc_encoder() + +index = FlexIndex(f"mmarco/v2/fr_bgem3", verbose=True) +indexing_pipeline = encoder >> index + +indexing_pipeline.index(pt.get_dataset(f"irds:mmarco/v2/fr").get_corpus_iter()) +``` + +### Retrieval + +```python + factory = BGEM3(batch_size=32, max_length=1024) + encoder = factory.query_encoder() + + index = FlexIndex(f"mmarco/v2/fr_bgem3", verbose=True) + + pipeline = encoder >> idx.np_retriever() +``` + ## References - PyTerrier: PyTerrier: Declarative Experimentation in Python from BM25 to Dense Retrieval (Macdonald et al, CIKM 2021) diff --git a/pyterrier_dr/__init__.py b/pyterrier_dr/__init__.py index 8904be3..2d053c4 100644 --- a/pyterrier_dr/__init__.py +++ b/pyterrier_dr/__init__.py @@ -8,4 +8,5 @@ from .sbert_models import SBertBiEncoder, Ance, Query2Query, GTR from .tctcolbert_model import TctColBert from .electra import ElectraScorer +from .bge_m3 import BGEM3, BGEM3QueryEncoder, BGEM3DocEncoder from .cde import CDE, CDECache diff --git a/pyterrier_dr/bge_m3.py b/pyterrier_dr/bge_m3.py new file mode 100644 index 0000000..8e79ad5 --- /dev/null +++ b/pyterrier_dr/bge_m3.py @@ -0,0 +1,140 @@ +from tqdm import tqdm +import pyterrier as pt +import pandas as pd +import numpy as np +import torch +from .biencoder import BiEncoder + +class BGEM3(BiEncoder): + def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False): + super().__init__(batch_size, text_field, verbose) + self.model_name = model_name + self.use_fp16 = use_fp16 + self.max_length = max_length + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(device) + try: + from FlagEmbedding import BGEM3FlagModel + except ImportError as e: + raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'") + + self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device) + + + def __repr__(self): + return f'BGEM3({repr(self.model_name)})' + + def encode_queries(self, texts, batch_size=None): + return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length, + return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs'] + + def encode_docs(self, texts, batch_size=None): + return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length, + return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs'] + + # Only does dense (single_vec) encoding + def query_encoder(self, verbose=None, batch_size=None): + return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size) + def doc_encoder(self, verbose=None, batch_size=None): + return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size) + + # Does all three BGE-M3 encodings: dense, sparse and colbert(multivec) + def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True): + return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs) + def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True): + return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs) + +class BGEM3QueryEncoder(pt.Transformer): + def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False): + self.bge_factory = bge_factory + self.verbose = verbose if verbose is not None else bge_factory.verbose + self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size + self.max_length = max_length if max_length is not None else bge_factory.max_length + + self.dense = return_dense + self.sparse = return_sparse + self.multivecs = return_colbert_vecs + + def encode(self, texts): + return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length, + return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + assert all(c in inp.columns for c in ['query']) + + # check if inp is empty + if len(inp) == 0: + if self.dense: + inp = inp.assign(query_vec=[]) + if self.sparse: + inp = inp.assign(query_toks=[]) + if self.multivecs: + inp = inp.assign(query_embs=[]) + return inp + + it = inp['query'].values + it, inv = np.unique(it, return_inverse=True) + if self.verbose: + it = pt.tqdm(it, desc='Encoding Queries', unit='query') + bgem3_results = self.encode(it) + + if self.dense: + inp = inp.assign(query_vec=[bgem3_results['dense_vecs'][i] for i in inv]) + if self.sparse: + # for sparse convert ids to the actual tokens + query_toks = self.bge_factory.model.convert_id_to_token(bgem3_results['lexical_weights']) + inp = inp.assign(query_toks=query_toks) + if self.multivecs: + inp = inp.assign(query_embs=[bgem3_results['colbert_vecs'][i] for i in inv]) + return inp + + def __repr__(self): + return f'{repr(self.bge_factory)}.query_encoder()' + +class BGEM3DocEncoder(pt.Transformer): + def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False): + self.bge_factory = bge_factory + self.verbose = verbose if verbose is not None else bge_factory.verbose + self.batch_size = batch_size if batch_size is not None else bge_factory.batch_size + self.max_length = max_length if max_length is not None else bge_factory.max_length + + self.dense = return_dense + self.sparse = return_sparse + self.multivecs = return_colbert_vecs + + def encode(self, texts): + return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length, + return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) + + def transform(self, inp: pd.DataFrame) -> pd.DataFrame: + # check if the input dataframe contains the field(s) specified in the text_field + assert all(c in inp.columns for c in [self.bge_factory.text_field]) + # check if inp is empty + if len(inp) == 0: + if self.dense: + inp = inp.assign(doc_vec=[]) + if self.sparse: + inp = inp.assign(toks=[]) + if self.multivecs: + inp = inp.assign(doc_embs=[]) + return inp + + it = inp[self.bge_factory.text_field] + if self.verbose: + it = pt.tqdm(it, desc='Encoding Documents', unit='doc') + bgem3_results = self.encode(it) + + if self.dense: + inp = inp.assign(doc_vec=list(bgem3_results['dense_vecs'])) + if self.sparse: + toks = bgem3_results['lexical_weights'] + # for sparse convert ids to the actual tokens + toks = self.bge_factory.model.convert_id_to_token(toks) + inp = inp.assign(toks=toks) + if self.multivecs: + inp = inp.assign(doc_embs=list(bgem3_results['colbert_vecs'])) + return inp + + def __repr__(self): + return f'{repr(self.bge_factory)}.doc_encoder()' diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f84e1a..ae6b446 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pytest pytest-subtests git+https://github.com/terrierteam/pyterrier_adaptive voyager +FlagEmbedding diff --git a/setup.py b/setup.py index 85598b3..ff35d0e 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,9 @@ def get_version(rel_path): long_description_content_type="text/markdown", packages=setuptools.find_packages(), install_requires=requirements, + extras_require={ + 'bgem3': ['FlagEmbedding'], + }, python_requires='>=3.6', entry_points={ 'pyterrier.artifact': [ diff --git a/tests/test_models.py b/tests/test_models.py index 58156b9..738bae6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -104,6 +104,46 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test self.assertTrue('docno' in retr_res.columns) self.assertTrue('score' in retr_res.columns) self.assertTrue('rank' in retr_res.columns) + + def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False): + dataset = pt.get_dataset('irds:vaswani') + + docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200)) + docs_df = pd.DataFrame(docs) + + if test_query_multivec_encoder: + with self.subTest('query_multivec_encoder'): + topics = dataset.get_topics() + enc_topics = model(topics) + self.assertEqual(len(enc_topics), len(topics)) + self.assertTrue('query_toks' in enc_topics.columns) + self.assertTrue('query_embs' in enc_topics.columns) + self.assertTrue(all(c in enc_topics.columns for c in topics.columns)) + self.assertEqual(enc_topics.query_toks.dtype, object) + self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks)) + self.assertEqual(enc_topics.query_embs.dtype, object) + self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs)) + with self.subTest('query_multivec_encoder empty'): + enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query'])) + self.assertEqual(len(enc_topics_empty), 0) + self.assertTrue('query_toks' in enc_topics_empty.columns) + self.assertTrue('query_embs' in enc_topics_empty.columns) + if test_doc_multivec_encoder: + with self.subTest('doc_multi_encoder'): + enc_docs = model(pd.DataFrame(docs_df)) + self.assertEqual(len(enc_docs), len(docs_df)) + self.assertTrue('toks' in enc_docs.columns) + self.assertTrue('doc_embs' in enc_docs.columns) + self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns)) + self.assertEqual(enc_docs.toks.dtype, object) + self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks)) + self.assertEqual(enc_docs.doc_embs.dtype, object) + self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs)) + with self.subTest('doc_multi_encoder empty'): + enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text'])) + self.assertEqual(len(enc_docs_empty), 0) + self.assertTrue('toks' in enc_docs_empty.columns) + self.assertTrue('doc_embs' in enc_docs_empty.columns) def test_tct(self): from pyterrier_dr import TctColBert @@ -129,6 +169,16 @@ def test_query2query(self): from pyterrier_dr import Query2Query self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False) + def test_bgem3(self): + from pyterrier_dr import BGEM3 + # create BGEM3 instance + bgem3 = BGEM3(max_length=1024) + + self._base_test(bgem3.query_multi_encoder(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False) + self._base_test(bgem3.doc_multi_encoder(), test_query_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False) + + self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True) + self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True) def setUp(self): import pyterrier as pt if not pt.started():