Skip to content

Commit

Permalink
Add colbert as jupyter notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mam10eks committed May 3, 2023
1 parent 7f9ba0c commit 1ba9853
Show file tree
Hide file tree
Showing 2 changed files with 370 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tira-ir-starters/pyterrier-colbert/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ RUN python3 -c "import pandas as pd; from tira.third_party_integrations import e

COPY pyterrier-colbert/reranking.py /reranking.py

COPY pyterrier-colbert/bm25-colbert.ipynb /workspace/

RUN jupyter trust /workspace/*.ipynb

366 changes: 366 additions & 0 deletions tira-ir-starters/pyterrier-colbert/bm25-colbert.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "8c3da078-f7fc-4d37-904c-532bb26d4321",
"metadata": {},
"source": [
"# This is my cool Pipeline"
]
},
{
"cell_type": "markdown",
"id": "66fd2911-c97a-4f91-af28-8c7e381573b6",
"metadata": {},
"source": [
"### Step 1: Import everything and load variables"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7ae3c54f-aba1-45bf-b074-e78a99f6405f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I will use a small hardcoded example located in ./sample-input-full-rank.\n",
"The output directory is /tmp/\n"
]
}
],
"source": [
"import pyterrier as pt\n",
"import pandas as pd\n",
"from tira.third_party_integrations import ensure_pyterrier_is_loaded, get_input_directory_and_output_directory, persist_and_normalize_run\n",
"import json\n",
"from tqdm import tqdm\n",
"import os\n",
"\n",
"ensure_pyterrier_is_loaded()\n",
"input_directory, output_directory = get_input_directory_and_output_directory('./sample-input-full-rank')\n",
"from pyterrier_colbert.ranking import ColBERTFactory\n"
]
},
{
"cell_type": "markdown",
"id": "8c563b0e-97ac-44a2-ba2f-18858f1506bb",
"metadata": {},
"source": [
"### Step 2: Load the Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e35230af-66ec-4607-a97b-127bd890fa59",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 2: Load the data.\n"
]
}
],
"source": [
"print('Step 2: Load the data.')\n",
"\n",
"queries = pt.io.read_topics(input_directory + '/queries.xml', format='trecxml')\n",
"\n",
"documents = [json.loads(i) for i in open(input_directory + '/documents.jsonl', 'r')]\n"
]
},
{
"cell_type": "markdown",
"id": "d2c3108e",
"metadata": {},
"source": [
"### Step 3: Create the Index"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d6141df7",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 3: Create the Index.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.51it/s]\n"
]
}
],
"source": [
"print('Step 3: Create the Index.')\n",
"\n",
"!rm -Rf ./index\n",
"iter_indexer = pt.IterDictIndexer(\"./index\", meta={'docno' : 100, 'text': 10240})\n",
"index_ref = iter_indexer.index(tqdm(documents))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bdf81496",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing ColBERT: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[May 03, 20:02:45] #> Loading model checkpoint.\n",
"[May 03, 20:02:45] #> Loading checkpoint http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/torch/hub.py:647: UserWarning: Falling back to the old format < 1.6. This support will be deprecated in favor of default zipfile format introduced in 1.6. Please redo torch.save() to save it in the new zipfile format.\n",
" warnings.warn('Falling back to the old format < 1.6. This support will be '\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[May 03, 20:03:00] #> checkpoint['epoch'] = 0\n",
"[May 03, 20:03:00] #> checkpoint['batch'] = 44500\n"
]
}
],
"source": [
"bm25 = pt.BatchRetrieve(index_ref, wmodel=\"BM25\", metadata=['docno', 'text'])\n",
"\n",
" \n",
"pytcolbert = ColBERTFactory(os.environ['MODEL_NAME'], \"/index\", \"index\")\n",
"\n",
"pipeline = bm25 % 1000 >> pytcolbert.text_scorer(verbose=True)"
]
},
{
"cell_type": "markdown",
"id": "806c4638-ccee-4470-a74c-2a85d9ee2cfc",
"metadata": {},
"source": [
"### Step 4: Create Run"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a191f396-e896-4792-afaf-574e452640f5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 4: Create Run.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.99q/s]\n"
]
}
],
"source": [
"print('Step 4: Create Run.')\n",
"run = pipeline(queries)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c0e07fca-de98-4de2-b6a7-abfd516c652c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>qid</th>\n",
" <th>query</th>\n",
" <th>docno</th>\n",
" <th>score</th>\n",
" <th>rank</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>fox jumps above animal</td>\n",
" <td>pangram-04</td>\n",
" <td>21.435955</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>fox jumps above animal</td>\n",
" <td>pangram-02</td>\n",
" <td>20.018982</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>fox jumps above animal</td>\n",
" <td>pangram-03</td>\n",
" <td>14.482786</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>fox jumps above animal</td>\n",
" <td>pangram-01</td>\n",
" <td>12.042614</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2</td>\n",
" <td>multiple animals including a zebra</td>\n",
" <td>pangram-03</td>\n",
" <td>20.021858</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2</td>\n",
" <td>multiple animals including a zebra</td>\n",
" <td>pangram-01</td>\n",
" <td>16.255941</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>2</td>\n",
" <td>multiple animals including a zebra</td>\n",
" <td>pangram-05</td>\n",
" <td>16.099251</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" qid query docno score rank\n",
"0 1 fox jumps above animal pangram-04 21.435955 0\n",
"1 1 fox jumps above animal pangram-02 20.018982 1\n",
"2 1 fox jumps above animal pangram-03 14.482786 2\n",
"3 1 fox jumps above animal pangram-01 12.042614 3\n",
"4 2 multiple animals including a zebra pangram-03 20.021858 0\n",
"5 2 multiple animals including a zebra pangram-01 16.255941 1\n",
"6 2 multiple animals including a zebra pangram-05 16.099251 2"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"run"
]
},
{
"cell_type": "markdown",
"id": "28c40a2e-0f96-4ae8-aa5e-55a5e7ef9dee",
"metadata": {},
"source": [
"### Step 5: Persist Run"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "12e5bb42-ed1f-41ba-b7a5-cb43ebca96f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 5: Persist Run.\n"
]
}
],
"source": [
"print('Step 5: Persist Run.')\n",
"\n",
"persist_and_normalize_run(run, output_file=output_directory, system_name='colbert', depth=1000)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 1ba9853

Please sign in to comment.