2
2
from source .embedding .bgeBase import BgeBaseEmbedding
3
3
4
4
5
- def test_didIter_1 ():
5
+ def test_didIter ():
6
6
"""
7
7
Test didIter method.
8
8
"""
@@ -12,7 +12,7 @@ def test_didIter_1():
12
12
assert all (isinstance (i , int ) for i in ids )
13
13
14
14
15
- def test_docIter_1 ():
15
+ def test_docIter ():
16
16
"""
17
17
Test docIter method.
18
18
"""
@@ -22,7 +22,7 @@ def test_docIter_1():
22
22
assert all (isinstance (d , str ) for d in docs )
23
23
24
24
25
- def test_docEmbIter_1 ():
25
+ def test_docEmbIter ():
26
26
"""
27
27
Test docEmbIter method.
28
28
"""
@@ -31,7 +31,7 @@ def test_docEmbIter_1():
31
31
assert embeddings .shape == (8 , BgeBaseEmbedding .size )
32
32
33
33
34
- def test_getDocLen_1 ():
34
+ def test_getDocLen ():
35
35
"""
36
36
Test getDocLen method.
37
37
"""
@@ -41,119 +41,77 @@ def test_getDocLen_1():
41
41
assert docLen == 8841823
42
42
43
43
44
- def test_qidIter_1 ():
44
+ def test_qidIter ():
45
45
"""
46
46
Test qidIter method.
47
47
"""
48
48
dataset = MsMarcoDataset ()
49
49
qids = next (dataset .qidIter ("Train" , 8 ))
50
50
assert isinstance (qids , list ) and len (qids ) == 8
51
51
assert all (isinstance (q , int ) for q in qids )
52
-
53
-
54
- def test_qidIter_2 ():
55
- """
56
- Test qidIter method.
57
- """
58
- dataset = MsMarcoDataset ()
59
52
qids = next (dataset .qidIter ("Validate" , 8 ))
60
53
assert isinstance (qids , list ) and len (qids ) == 8
61
54
assert all (isinstance (q , int ) for q in qids )
62
55
63
56
64
- def test_qryIter_1 ():
57
+ def test_qryIter ():
65
58
"""
66
59
Test qryIter method.
67
60
"""
68
61
dataset = MsMarcoDataset ()
69
62
qrys = next (dataset .qryIter ("Train" , 8 ))
70
63
assert isinstance (qrys , list ) and len (qrys ) == 8
71
64
assert all (isinstance (q , str ) for q in qrys )
72
-
73
-
74
- def test_qryIter_2 ():
75
- """
76
- Test qryIter method.
77
- """
78
- dataset = MsMarcoDataset ()
79
65
qrys = next (dataset .qryIter ("Validate" , 8 ))
80
66
assert isinstance (qrys , list ) and len (qrys ) == 8
81
67
assert all (isinstance (q , str ) for q in qrys )
82
68
83
69
84
- def test_qryEmbIter_1 ():
70
+ def test_qryEmbIter ():
85
71
"""
86
72
Test qryEmbIter method.
87
73
"""
88
74
dataset = MsMarcoDataset ()
89
75
embeddings = next (dataset .qryEmbIter (BgeBaseEmbedding , "Train" , 8 , 0 , False ))
90
76
assert embeddings .shape == (8 , BgeBaseEmbedding .size )
91
-
92
-
93
- def test_qryEmbIter_2 ():
94
- """
95
- Test qryEmbIter method.
96
- """
97
- dataset = MsMarcoDataset ()
98
77
embeddings = next (dataset .qryEmbIter (BgeBaseEmbedding , "Validate" , 8 , 0 , False ))
99
78
assert embeddings .shape == (8 , BgeBaseEmbedding .size )
100
79
101
80
102
- def test_getQryLen_1 ():
81
+ def test_getQryLen ():
103
82
"""
104
83
Test getQryLen method.
105
84
"""
106
85
dataset = MsMarcoDataset ()
107
86
qryLen = dataset .getQryLen ("Train" )
108
87
assert isinstance (qryLen , int )
109
88
assert qryLen == 808731
110
-
111
-
112
- def test_getQryLen_2 ():
113
- """
114
- Test getQryLen method.
115
- """
116
- dataset = MsMarcoDataset ()
117
89
qryLen = dataset .getQryLen ("Validate" )
118
90
assert isinstance (qryLen , int )
119
91
assert qryLen == 101093
120
92
121
93
122
- def test_mixEmbIter_1 ():
94
+ def test_mixEmbIter ():
123
95
"""
124
96
Test mixEmbIter method.
125
97
"""
126
98
dataset = MsMarcoDataset ()
127
99
qry , docs = next (dataset .mixEmbIter (BgeBaseEmbedding , "Train" , 32 , 8 , 0 , False ))
128
100
assert qry .shape == (8 , BgeBaseEmbedding .size )
129
101
assert docs .shape == (8 , 32 , BgeBaseEmbedding .size )
130
-
131
-
132
- def test_mixEmbIter_2 ():
133
- """
134
- Test mixEmbIter method.
135
- """
136
- dataset = MsMarcoDataset ()
137
102
qry , docs = next (dataset .mixEmbIter (BgeBaseEmbedding , "Validate" , 32 , 8 , 0 , False ))
138
103
assert qry .shape == (8 , BgeBaseEmbedding .size )
139
104
assert docs .shape == (8 , 32 , BgeBaseEmbedding .size )
140
105
141
106
142
- def test_getMixLen_1 ():
107
+ def test_getMixLen ():
143
108
"""
144
109
Test getMixLen method.
145
110
"""
146
111
dataset = MsMarcoDataset ()
147
112
mixLen = dataset .getMixLen ("Train" )
148
113
assert isinstance (mixLen , int )
149
114
assert mixLen == 808731
150
-
151
-
152
- def test_getMixLen_2 ():
153
- """
154
- Test getMixLen method.
155
- """
156
- dataset = MsMarcoDataset ()
157
115
mixLen = dataset .getMixLen ("Validate" )
158
116
assert isinstance (mixLen , int )
159
117
assert mixLen == 101093
0 commit comments