Skip to content

Commit 10243c2

Browse files
committed
multiple bug fixes and added repeat_batch_size to greedy
1 parent 98e2268 commit 10243c2

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

thoi/heuristics/greedy.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def greedy(X: TensorLikeArray,
2020
T: Optional[Union[int, List[int]]]=None,
2121
repeat: int=10,
2222
batch_size: int=1000000,
23+
repeat_batch_size: int=1000000,
2324
device: torch.device=torch.device('cpu'),
2425
metric: Union[str,Callable]='o',
2526
largest: bool=False):
@@ -50,7 +51,7 @@ def greedy(X: TensorLikeArray,
5051
batch_data_collector = partial(batch_to_tensor, top_k=repeat, metric=metric, largest=largest)
5152
batch_aggregation = partial(concat_batched_tensors, top_k=repeat, metric=None, largest=largest)
5253

53-
# |repeat| x |initial_order|
54+
# |repeat| x |initial_order|, |repeat|
5455
_, current_solution, current_scores = multi_order_measures(covmats,
5556
covmat_precomputed=True,
5657
T=T,
@@ -60,7 +61,7 @@ def greedy(X: TensorLikeArray,
6061
device=device,
6162
batch_data_collector=batch_data_collector,
6263
batch_aggregation=batch_aggregation)
63-
64+
6465
# send current solution to the device
6566
current_solution = current_solution.to(device).contiguous()
6667

@@ -70,15 +71,21 @@ def greedy(X: TensorLikeArray,
7071
# Iterate over the remaining orders to get the best solution for each order
7172
best_scores = [current_scores]
7273
for _ in trange(initial_order, order, leave=False, desc='Order'):
74+
75+
# |repeat|, |repeat|
7376
best_candidate, best_score = _next_order_greedy(covmats, T, current_solution,
7477
metric=metric,
7578
largest=largest,
7679
batch_size=batch_size,
80+
repeat_batch_size=repeat_batch_size,
7781
device=device)
82+
# |order - initial_order| x |repeat|
7883
best_scores.append(best_score)
79-
84+
85+
# |repeat| x |order|
8086
current_solution = torch.cat((current_solution, best_candidate.unsqueeze(1)) , dim=1)
8187

88+
# |repeat| x |order|, |repeat| x |order - initial_order|
8289
return current_solution, torch.stack(best_scores).T
8390

8491

@@ -116,6 +123,7 @@ def _next_order_greedy(covmats: torch.Tensor,
116123
metric: Union[str,Callable],
117124
largest: bool,
118125
batch_size: int=1000000,
126+
repeat_batch_size: int=1000000,
119127
device: torch.device=torch.device('cpu')):
120128

121129
'''
@@ -146,33 +154,33 @@ def _next_order_greedy(covmats: torch.Tensor,
146154
best_candidates = []
147155
best_scores = []
148156

149-
for start in range(0, total_size, batch_size):
150-
end = min(start + batch_size, total_size)
157+
for start in trange(0, total_size, repeat_batch_size, desc='Batch repeat', leave=False):
158+
end = min(start + repeat_batch_size, total_size)
151159
batch_initial_solution = initial_solution[start:end]
152160
batch_valid_candidates = valid_candidates[start:end]
153161

154-
# |batch_size| x |N-order| x |order+1|
162+
# |repeat_batch_size| x |N-order| x |order+1|
155163
all_solutions = _create_all_solutions(batch_initial_solution, batch_valid_candidates)
156164

157-
# |batch_size x N-order| x |order+1|
165+
# |repeat_batch_size x N-order| x |order+1|
158166
all_solutions = all_solutions.view(-1, order+1)
159167

160-
# |batch_size x N-order|
168+
# |repeat_batch_size x N-order|
161169
batch_best_score = _evaluate_nplets(covmats, T,
162170
all_solutions,
163171
metric,
164172
batch_size=batch_size,
165173
device=device)
166174

167-
# |batch_size| x |N-order|
175+
# |repeat_batch_size| x |N-order|
168176
batch_best_score = batch_best_score.view(end - start, N - order)
169177

170178
if not largest:
171179
batch_best_score = -batch_best_score
172180

173181
# get for each batch item the best score over the second dimension
174182

175-
# |batch_size|
183+
# |repeat_batch_size|
176184
max_idxs = torch.argmax(batch_best_score, dim=1)
177185
batch_best_candidates = batch_valid_candidates[torch.arange(end - start), max_idxs]
178186
batch_best_score = batch_best_score[torch.arange(end - start), max_idxs]

thoi/heuristics/simulated_annealing.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
4444
covmats, D, N, T = _normalize_input_data(X, covmat_precomputed, T, device)
4545

4646
# Compute current solution
47-
# |batch_size| x |order|
47+
# |repeat| x |order|
4848
if initial_solution is None:
4949
current_solution = random_sampler(N, order, repeat, device)
5050
else:
5151
current_solution = initial_solution.to(device).contiguous()
5252

53-
# |batch_size|
53+
# |repeat|
5454
current_energy = _evaluate_nplets(covmats, T,
5555
current_solution,
5656
metric,
@@ -61,21 +61,21 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
6161
current_energy = -current_energy
6262

6363
# Initial valid candidates
64-
# |batch_size| x |N-order|
64+
# |repeat| x |N-order|
6565
valid_candidates = _get_valid_candidates(current_solution, N, device)
6666

6767
# Set initial temperature
6868
temp = initial_temp
6969

7070
# Best solution found
71-
# |batch_size| x |order|
71+
# |repeat| x |order|
7272
best_solution = current_solution.clone()
73-
# |batch_size|
73+
# |repeat|
7474
best_energy = current_energy.clone()
7575

7676
# Repeat tensor for indexing the current_solution
7777
# |repeat| x |1|
78-
i_repeat = torch.arange(repeat)
78+
i_repeat = torch.arange(repeat, device=device)
7979

8080
no_progress_count = 0
8181
pbar = trange(max_iterations, leave=False)
@@ -88,7 +88,7 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
8888

8989
# Generate new solution by modifying the current solution.
9090
# Generate the random indexes to change.
91-
# |batch_size| x |order|
91+
# |repeat| x |order|
9292
i_sol = torch.randint(0, current_solution.shape[1], (repeat,), device=device)
9393
i_cand = torch.randint(0, valid_candidates.shape[1], (repeat,), device=device)
9494

@@ -99,7 +99,7 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
9999
current_solution[i_repeat, i_sol] = new_candidates
100100

101101
# Calculate energy of new solution
102-
# |batch_size|
102+
# |repeat|
103103
new_energy = _evaluate_nplets(covmats, T,
104104
current_solution,
105105
metric,
@@ -111,33 +111,33 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
111111

112112
# Calculate change in energy
113113
# delca_energy > 0 means new_energy is bigger (more optimal) than current_energy
114-
# |batch_size|
114+
# |repeat|
115115
delta_energy = new_energy - current_energy
116116

117117
# Determine if we should accept the new solution
118-
# |batch_size|
118+
# |repeat|
119119
temp_probas = torch.rand(repeat, device=device) < torch.exp(delta_energy / temp)
120120
improves = delta_energy > 0
121121
accept_new_solution = torch.logical_or(improves, temp_probas)
122122

123123
# Restore original values for rejected candidates
124-
# |batch_size| x |order|
124+
# |repeat| x |order|
125125
current_solution[i_repeat[~accept_new_solution], i_sol[~accept_new_solution]] = current_candidates[~accept_new_solution]
126126

127127
# Update valid_candidate for the accepted answers as they are not longer valid candidates
128-
# |batch_size| x |N-order|
128+
# |repeat| x |N-order|
129129
valid_candidates[i_repeat[accept_new_solution], i_cand[accept_new_solution]] = current_candidates[accept_new_solution]
130130

131131
# Update current energy for the accepted solutions
132-
# |batch_size|
132+
# |repeat|
133133
current_energy[accept_new_solution] = new_energy[accept_new_solution]
134134

135135
new_global_maxima = (new_energy > best_energy)
136136

137-
# |batch_size| x |order|
137+
# |repeat| x |order|
138138
best_solution[new_global_maxima] = current_solution[new_global_maxima]
139139

140-
# |batch_size|
140+
# |repeat|
141141
best_energy[new_global_maxima] = new_energy[new_global_maxima]
142142

143143
# Cool down

thoi/measures/gaussian_copula.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def multi_order_measures(X: TensorLikeArray,
502502
assert min_order <= max_order, f"min_order must be lower or equal than max_order. {min_order} > {max_order}"
503503

504504
# Ensure that final batch_size is smaller than the original batch_size
505-
batch_size = batch_size // D
505+
batch_size = max(batch_size // D, 1)
506506

507507
# To compute using pytorch, we need to compute each order separately
508508
batched_data = []

0 commit comments

Comments
 (0)