Skip to content

Commit

Permalink
Documentation for FlexIndex (#31)
Browse files Browse the repository at this point in the history
* flexindex documentation

* removed unused import

* name conflict

* copy-paste error

* faster tests

* misc

* misc

* : in directives mean something else, changed to .

* some model documentation

* bind staticmethod

* default model_name flow through

* biencoder documentation

* good enough, I just need to be done with this
  • Loading branch information
seanmacavaney authored Dec 4, 2024
1 parent 9c86daa commit 4dae8e6
Show file tree
Hide file tree
Showing 24 changed files with 988 additions and 215 deletions.
2 changes: 1 addition & 1 deletion pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class BGEM3(BiEncoder):
def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False):
super().__init__(batch_size, text_field, verbose)
super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose)
self.model_name = model_name
self.use_fp16 = use_fp16
self.max_length = max_length
Expand Down
68 changes: 59 additions & 9 deletions pyterrier_dr/biencoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional
from abc import abstractmethod
import numpy as np
import pyterrier as pt
import pandas as pd
Expand All @@ -6,18 +8,33 @@


class BiEncoder(pt.Transformer):
def __init__(self, batch_size=32, text_field='text', verbose=False):
"""Represents a single-vector dense bi-encoder.
A ``BiEncoder`` encodes the text of a query or document into a dense vector.
This class functions as a transformer factory:
- Query encoding using :meth:`query_encoder`
- Document encoding using :meth:`doc_encoder`
- Text scoring (re-reranking) using :meth:`text_scorer`
It can also be used as a transformer directly. It infers which transformer to use
based on columns present in the input frame.
Note that in most cases, you will want to use a ``BiEncoder`` as part of a pipeline
with a :class:`~pyterrier_dr.FlexIndex` to perform dense indexing and retrival.
"""
def __init__(self, *, batch_size=32, text_field='text', verbose=False):
"""
Args:
batch_size: The default batch size to use for query/document encoding
text_field: The field in the input dataframe that contains the document text
verbose: Whether to show progress bars
"""
super().__init__()
self.batch_size = batch_size
self.text_field = text_field
self.verbose = verbose

def encode_queries(self, texts, batch_size=None) -> np.array:
raise NotImplementedError()

def encode_docs(self, texts, batch_size=None) -> np.array:
raise NotImplementedError()

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
with pta.validate.any(inp) as v:
v.columns(includes=['query', self.text_field], mode='scorer')
Expand Down Expand Up @@ -46,12 +63,15 @@ def doc_encoder(self, verbose=None, batch_size=None) -> pt.Transformer:
"""
return BiDocEncoder(self, verbose=verbose, batch_size=batch_size)

def scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer:
def text_scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer:
"""
Scoring (re-ranking)
Text Scoring (re-ranking)
"""
return BiScorer(self, verbose=verbose, batch_size=batch_size, sim_fn=sim_fn)

def scorer(self, verbose=None, batch_size=None, sim_fn=None) -> pt.Transformer:
return self.text_scorer(verbose=verbose, batch_size=batch_size, sim_fn=sim_fn)

@property
def sim_fn(self) -> SimFn:
"""
Expand All @@ -61,6 +81,36 @@ def sim_fn(self) -> SimFn:
return SimFn(self.config.sim_fn)
return SimFn.dot # default

@abstractmethod
def encode_queries(self, texts: List[str], batch_size: Optional[int] = None) -> np.array:
"""Abstract method to encode a list of query texts into dense vectors.
This function is used by the transformer returned by :meth:`query_encoder`.
Args:
texts: A list of query texts
batch_size: The batch size to use for encoding
Returns:
np.array: A numpy array of shape (n_queries, n_dims)
"""
raise NotImplementedError()

@abstractmethod
def encode_docs(self, texts: List[str], batch_size: Optional[int] = None) -> np.array:
"""Abstract method to encode a list of document texts into dense vectors.
This function is used by the transformer returned by :meth:`doc_encoder`.
Args:
texts: A list of document texts
batch_size: The batch size to use for encoding
Returns:
np.array: A numpy array of shape (n_docs, n_dims)
"""
raise NotImplementedError()


class BiQueryEncoder(pt.Transformer):
def __init__(self, bi_encoder_model: BiEncoder, verbose=None, batch_size=None):
Expand Down
2 changes: 1 addition & 1 deletion pyterrier_dr/cde.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class CDE(BiEncoder):
def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = None, batch_size=32, text_field='text', verbose=False, device=None):
super().__init__(batch_size, text_field, verbose)
super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose)
self.model_name = model_name
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down
205 changes: 165 additions & 40 deletions pyterrier_dr/flex/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Union, Iterable, Dict
import shutil
import itertools
import json
from pathlib import Path
from warnings import warn
import numpy as np
import more_itertools
import pandas as pd
Expand All @@ -20,12 +20,32 @@ class IndexingMode(Enum):


class FlexIndex(pta.Artifact, pt.Indexer):
def __init__(self, index_path, num_results=1000, sim_fn=SimFn.dot, indexing_mode=IndexingMode.create, verbose=True):
super().__init__(index_path)
self.index_path = Path(index_path)
self.num_results = num_results
""" Represents a FLexible EXecution (FLEX) Index, which is a dense index format.
FLEX allows for a variety of retrieval implementations (NumPy, FAISS, etc.) and algorithms (exhaustive, HNSW, etc.)
to be tested. In most cases, the same vector storage can be used across implementations and algorithms, saving
considerably on disk space.
"""

ARTIFACT_TYPE = 'dense_index'
ARTIFACT_FORMAT = 'flex'


def __init__(self,
path: str,
*,
sim_fn: Union[SimFn, str] = SimFn.dot,
verbose: bool = True
):
"""
Args:
path: The path to the index directory
sim_fn: The similarity function to use
verbose: Whether to display verbose output (e.g., progress bars)
"""
super().__init__(path)
self.index_path = Path(path)
self.sim_fn = SimFn(sim_fn)
self.indexing_mode = IndexingMode(indexing_mode)
self.verbose = verbose
self._meta = None
self._docnos = None
Expand Down Expand Up @@ -53,44 +73,88 @@ def __len__(self):
meta, = self.payload(return_dvecs=False, return_docnos=False)
return meta['doc_count']

def index(self, inp):
if isinstance(inp, pd.DataFrame):
inp = inp.to_dict(orient="records")
inp = more_itertools.peekable(inp)
path = Path(self.index_path)
if path.exists():
if self.indexing_mode == IndexingMode.overwrite:
shutil.rmtree(path)
else:
raise RuntimeError(f'Index already exists at {self.index_path}. If you want to delete and re-create an existing index, you can pass indexing_mode=IndexingMode.overwrite')
path.mkdir(parents=True, exist_ok=True)
vec_size = None
count = 0
if self.verbose:
inp = pt.tqdm(inp, desc='indexing', unit='dvec')
with open(path/'vecs.f4', 'wb') as fout, Lookup.builder(path/'docnos.npids') as docnos:
for d in inp:
vec = d['doc_vec']
if vec_size is None:
vec_size = vec.shape[0]
elif vec_size != vec.shape[0]:
raise ValueError(f'Inconsistent vector shapes detected (expected {vec_size} but found {vec.shape[0]})')
vec = vec.astype(np.float32)
fout.write(vec.tobytes())
docnos.add(d['docno'])
count += 1
with open(path/'pt_meta.json', 'wt') as f_meta:
json.dump({"type": "dense_index", "format": "flex", "vec_size": vec_size, "doc_count": count}, f_meta)
def index(self, inp: Iterable[Dict]) -> pta.Artifact:
"""Index the given input data stream to a new index at this location.
Each record in ``inp`` is expected to be a dictionary containing at least two keys: ``docno`` (a unique document
identifier) and ``doc_vec`` (a dense vector representation of the document).
Typically this method will be used in a pipeline of operations, where the input data is first transformed by a
document encoder to add the ``doc_vec`` values before it is indexed. For example:
.. code-block:: python
:caption: Index documents into a :class:`~pyterrier_dr.FlexIndex` using a :class:`~pyterrier_dr.TasB` encoder.
from pyterrier_dr import TasB, FlexIndex
encoder = TasB.dot()
index = FlexIndex('my_index')
pipeline = encoder >> index
pipeline.index([
{'docno': 'doc1', 'text': 'hello'},
{'docno': 'doc2', 'text': 'world'},
])
Args:
inp: An iterable of dictionaries to index.
Returns:
:class:`pyterrier_alpha.Artifact`: A reference back to this index (``self``).
Raises:
RuntimeError: If the index is aready built.
"""
return self.indexer().index(inp)

def indexer(self, *, mode: Union[IndexingMode, str] = IndexingMode.create) -> 'FlexIndexer':
"""Return an indexer for this index with the specified options.
This transformer gives more fine-grained control over the indexing process, allowing you to specify whether
to create a new index or overwrite an existing one.
Similar to :meth:`index`, this method will typically be used in a pipeline of operations, where the input data
is first transformed by a document encoder to add the ``doc_vec`` values before it is indexed. For example:
.. code-block:: python
:caption: Oerwrite a :class:`~pyterrier_dr.FlexIndex` using a :class:`~pyterrier_dr.TasB` encoder.
from pyterrier_dr import TasB, FlexIndex
encoder = TasB.dot()
index = FlexIndex('my_index')
pipeline = encoder >> index.indexer(mode='overwrite')
pipeline.index([
{'docno': 'doc1', 'text': 'hello'},
{'docno': 'doc2', 'text': 'world'},
])
Args:
mode: The indexing mode to use (``create`` or ``overwrite``).
Returns:
:class:`~pyterrier.Indexer`: A new indexer instance.
"""
return FlexIndexer(self, mode=mode)

def transform(self, inp):
with pta.validate.any(inp) as v:
v.query_frame(extra_columns=['query_vec'], mode='np_retriever')
v.query_frame(extra_columns=['query_vec'], mode='retriever')
v.result_frame(extra_columns=['query_vec'], mode='scorer')

if v.mode == 'retriever':
return self.retriever()(inp)
if v.mode == 'scorer':
return self.scorer()(inp)

def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True) -> Iterable[Dict]:
"""Iterate over the documents in the index.
if v.mode == 'np_retriever':
warn("performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster")
return self.np_retriever()(inp)
Args:
start_idx: The index of the first document to return (or ``None`` to start at the first document).
stop_idx: The index of the last document to return (or ``None`` to end on the last document).
verbose: Whether to display a progress bar.
def get_corpus_iter(self, start_idx=None, stop_idx=None, verbose=True):
Yields:
Dict[str,Any]: A dictionary with keys ``docno`` and ``doc_vec``.
"""
docnos, dvecs, meta = self.payload()
docno_iter = iter(docnos)
if start_idx is not None or stop_idx is not None:
Expand All @@ -111,9 +175,70 @@ def _load_docids(self, inp):
docnos, config = self.payload(return_dvecs=False)
return docnos.inv[inp['docno'].values] # look up docids from docnos

def built(self):
def built(self) -> bool:
"""Check if the index has been built.
Returns:
bool: ``True`` if the index has been built, otherwise ``False``.
"""
return self.index_path.exists()

def docnos(self) -> Lookup:
"""Return the document identifier (docno) lookup data structure.
Returns:
:class:`npids.Lookup`: The document number lookup.
"""
docnos, meta = self.payload(return_dvecs=False)
return docnos


class FlexIndexer(pt.Indexer):
def __init__(self, index: FlexIndex, mode: Union[IndexingMode, str] = IndexingMode.create):
self._index = index
self.mode = IndexingMode(mode)

def __repr__(self):
return f'{self._index}.indexer(mode={self.mode!r})'

def transform(self, inp):
raise RuntimeError("FlexIndexer cannot be used as a transformer, use .index() instead")

def index(self, inp):
if isinstance(inp, pd.DataFrame):
inp = inp.to_dict(orient="records")
inp = more_itertools.peekable(inp)
path = Path(self._index.index_path)
if path.exists():
if self.mode == IndexingMode.overwrite:
shutil.rmtree(path)
else:
raise RuntimeError(f'Index already exists at {self._index.index_path}. If you want to delete and re-create an existing index, you can pass index.indexer(mode="overwrite")')
path.mkdir(parents=True, exist_ok=True)
vec_size = None
count = 0
if self._index.verbose:
inp = pt.tqdm(inp, desc='indexing', unit='dvec')
with open(path/'vecs.f4', 'wb') as fout, Lookup.builder(path/'docnos.npids') as docnos:
for d in inp:
vec = d['doc_vec']
if vec_size is None:
vec_size = vec.shape[0]
elif vec_size != vec.shape[0]:
raise ValueError(f'Inconsistent vector shapes detected (expected {vec_size} but found {vec.shape[0]})')
vec = vec.astype(np.float32)
fout.write(vec.tobytes())
docnos.add(d['docno'])
count += 1
with open(path/'pt_meta.json', 'wt') as f_meta:
json.dump({
"type": self._index.ARTIFACT_TYPE,
"format": self._index.ARTIFACT_FORMAT,
"vec_size": vec_size,
"doc_count": count
}, f_meta)
return self._index


def _load_dvecs(flex_index, inp):
dvecs, config = flex_index.payload(return_docnos=False)
Expand Down
17 changes: 16 additions & 1 deletion pyterrier_dr/flex/corpus_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,22 @@
from . import FlexIndex


def _corpus_graph(self, k=16, batch_size=8192):
def _corpus_graph(self, k: int = 16, *, batch_size: int = 8192):
"""Return the corpus graph (neighborhood graph) for the index.
The corpus graph is a directed graph where each node represents a document and each edge represents a
connection between two documents. The graph is built by computing the cosine similarity between each
pair of documents and storing the k-nearest neighbors for each document.
If the corpus graph has not been built yet, it will be built using the given k and batch size.
Args:
k: The number of neighbors to store for each document.
batch_size: The number of vectors to process in each batch.
Returns:
:class:`pyterrier_adaptive.CorpusGraph`: The corpus graph for the index.
"""
from pyterrier_adaptive import CorpusGraph
key = ('corpus_graph', k)
if key not in self._cache:
Expand Down
Loading

0 comments on commit 4dae8e6

Please sign in to comment.