1
1
from typing import Union , Callable , List , Optional
2
2
from tqdm import trange
3
3
4
- import numpy as np
5
4
import torch
6
5
from functools import partial
7
6
@@ -20,8 +19,8 @@ def greedy(X: TensorLikeArray,
20
19
covmat_precomputed : bool = False ,
21
20
T : Optional [Union [int , List [int ]]]= None ,
22
21
repeat : int = 10 ,
23
- device : torch .device = torch .device ('cpu' ),
24
22
batch_size : int = 1000000 ,
23
+ device : torch .device = torch .device ('cpu' ),
25
24
metric : Union [str ,Callable ]= 'o' ,
26
25
largest : bool = False ):
27
26
@@ -74,6 +73,7 @@ def greedy(X: TensorLikeArray,
74
73
best_candidate , best_score = _next_order_greedy (covmats , T , current_solution ,
75
74
metric = metric ,
76
75
largest = largest ,
76
+ batch_size = batch_size ,
77
77
device = device )
78
78
best_scores .append (best_score )
79
79
@@ -115,6 +115,7 @@ def _next_order_greedy(covmats: torch.Tensor,
115
115
initial_solution : torch .Tensor ,
116
116
metric : Union [str ,Callable ],
117
117
largest : bool ,
118
+ batch_size : int = 1000000 ,
118
119
device : torch .device = torch .device ('cpu' )):
119
120
120
121
'''
@@ -126,6 +127,7 @@ def _next_order_greedy(covmats: torch.Tensor,
126
127
- initial_solution (torch.Tensor): The initial solution with shape (batch_size, order)
127
128
- metric (Union[str,Callable]): The metric to evaluate. One of tc, dtc, o, s or a callable function
128
129
- largest (bool): A flag to indicate if the metric is to be maximized or minimized
130
+ - batch_size (int): The batch size to use for the computation. Default is 1000000.
129
131
- device (torch.device): The device to use for the computation. Default is 'cpu'
130
132
131
133
Returns:
@@ -135,36 +137,51 @@ def _next_order_greedy(covmats: torch.Tensor,
135
137
136
138
# Get parameters attributes
137
139
N = covmats .shape [1 ]
138
- batch_size , order = initial_solution .shape
140
+ total_size , order = initial_solution .shape
139
141
140
142
# Initial valid candidates to iterate one by one
141
- # |batch_size | x |N-order|
143
+ # |total_size | x |N-order|
142
144
valid_candidates = _get_valid_candidates (initial_solution , N , device )
143
-
144
- # |batch_size| x |N-order| x |order+1|
145
- all_solutions = _create_all_solutions (initial_solution , valid_candidates )
146
-
147
- # |batch_size x N-order| x |order+1|
148
- all_solutions = all_solutions .view (batch_size * (N - order ), order + 1 )
149
-
150
- # |batch_size x N-order|
151
- best_score = _evaluate_nplets (covmats , T , all_solutions , metric , device = device )
152
-
153
- # |batch_size| x |N-order|
154
- best_score = best_score .view (batch_size , N - order )
155
-
156
- if not largest :
157
- best_score = - best_score
158
-
159
- # get for each batch item the best score over the second dimention
160
-
161
- # |batch_size|
162
- max_idxs = torch .argmax (best_score , dim = 1 )
163
- best_candidates = valid_candidates [torch .arange (batch_size ), max_idxs ]
164
- best_score = best_score [torch .arange (batch_size ), max_idxs ]
165
-
166
- # If minimizing, then return score to its original sign
167
- if not largest :
168
- best_score = - best_score
169
145
170
- return best_candidates , best_score
146
+ best_candidates = []
147
+ best_scores = []
148
+
149
+ for start in range (0 , total_size , batch_size ):
150
+ end = min (start + batch_size , total_size )
151
+ batch_initial_solution = initial_solution [start :end ]
152
+ batch_valid_candidates = valid_candidates [start :end ]
153
+
154
+ # |batch_size| x |N-order| x |order+1|
155
+ all_solutions = _create_all_solutions (batch_initial_solution , batch_valid_candidates )
156
+
157
+ # |batch_size x N-order| x |order+1|
158
+ all_solutions = all_solutions .view (- 1 , order + 1 )
159
+
160
+ # |batch_size x N-order|
161
+ batch_best_score = _evaluate_nplets (covmats , T ,
162
+ all_solutions ,
163
+ metric ,
164
+ batch_size = batch_size ,
165
+ device = device )
166
+
167
+ # |batch_size| x |N-order|
168
+ batch_best_score = batch_best_score .view (end - start , N - order )
169
+
170
+ if not largest :
171
+ batch_best_score = - batch_best_score
172
+
173
+ # get for each batch item the best score over the second dimension
174
+
175
+ # |batch_size|
176
+ max_idxs = torch .argmax (batch_best_score , dim = 1 )
177
+ batch_best_candidates = batch_valid_candidates [torch .arange (end - start ), max_idxs ]
178
+ batch_best_score = batch_best_score [torch .arange (end - start ), max_idxs ]
179
+
180
+ # If minimizing, then return score to its original sign
181
+ if not largest :
182
+ batch_best_score = - batch_best_score
183
+
184
+ best_candidates .append (batch_best_candidates )
185
+ best_scores .append (batch_best_score )
186
+
187
+ return torch .cat (best_candidates ), torch .cat (best_scores )
0 commit comments