-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
370 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,366 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8c3da078-f7fc-4d37-904c-532bb26d4321", | ||
"metadata": {}, | ||
"source": [ | ||
"# This is my cool Pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "66fd2911-c97a-4f91-af28-8c7e381573b6", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 1: Import everything and load variables" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "7ae3c54f-aba1-45bf-b074-e78a99f6405f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"I will use a small hardcoded example located in ./sample-input-full-rank.\n", | ||
"The output directory is /tmp/\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import pyterrier as pt\n", | ||
"import pandas as pd\n", | ||
"from tira.third_party_integrations import ensure_pyterrier_is_loaded, get_input_directory_and_output_directory, persist_and_normalize_run\n", | ||
"import json\n", | ||
"from tqdm import tqdm\n", | ||
"import os\n", | ||
"\n", | ||
"ensure_pyterrier_is_loaded()\n", | ||
"input_directory, output_directory = get_input_directory_and_output_directory('./sample-input-full-rank')\n", | ||
"from pyterrier_colbert.ranking import ColBERTFactory\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8c563b0e-97ac-44a2-ba2f-18858f1506bb", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 2: Load the Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "e35230af-66ec-4607-a97b-127bd890fa59", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Step 2: Load the data.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print('Step 2: Load the data.')\n", | ||
"\n", | ||
"queries = pt.io.read_topics(input_directory + '/queries.xml', format='trecxml')\n", | ||
"\n", | ||
"documents = [json.loads(i) for i in open(input_directory + '/documents.jsonl', 'r')]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d2c3108e", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 3: Create the Index" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "d6141df7", | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Step 3: Create the Index.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.51it/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print('Step 3: Create the Index.')\n", | ||
"\n", | ||
"!rm -Rf ./index\n", | ||
"iter_indexer = pt.IterDictIndexer(\"./index\", meta={'docno' : 100, 'text': 10240})\n", | ||
"index_ref = iter_indexer.index(tqdm(documents))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "bdf81496", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing ColBERT: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']\n", | ||
"- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | ||
"- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||
"Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[May 03, 20:02:45] #> Loading model checkpoint.\n", | ||
"[May 03, 20:02:45] #> Loading checkpoint http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/conda/lib/python3.7/site-packages/torch/hub.py:647: UserWarning: Falling back to the old format < 1.6. This support will be deprecated in favor of default zipfile format introduced in 1.6. Please redo torch.save() to save it in the new zipfile format.\n", | ||
" warnings.warn('Falling back to the old format < 1.6. This support will be '\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[May 03, 20:03:00] #> checkpoint['epoch'] = 0\n", | ||
"[May 03, 20:03:00] #> checkpoint['batch'] = 44500\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"bm25 = pt.BatchRetrieve(index_ref, wmodel=\"BM25\", metadata=['docno', 'text'])\n", | ||
"\n", | ||
" \n", | ||
"pytcolbert = ColBERTFactory(os.environ['MODEL_NAME'], \"/index\", \"index\")\n", | ||
"\n", | ||
"pipeline = bm25 % 1000 >> pytcolbert.text_scorer(verbose=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "806c4638-ccee-4470-a74c-2a85d9ee2cfc", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 4: Create Run" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "a191f396-e896-4792-afaf-574e452640f5", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Step 4: Create Run.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.99q/s]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print('Step 4: Create Run.')\n", | ||
"run = pipeline(queries)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "c0e07fca-de98-4de2-b6a7-abfd516c652c", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>qid</th>\n", | ||
" <th>query</th>\n", | ||
" <th>docno</th>\n", | ||
" <th>score</th>\n", | ||
" <th>rank</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <th>0</th>\n", | ||
" <td>1</td>\n", | ||
" <td>fox jumps above animal</td>\n", | ||
" <td>pangram-04</td>\n", | ||
" <td>21.435955</td>\n", | ||
" <td>0</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>1</th>\n", | ||
" <td>1</td>\n", | ||
" <td>fox jumps above animal</td>\n", | ||
" <td>pangram-02</td>\n", | ||
" <td>20.018982</td>\n", | ||
" <td>1</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>2</th>\n", | ||
" <td>1</td>\n", | ||
" <td>fox jumps above animal</td>\n", | ||
" <td>pangram-03</td>\n", | ||
" <td>14.482786</td>\n", | ||
" <td>2</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>3</th>\n", | ||
" <td>1</td>\n", | ||
" <td>fox jumps above animal</td>\n", | ||
" <td>pangram-01</td>\n", | ||
" <td>12.042614</td>\n", | ||
" <td>3</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>4</th>\n", | ||
" <td>2</td>\n", | ||
" <td>multiple animals including a zebra</td>\n", | ||
" <td>pangram-03</td>\n", | ||
" <td>20.021858</td>\n", | ||
" <td>0</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>5</th>\n", | ||
" <td>2</td>\n", | ||
" <td>multiple animals including a zebra</td>\n", | ||
" <td>pangram-01</td>\n", | ||
" <td>16.255941</td>\n", | ||
" <td>1</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>6</th>\n", | ||
" <td>2</td>\n", | ||
" <td>multiple animals including a zebra</td>\n", | ||
" <td>pangram-05</td>\n", | ||
" <td>16.099251</td>\n", | ||
" <td>2</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" qid query docno score rank\n", | ||
"0 1 fox jumps above animal pangram-04 21.435955 0\n", | ||
"1 1 fox jumps above animal pangram-02 20.018982 1\n", | ||
"2 1 fox jumps above animal pangram-03 14.482786 2\n", | ||
"3 1 fox jumps above animal pangram-01 12.042614 3\n", | ||
"4 2 multiple animals including a zebra pangram-03 20.021858 0\n", | ||
"5 2 multiple animals including a zebra pangram-01 16.255941 1\n", | ||
"6 2 multiple animals including a zebra pangram-05 16.099251 2" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"run" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "28c40a2e-0f96-4ae8-aa5e-55a5e7ef9dee", | ||
"metadata": {}, | ||
"source": [ | ||
"### Step 5: Persist Run" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "12e5bb42-ed1f-41ba-b7a5-cb43ebca96f6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Step 5: Persist Run.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print('Step 5: Persist Run.')\n", | ||
"\n", | ||
"persist_and_normalize_run(run, output_file=output_directory, system_name='colbert', depth=1000)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |