@@ -20,6 +20,7 @@ def greedy(X: TensorLikeArray,
20
20
T : Optional [Union [int , List [int ]]]= None ,
21
21
repeat : int = 10 ,
22
22
batch_size : int = 1000000 ,
23
+ repeat_batch_size : int = 1000000 ,
23
24
device : torch .device = torch .device ('cpu' ),
24
25
metric : Union [str ,Callable ]= 'o' ,
25
26
largest : bool = False ):
@@ -50,7 +51,7 @@ def greedy(X: TensorLikeArray,
50
51
batch_data_collector = partial (batch_to_tensor , top_k = repeat , metric = metric , largest = largest )
51
52
batch_aggregation = partial (concat_batched_tensors , top_k = repeat , metric = None , largest = largest )
52
53
53
- # |repeat| x |initial_order|
54
+ # |repeat| x |initial_order|, |repeat|
54
55
_ , current_solution , current_scores = multi_order_measures (covmats ,
55
56
covmat_precomputed = True ,
56
57
T = T ,
@@ -60,7 +61,7 @@ def greedy(X: TensorLikeArray,
60
61
device = device ,
61
62
batch_data_collector = batch_data_collector ,
62
63
batch_aggregation = batch_aggregation )
63
-
64
+
64
65
# send current solution to the device
65
66
current_solution = current_solution .to (device ).contiguous ()
66
67
@@ -70,15 +71,21 @@ def greedy(X: TensorLikeArray,
70
71
# Iterate over the remaining orders to get the best solution for each order
71
72
best_scores = [current_scores ]
72
73
for _ in trange (initial_order , order , leave = False , desc = 'Order' ):
74
+
75
+ # |repeat|, |repeat|
73
76
best_candidate , best_score = _next_order_greedy (covmats , T , current_solution ,
74
77
metric = metric ,
75
78
largest = largest ,
76
79
batch_size = batch_size ,
80
+ repeat_batch_size = repeat_batch_size ,
77
81
device = device )
82
+ # |order - initial_order| x |repeat|
78
83
best_scores .append (best_score )
79
-
84
+
85
+ # |repeat| x |order|
80
86
current_solution = torch .cat ((current_solution , best_candidate .unsqueeze (1 )) , dim = 1 )
81
87
88
+ # |repeat| x |order|, |repeat| x |order - initial_order|
82
89
return current_solution , torch .stack (best_scores ).T
83
90
84
91
@@ -116,6 +123,7 @@ def _next_order_greedy(covmats: torch.Tensor,
116
123
metric : Union [str ,Callable ],
117
124
largest : bool ,
118
125
batch_size : int = 1000000 ,
126
+ repeat_batch_size : int = 1000000 ,
119
127
device : torch .device = torch .device ('cpu' )):
120
128
121
129
'''
@@ -146,33 +154,33 @@ def _next_order_greedy(covmats: torch.Tensor,
146
154
best_candidates = []
147
155
best_scores = []
148
156
149
- for start in range (0 , total_size , batch_size ):
150
- end = min (start + batch_size , total_size )
157
+ for start in trange (0 , total_size , repeat_batch_size , desc = 'Batch repeat' , leave = False ):
158
+ end = min (start + repeat_batch_size , total_size )
151
159
batch_initial_solution = initial_solution [start :end ]
152
160
batch_valid_candidates = valid_candidates [start :end ]
153
161
154
- # |batch_size | x |N-order| x |order+1|
162
+ # |repeat_batch_size | x |N-order| x |order+1|
155
163
all_solutions = _create_all_solutions (batch_initial_solution , batch_valid_candidates )
156
164
157
- # |batch_size x N-order| x |order+1|
165
+ # |repeat_batch_size x N-order| x |order+1|
158
166
all_solutions = all_solutions .view (- 1 , order + 1 )
159
167
160
- # |batch_size x N-order|
168
+ # |repeat_batch_size x N-order|
161
169
batch_best_score = _evaluate_nplets (covmats , T ,
162
170
all_solutions ,
163
171
metric ,
164
172
batch_size = batch_size ,
165
173
device = device )
166
174
167
- # |batch_size | x |N-order|
175
+ # |repeat_batch_size | x |N-order|
168
176
batch_best_score = batch_best_score .view (end - start , N - order )
169
177
170
178
if not largest :
171
179
batch_best_score = - batch_best_score
172
180
173
181
# get for each batch item the best score over the second dimension
174
182
175
- # |batch_size |
183
+ # |repeat_batch_size |
176
184
max_idxs = torch .argmax (batch_best_score , dim = 1 )
177
185
batch_best_candidates = batch_valid_candidates [torch .arange (end - start ), max_idxs ]
178
186
batch_best_score = batch_best_score [torch .arange (end - start ), max_idxs ]
0 commit comments