Skip to content

Commit 81b7a1a

Browse files
ColBERT Upstream Updates (#19)
- `IndexUpdater` class for adding/removing new documents from an existing index. - `class_factory` wrapper for `HF_ColBERT` to initializing new types of models: - AutoModel - BERT - Deberta - Electra - Roberta - XLMRoberta - README updates.
1 parent bf4df83 commit 81b7a1a

24 files changed

+980
-104
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ fast**RAG** is a research framework designed to facilitate the building of retri
2020

2121
## Updates
2222

23+
- **June 2023**: ColBERT index modification: adding/removing documents; see [IndexUpdater](libs/colbert/colbert/index_updater.py).
2324
- **May 2023**: [RAG with LLM and dynamic prompt synthesis example](examples/rag-prompt-hf.ipynb).
2425
- **April 2023**: Qdrant `DocumentStore` support.
2526

fastrag/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastrag import image_generators, kg_creators, rankers, readers, retrievers, stores
55
from fastrag.utils import add_timing_to_pipeline
66

7-
__version__ = "1.2.0"
7+
__version__ = "1.3.0"
88

99

1010
def load_pipeline(config_path: str) -> Pipeline:

libs/colbert/README.md

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 🚨 **Announcements**
2+
3+
* (1/29/23) We have merged a new index updater feature and support for additional Hugging Face models! These are in beta so please give us feedback as you try them out.
4+
* (1/24/23) If you're looking for the **DSP** framework for composing ColBERTv2 and LLMs, it's at: https://github.com/stanfordnlp/dsp
5+
16
# ColBERT (v2)
27

38
### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
@@ -18,7 +23,7 @@ These rich interactions allow ColBERT to surpass the quality of _single-vector_
1823
* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21).
1924
* [**Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval**](https://arxiv.org/abs/2101.00436) (NeurIPS'21).
2025
* [**ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction**](https://arxiv.org/abs/2112.01488) (NAACL'22).
21-
* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (preprint).
26+
* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (CIKM'22).
2227

2328
----
2429

@@ -29,7 +34,7 @@ The ColBERTv1 code from the SIGIR'20 paper is in the [`colbertv1` branch](https:
2934

3035
## Installation
3136

32-
ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library.
37+
ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [Hugging Face Transformers](https://github.com/huggingface/transformers) library.
3338

3439
We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).)
3540

@@ -161,6 +166,19 @@ if __name__=='__main__':
161166
print(f"Saved checkpoint to {checkpoint_path}...")
162167
```
163168

169+
## Running a lightweight ColBERTv2 server
170+
We provide a script to run a lightweight server which serves k (upto 100) results in ranked order for a given search query, in JSON format. This script can be used to power DSP programs.
171+
172+
To run the server, update the environment variables `INDEX_ROOT` and `INDEX_NAME` in the `.env` file to point to the appropriate ColBERT index. The run the following command:
173+
```
174+
python server.py
175+
```
176+
177+
A sample query:
178+
```
179+
http://localhost:8893/api/search?query=Who won the 2022 FIFA world cup&k=25
180+
```
181+
164182
## Branches
165183

166184
### Supported branches

libs/colbert/colbert/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .index_updater import IndexUpdater
12
from .indexer import Indexer
23
from .modeling.checkpoint import Checkpoint
34
from .searcher import Searcher
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from collections import defaultdict
2+
3+
import tqdm
4+
import ujson
5+
from colbert.data import Ranking
6+
from colbert.distillation.scorer import Scorer
7+
from colbert.infra import Run
8+
from colbert.infra.provenance import Provenance
9+
from colbert.utility.utils.save_metadata import get_metadata_only
10+
from colbert.utils.utils import print_message, zipstar
11+
12+
13+
class RankingScorer:
14+
def __init__(self, scorer: Scorer, ranking: Ranking):
15+
self.scorer = scorer
16+
self.ranking = ranking.tolist()
17+
self.__provenance = Provenance()
18+
19+
print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!")
20+
21+
def provenance(self):
22+
return self.__provenance
23+
24+
def run(self):
25+
print_message(f"#> Starting..")
26+
27+
qids, pids, *_ = zipstar(self.ranking)
28+
distillation_scores = self.scorer.launch(qids, pids)
29+
30+
scores_by_qid = defaultdict(list)
31+
32+
for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)):
33+
scores_by_qid[qid].append((score, pid))
34+
35+
with Run().open("distillation_scores.json", "w") as f:
36+
for qid in tqdm.tqdm(scores_by_qid):
37+
obj = (qid, scores_by_qid[qid])
38+
f.write(ujson.dumps(obj) + "\n")
39+
40+
output_path = f.name
41+
print_message(f"#> Saved the distillation_scores to {output_path}")
42+
43+
with Run().open(f"{output_path}.meta", "w") as f:
44+
d = {}
45+
d["metadata"] = get_metadata_only()
46+
d["provenance"] = self.provenance()
47+
line = ujson.dumps(d, indent=4)
48+
f.write(line)
49+
50+
return output_path
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import tqdm
3+
from colbert.infra import Run, RunConfig
4+
from colbert.infra.launcher import Launcher
5+
from colbert.modeling.reranker.electra import ElectraReranker
6+
from colbert.utils.utils import flatten
7+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
8+
9+
DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
10+
11+
12+
class Scorer:
13+
def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256):
14+
self.queries = queries
15+
self.collection = collection
16+
self.model = model
17+
18+
self.maxlen = maxlen
19+
self.bsize = bsize
20+
21+
def launch(self, qids, pids):
22+
launcher = Launcher(self._score_pairs_process, return_all=True)
23+
outputs = launcher.launch(Run().config, qids, pids)
24+
25+
return flatten(outputs)
26+
27+
def _score_pairs_process(self, config, qids, pids):
28+
assert len(qids) == len(pids), (len(qids), len(pids))
29+
share = 1 + len(qids) // config.nranks
30+
offset = config.rank * share
31+
endpos = (1 + config.rank) * share
32+
33+
return self._score_pairs(
34+
qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1)
35+
)
36+
37+
def _score_pairs(self, qids, pids, show_progress=False):
38+
tokenizer = AutoTokenizer.from_pretrained(self.model)
39+
model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()
40+
41+
assert len(qids) == len(pids), (len(qids), len(pids))
42+
43+
scores = []
44+
45+
model.eval()
46+
with torch.inference_mode():
47+
with torch.cuda.amp.autocast():
48+
for offset in tqdm.tqdm(
49+
range(0, len(qids), self.bsize), disable=(not show_progress)
50+
):
51+
endpos = offset + self.bsize
52+
53+
queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
54+
passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
55+
56+
features = tokenizer(
57+
queries_,
58+
passages_,
59+
padding="longest",
60+
truncation=True,
61+
return_tensors="pt",
62+
max_length=self.maxlen,
63+
).to(model.device)
64+
65+
scores.append(model(**features).logits.flatten())
66+
67+
scores = torch.cat(scores)
68+
scores = scores.tolist()
69+
70+
Run().print(f"Returning with {len(scores)} scores")
71+
72+
return scores
73+
74+
75+
# LONG-TERM TODO: This can be sped up by sorting by length in advance.

libs/colbert/colbert/evaluation/loaders.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ def load_collection(collection_path):
176176
print(f"{line_idx // 1000 // 1000}M", end=" ", flush=True)
177177

178178
pid, passage, *rest = line.strip("\n\r ").split("\t")
179-
# id could be either "id" (the first line), a number or have the format "docNUM"
180-
assert pid == "id" or int(pid if pid.isnumeric() else pid[3:]) == line_idx
179+
assert pid == "id" or int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}"
181180

182181
if len(rest) >= 1:
183182
title = rest[0]

0 commit comments

Comments
 (0)