Skip to content

Commit 361be00

Browse files
committed
add retrieval for decoded and latent
1 parent 40f69b1 commit 361be00

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

script/acquire.sh

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
#!/usr/bin/bash
22

3-
# Acquire a compute node with 32 CPUs, 96GB RAM, and 1 A6000 GPU.
3+
# Validate the arguments
4+
if [ -z "$1" ]; then
5+
echo "Usage: $0 <num_gpu>"
6+
exit 1
7+
fi
8+
9+
NUM_GPU=$1
10+
11+
# Acquire a compute node with 32 CPUs, 96GB RAM, and specified number of GPUs.
412
# The node will be acquired for 7 days, and the session will be interactive.
513
# Please run this script with tmux to avoid losing the session.
614
srun \
7-
--partition=long --time=07-00:00:00 \
8-
--cpus-per-task=32 --mem=96GB --gres=gpu:A6000:1 \
9-
--pty bash
15+
--partition=long --time=07-00:00:00 \
16+
--cpus-per-task=32 --mem=96GB --gres=gpu:A6000:$NUM_GPU \
17+
--pty bash

source/interpret/retrieval/reconstruct.py source/interpret/retrieval/decoded.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main(dataset: Dataset, embedding: Type[Embedding], version: str):
3333
)
3434

3535
# define where to save results
36-
saveBase = Path(workspace, "reconstruct")
36+
saveBase = Path(workspace, "decoded")
3737
saveBase.mkdir(mode=0o770, parents=True, exist_ok=True)
3838
qresFile = Path(saveBase, f"{version}.qres")
3939
evalFile = Path(saveBase, f"{version}.eval")
@@ -54,7 +54,7 @@ def main(dataset: Dataset, embedding: Type[Embedding], version: str):
5454
D, I = gpuIndex.search(bQrys, 100)
5555
for qid, sims, dnos in zip(bQids, D, I):
5656
for s, d in zip(sims, dnos):
57-
f.write(f"{qid}\tQ0\t{dids[d]}\t0\t{s}\tReconstruct\n")
57+
f.write(f"{qid}\tQ0\t{dids[d]}\t0\t{s}\tDecoded\n")
5858
f.flush()
5959
p.update(t, advance=bQrys.shape[0])
6060
p.remove_task(t)

source/interpret/retrieval/latent.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from source.dataset import MsMarcoDataset
1515
from source.embedding import BgeBaseEmbedding
1616
from source.interface import Dataset, Embedding
17-
from source.interpret import esUser, esAuth, esCert
1817

1918

2019
def main(
@@ -60,10 +59,7 @@ def main(
6059

6160
# create connection to elastic search
6261
es = Elasticsearch(
63-
hosts=[{"host": esHost, "port": esPort}],
64-
http_auth=(esUser, esAuth),
65-
ca_certs=esCert,
66-
scheme="https",
62+
hosts=[{"host": esHost, "port": esPort, "scheme": "http"}],
6763
)
6864
es.indices.create(
6965
index=f"{version}.latent".lower(),
@@ -163,7 +159,7 @@ def main(
163159
parser.add_argument("version", type=str)
164160
parser.add_argument("--dataset", type=str, default="MsMarco", choices=["MsMarco"])
165161
parser.add_argument("--embedding", type=str, default="BgeBase", choices=["BgeBase"])
166-
parser.add_argument("--esHost", type=str, default="172.16.1.166")
162+
parser.add_argument("--esHost", type=str, default="localhost")
167163
parser.add_argument("--esPort", type=int, default=9200)
168164
args = parser.parse_args()
169165

0 commit comments

Comments
 (0)