Skip to content

Commit 7f1db3e

Browse files
committed
fix typing issues from dataset and embedding models
1 parent 69af5ce commit 7f1db3e

File tree

5 files changed

+26
-60
lines changed

5 files changed

+26
-60
lines changed

.vscode/settings.json

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
},
1010
"[python]": {
1111
"editor.tabSize": 4,
12-
"editor.formatOnSave": true,
1312
"editor.defaultFormatter": "ms-python.black-formatter",
1413
},
1514
"[jsonc]": {

source/dataset/msMarco.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,12 @@ def docPrefixEmbIter(
100100
shuffle: bool,
101101
idxs: List[int],
102102
) -> Iterator[Tuple[Tensor, List[str], Tensor]]:
103+
"""
104+
@todo: fix the typing override.
105+
"""
103106
embed = embedding()
104107
idx = 0
105-
idxs = deque(sorted(idxs))
108+
idxs = deque(sorted(idxs)) # type: ignore
106109
done = False
107110
for p in range(4):
108111
path = Path(DocIterInit.base, f"partition-{p:08d}.parquet")
@@ -111,9 +114,12 @@ def docPrefixEmbIter(
111114
batches = file.iter_batches(1, columns=["text"])
112115
for i, part in enumerate(batches):
113116
if idx == idxs[0]:
114-
idxs.popleft()
117+
idxs.popleft() # type: ignore
115118
txt = part.column("text").to_pylist()
116-
vec, tokens, token_ids = embed.forward_prefix(txt)
119+
"""
120+
@todo: add forward_prefix to the Embedding interface.
121+
"""
122+
vec, tokens, token_ids = embed.forward_prefix(txt) # type: ignore
117123
yield vec, tokens, token_ids.detach().cpu().tolist()
118124
idx += 1
119125
if len(idxs) == 0:
@@ -396,7 +402,7 @@ def __init__(self) -> None:
396402
self.base.mkdir(mode=0o770, parents=True, exist_ok=True)
397403
asyncio.run(self.dispatch())
398404

399-
async def dispatch(self) -> None:
405+
async def dispatch(self):
400406
# we should have dispatched all tasks at once, but due to progress bar
401407
# constraints, only one at a time is possible. Otherwise, the progress
402408
# bar would be globally defined, and may interfere with training logs.

source/dataset/test_msMarco.py

+10-52
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from source.embedding.bgeBase import BgeBaseEmbedding
33

44

5-
def test_didIter_1():
5+
def test_didIter():
66
"""
77
Test didIter method.
88
"""
@@ -12,7 +12,7 @@ def test_didIter_1():
1212
assert all(isinstance(i, int) for i in ids)
1313

1414

15-
def test_docIter_1():
15+
def test_docIter():
1616
"""
1717
Test docIter method.
1818
"""
@@ -22,7 +22,7 @@ def test_docIter_1():
2222
assert all(isinstance(d, str) for d in docs)
2323

2424

25-
def test_docEmbIter_1():
25+
def test_docEmbIter():
2626
"""
2727
Test docEmbIter method.
2828
"""
@@ -31,7 +31,7 @@ def test_docEmbIter_1():
3131
assert embeddings.shape == (8, BgeBaseEmbedding.size)
3232

3333

34-
def test_getDocLen_1():
34+
def test_getDocLen():
3535
"""
3636
Test getDocLen method.
3737
"""
@@ -41,119 +41,77 @@ def test_getDocLen_1():
4141
assert docLen == 8841823
4242

4343

44-
def test_qidIter_1():
44+
def test_qidIter():
4545
"""
4646
Test qidIter method.
4747
"""
4848
dataset = MsMarcoDataset()
4949
qids = next(dataset.qidIter("Train", 8))
5050
assert isinstance(qids, list) and len(qids) == 8
5151
assert all(isinstance(q, int) for q in qids)
52-
53-
54-
def test_qidIter_2():
55-
"""
56-
Test qidIter method.
57-
"""
58-
dataset = MsMarcoDataset()
5952
qids = next(dataset.qidIter("Validate", 8))
6053
assert isinstance(qids, list) and len(qids) == 8
6154
assert all(isinstance(q, int) for q in qids)
6255

6356

64-
def test_qryIter_1():
57+
def test_qryIter():
6558
"""
6659
Test qryIter method.
6760
"""
6861
dataset = MsMarcoDataset()
6962
qrys = next(dataset.qryIter("Train", 8))
7063
assert isinstance(qrys, list) and len(qrys) == 8
7164
assert all(isinstance(q, str) for q in qrys)
72-
73-
74-
def test_qryIter_2():
75-
"""
76-
Test qryIter method.
77-
"""
78-
dataset = MsMarcoDataset()
7965
qrys = next(dataset.qryIter("Validate", 8))
8066
assert isinstance(qrys, list) and len(qrys) == 8
8167
assert all(isinstance(q, str) for q in qrys)
8268

8369

84-
def test_qryEmbIter_1():
70+
def test_qryEmbIter():
8571
"""
8672
Test qryEmbIter method.
8773
"""
8874
dataset = MsMarcoDataset()
8975
embeddings = next(dataset.qryEmbIter(BgeBaseEmbedding, "Train", 8, 0, False))
9076
assert embeddings.shape == (8, BgeBaseEmbedding.size)
91-
92-
93-
def test_qryEmbIter_2():
94-
"""
95-
Test qryEmbIter method.
96-
"""
97-
dataset = MsMarcoDataset()
9877
embeddings = next(dataset.qryEmbIter(BgeBaseEmbedding, "Validate", 8, 0, False))
9978
assert embeddings.shape == (8, BgeBaseEmbedding.size)
10079

10180

102-
def test_getQryLen_1():
81+
def test_getQryLen():
10382
"""
10483
Test getQryLen method.
10584
"""
10685
dataset = MsMarcoDataset()
10786
qryLen = dataset.getQryLen("Train")
10887
assert isinstance(qryLen, int)
10988
assert qryLen == 808731
110-
111-
112-
def test_getQryLen_2():
113-
"""
114-
Test getQryLen method.
115-
"""
116-
dataset = MsMarcoDataset()
11789
qryLen = dataset.getQryLen("Validate")
11890
assert isinstance(qryLen, int)
11991
assert qryLen == 101093
12092

12193

122-
def test_mixEmbIter_1():
94+
def test_mixEmbIter():
12395
"""
12496
Test mixEmbIter method.
12597
"""
12698
dataset = MsMarcoDataset()
12799
qry, docs = next(dataset.mixEmbIter(BgeBaseEmbedding, "Train", 32, 8, 0, False))
128100
assert qry.shape == (8, BgeBaseEmbedding.size)
129101
assert docs.shape == (8, 32, BgeBaseEmbedding.size)
130-
131-
132-
def test_mixEmbIter_2():
133-
"""
134-
Test mixEmbIter method.
135-
"""
136-
dataset = MsMarcoDataset()
137102
qry, docs = next(dataset.mixEmbIter(BgeBaseEmbedding, "Validate", 32, 8, 0, False))
138103
assert qry.shape == (8, BgeBaseEmbedding.size)
139104
assert docs.shape == (8, 32, BgeBaseEmbedding.size)
140105

141106

142-
def test_getMixLen_1():
107+
def test_getMixLen():
143108
"""
144109
Test getMixLen method.
145110
"""
146111
dataset = MsMarcoDataset()
147112
mixLen = dataset.getMixLen("Train")
148113
assert isinstance(mixLen, int)
149114
assert mixLen == 808731
150-
151-
152-
def test_getMixLen_2():
153-
"""
154-
Test getMixLen method.
155-
"""
156-
dataset = MsMarcoDataset()
157115
mixLen = dataset.getMixLen("Validate")
158116
assert isinstance(mixLen, int)
159117
assert mixLen == 101093

source/embedding/bgeBase.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import Tensor
33
import torch.nn as nn
44
import torch.nn.functional as F
5-
from typing import List
5+
from typing import List, Tuple, Any
66
from transformers.models.bert.modeling_bert import BertModel
77
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
88
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
@@ -38,7 +38,10 @@ def forward(self, passages: List[str]) -> Tensor:
3838
return F.normalize(hiddens[:, 0], p=2, dim=1)
3939

4040
@torch.inference_mode()
41-
def forward_prefix(self, passages: List[str]) -> Tensor:
41+
def forward_prefix(self, passages: List[str]) -> Tuple[Tensor, Any, Any]:
42+
"""
43+
@todo: fix the return type.
44+
"""
4245
kwargs = dict(padding=True, truncation=True, return_tensors="pt")
4346
encoded = self.tokenizer(passages[0], **kwargs)
4447
input_ids = encoded.input_ids[0] # Shape: [seq_len]

source/embedding/test_bgeBase.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from source.embedding.bgeBase import BgeBaseEmbedding
22

33

4-
def test_forward_1():
4+
def test_forward():
55
"""
66
Test forward method.
77
"""

0 commit comments

Comments
 (0)