Skip to content

Commit 16ad17c

Browse files
authored
implemented batch_size on nplet_measures/_hot_encoded (#18)
* implemented batch_size on nplet_measures/_hot_encoded * implemented batch_size in greedy * changed batch_size parameter order * added batch_size to simulated annealing * implemented batch size in simulated annealing
1 parent ca4e180 commit 16ad17c

7 files changed

+182
-69
lines changed

tests/test_nplet_measures.py

+46
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ def test_multiple_times_same_datasets_precomputed(self):
9393
nplets = torch.tensor(list(combinations(full_nplet, order)))
9494
res = nplets_measures([self.covmat, self.covmat], nplets, covmat_precomputed=True, T=self.X.shape[0])
9595
self._validate_same_results_for_repeated_datasets(res, nplets, rtol=1e-16, atol=1e-7)
96+
97+
def test_batch_size_does_not_change_result(self):
98+
full_nplet = range(self.X.shape[1])
99+
100+
nplets = torch.tensor([list(c) for i, c in enumerate(combinations(full_nplet, 3)) if i < 100000])
101+
102+
# test for different batch sizes
103+
res = nplets_measures(self.X, nplets)
104+
res2 = nplets_measures(self.X, nplets, batch_size=10)
105+
res3 = nplets_measures(self.X, nplets, batch_size=100)
106+
res4 = nplets_measures(self.X, nplets, batch_size=1000)
107+
res5 = nplets_measures(self.X, nplets, batch_size=10000)
108+
res6 = nplets_measures(self.X, nplets, batch_size=100000)
109+
res7 = nplets_measures(self.X, nplets, batch_size=1000000) # this should do a single batch
110+
111+
# check that the results are the same
112+
self.assertTrue(torch.allclose(res, res2, rtol=1e-16, atol=1e-12))
113+
self.assertTrue(torch.allclose(res, res3, rtol=1e-16, atol=1e-12))
114+
self.assertTrue(torch.allclose(res, res4, rtol=1e-16, atol=1e-12))
115+
self.assertTrue(torch.allclose(res, res5, rtol=1e-16, atol=1e-12))
116+
self.assertTrue(torch.allclose(res, res6, rtol=1e-16, atol=1e-12))
117+
self.assertTrue(torch.allclose(res, res7, rtol=1e-16, atol=1e-12))
96118

97119
def test_nplets_measures_timeseries_hot_encoded(self):
98120
N = self.X.shape[1]
@@ -142,6 +164,30 @@ def test_multiple_times_same_dataset_precomputed_hot_encoded(self):
142164
res = nplets_measures_hot_encoded([self.covmat, self.covmat], nplets_hot_encoded, covmat_precomputed=True, T=self.X.shape[0])
143165
self._validate_same_results_for_repeated_datasets(res, nplets, rtol=1e-8, atol=1e-4)
144166

167+
def test_batch_size_does_not_change_result_hot_encoded(self):
168+
full_nplet = range(self.X.shape[1])
169+
170+
nplets = torch.tensor([list(c) for i, c in enumerate(combinations(full_nplet, 3)) if i < 100000])
171+
nplets_hot_encoded = torch.zeros((nplets.shape[0], self.X.shape[1]), dtype=torch.int)
172+
nplets_hot_encoded[torch.arange(0,nplets.shape[0], dtype=int).view(-1,1), nplets] = 1
173+
174+
# test for different batch sizes
175+
res = nplets_measures_hot_encoded(self.X, nplets_hot_encoded)
176+
res2 = nplets_measures_hot_encoded(self.X, nplets_hot_encoded, batch_size=10)
177+
res3 = nplets_measures_hot_encoded(self.X, nplets_hot_encoded, batch_size=100)
178+
res4 = nplets_measures_hot_encoded(self.X, nplets_hot_encoded, batch_size=1000)
179+
res5 = nplets_measures_hot_encoded(self.X, nplets_hot_encoded, batch_size=10000)
180+
res6 = nplets_measures_hot_encoded(self.X, nplets_hot_encoded, batch_size=100000)
181+
res7 = nplets_measures_hot_encoded(self.X, nplets_hot_encoded, batch_size=1000000) # this should do a single batch
182+
183+
# check that the results are the same
184+
self.assertTrue(torch.allclose(res, res2, rtol=1e-16, atol=1e-12))
185+
self.assertTrue(torch.allclose(res, res3, rtol=1e-16, atol=1e-12))
186+
self.assertTrue(torch.allclose(res, res4, rtol=1e-16, atol=1e-12))
187+
self.assertTrue(torch.allclose(res, res5, rtol=1e-16, atol=1e-12))
188+
self.assertTrue(torch.allclose(res, res6, rtol=1e-16, atol=1e-12))
189+
self.assertTrue(torch.allclose(res, res7, rtol=1e-16, atol=1e-12))
145190

191+
146192
if __name__ == '__main__':
147193
unittest.main()

thoi/heuristics/greedy.py

+48-31
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Union, Callable, List, Optional
22
from tqdm import trange
33

4-
import numpy as np
54
import torch
65
from functools import partial
76

@@ -20,8 +19,8 @@ def greedy(X: TensorLikeArray,
2019
covmat_precomputed: bool=False,
2120
T: Optional[Union[int, List[int]]]=None,
2221
repeat: int=10,
23-
device: torch.device=torch.device('cpu'),
2422
batch_size: int=1000000,
23+
device: torch.device=torch.device('cpu'),
2524
metric: Union[str,Callable]='o',
2625
largest: bool=False):
2726

@@ -74,6 +73,7 @@ def greedy(X: TensorLikeArray,
7473
best_candidate, best_score = _next_order_greedy(covmats, T, current_solution,
7574
metric=metric,
7675
largest=largest,
76+
batch_size=batch_size,
7777
device=device)
7878
best_scores.append(best_score)
7979

@@ -115,6 +115,7 @@ def _next_order_greedy(covmats: torch.Tensor,
115115
initial_solution: torch.Tensor,
116116
metric: Union[str,Callable],
117117
largest: bool,
118+
batch_size: int=1000000,
118119
device: torch.device=torch.device('cpu')):
119120

120121
'''
@@ -126,6 +127,7 @@ def _next_order_greedy(covmats: torch.Tensor,
126127
- initial_solution (torch.Tensor): The initial solution with shape (batch_size, order)
127128
- metric (Union[str,Callable]): The metric to evaluate. One of tc, dtc, o, s or a callable function
128129
- largest (bool): A flag to indicate if the metric is to be maximized or minimized
130+
- batch_size (int): The batch size to use for the computation. Default is 1000000.
129131
- device (torch.device): The device to use for the computation. Default is 'cpu'
130132
131133
Returns:
@@ -135,36 +137,51 @@ def _next_order_greedy(covmats: torch.Tensor,
135137

136138
# Get parameters attributes
137139
N = covmats.shape[1]
138-
batch_size, order = initial_solution.shape
140+
total_size, order = initial_solution.shape
139141

140142
# Initial valid candidates to iterate one by one
141-
# |batch_size| x |N-order|
143+
# |total_size| x |N-order|
142144
valid_candidates = _get_valid_candidates(initial_solution, N, device)
143-
144-
# |batch_size| x |N-order| x |order+1|
145-
all_solutions = _create_all_solutions(initial_solution, valid_candidates)
146-
147-
# |batch_size x N-order| x |order+1|
148-
all_solutions = all_solutions.view(batch_size*(N-order), order+1)
149-
150-
# |batch_size x N-order|
151-
best_score = _evaluate_nplets(covmats, T, all_solutions, metric, device=device)
152-
153-
# |batch_size| x |N-order|
154-
best_score = best_score.view(batch_size, N-order)
155-
156-
if not largest:
157-
best_score = -best_score
158-
159-
# get for each batch item the best score over the second dimention
160-
161-
# |batch_size|
162-
max_idxs = torch.argmax(best_score, dim=1)
163-
best_candidates = valid_candidates[torch.arange(batch_size), max_idxs]
164-
best_score = best_score[torch.arange(batch_size), max_idxs]
165-
166-
# If minimizing, then return score to its original sign
167-
if not largest:
168-
best_score = -best_score
169145

170-
return best_candidates, best_score
146+
best_candidates = []
147+
best_scores = []
148+
149+
for start in range(0, total_size, batch_size):
150+
end = min(start + batch_size, total_size)
151+
batch_initial_solution = initial_solution[start:end]
152+
batch_valid_candidates = valid_candidates[start:end]
153+
154+
# |batch_size| x |N-order| x |order+1|
155+
all_solutions = _create_all_solutions(batch_initial_solution, batch_valid_candidates)
156+
157+
# |batch_size x N-order| x |order+1|
158+
all_solutions = all_solutions.view(-1, order+1)
159+
160+
# |batch_size x N-order|
161+
batch_best_score = _evaluate_nplets(covmats, T,
162+
all_solutions,
163+
metric,
164+
batch_size=batch_size,
165+
device=device)
166+
167+
# |batch_size| x |N-order|
168+
batch_best_score = batch_best_score.view(end - start, N - order)
169+
170+
if not largest:
171+
batch_best_score = -batch_best_score
172+
173+
# get for each batch item the best score over the second dimension
174+
175+
# |batch_size|
176+
max_idxs = torch.argmax(batch_best_score, dim=1)
177+
batch_best_candidates = batch_valid_candidates[torch.arange(end - start), max_idxs]
178+
batch_best_score = batch_best_score[torch.arange(end - start), max_idxs]
179+
180+
# If minimizing, then return score to its original sign
181+
if not largest:
182+
batch_best_score = -batch_best_score
183+
184+
best_candidates.append(batch_best_candidates)
185+
best_scores.append(batch_best_score)
186+
187+
return torch.cat(best_candidates), torch.cat(best_scores)

thoi/heuristics/scoring.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ def _evaluate_nplets(covmats: torch.Tensor,
1212
T: Optional[List[int]],
1313
batched_nplets: torch.Tensor,
1414
metric: Union[str, Callable],
15+
batch_size: int,
1516
device: torch.device):
1617
"""
1718
- covmats (torch.Tensor): The covariance matrix or matrixes with shape (N, N) or (D, N, N)
1819
- T (Optional[List[int]]): The number of samples for each multivariate series or None
19-
- batched_nplets (torch.Tensor): The nplets to calculate the inverse of the oinformation with shape (batch_size, order)
20+
- batched_nplets (torch.Tensor): The nplets to calculate the inverse of the oinformation with shape (total_size, order)
2021
- metric (str): The metric to evaluate. One of tc, dtc, o, s or Callable
22+
- batch_size (int): The batch size to use for the computation
2123
- device (torch.device): The device to use
2224
"""
2325

@@ -31,6 +33,7 @@ def _evaluate_nplets(covmats: torch.Tensor,
3133
nplets=batched_nplets,
3234
T=T,
3335
covmat_precomputed=True,
36+
batch_size=batch_size,
3437
device=device)
3538

3639
# |batch_size|
@@ -41,6 +44,7 @@ def _evaluate_nplet_hot_encoded(covmats: torch.Tensor,
4144
T: int,
4245
batched_nplets: torch.Tensor,
4346
metric: str,
47+
batch_size: int,
4448
device: torch.device):
4549

4650
"""
@@ -60,6 +64,7 @@ def _evaluate_nplet_hot_encoded(covmats: torch.Tensor,
6064
nplets=batched_nplets,
6165
T=T,
6266
covmat_precomputed=True,
67+
batch_size=batch_size,
6368
device=device)
6469

6570
# |batch_size|

thoi/heuristics/simulated_annealing.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
2626
T: Optional[Union[int, List[int]]]=None,
2727
initial_solution: Optional[torch.Tensor] = None,
2828
repeat: int = 10,
29+
batch_size: int = 1000000,
2930
device: torch.device = torch.device('cpu'),
3031
max_iterations: int = 1000,
3132
early_stop: int = 100,
@@ -50,7 +51,11 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
5051
current_solution = initial_solution.to(device).contiguous()
5152

5253
# |batch_size|
53-
current_energy = _evaluate_nplets(covmats, T, current_solution, metric, device=device)
54+
current_energy = _evaluate_nplets(covmats, T,
55+
current_solution,
56+
metric,
57+
batch_size=batch_size,
58+
device=device)
5459

5560
if not largest:
5661
current_energy = -current_energy
@@ -95,7 +100,11 @@ def simulated_annealing(X: Union[np.ndarray, torch.Tensor, List[np.ndarray], Lis
95100

96101
# Calculate energy of new solution
97102
# |batch_size|
98-
new_energy = _evaluate_nplets(covmats, T, current_solution, metric, device=device)
103+
new_energy = _evaluate_nplets(covmats, T,
104+
current_solution,
105+
metric,
106+
batch_size=batch_size,
107+
device=device)
99108

100109
if not largest:
101110
new_energy = -new_energy

thoi/heuristics/simulated_annealing_multi_order.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def simulated_annealing_multi_order(X: Union[np.ndarray, torch.Tensor, List[np.n
3434
T: Optional[Union[int, List[int]]]=None,
3535
initial_solution: Optional[torch.Tensor] = None,
3636
repeat: int = 10,
37+
batch_size: int = 1000000,
3738
device: torch.device = torch.device('cpu'),
3839
max_iterations: int = 1000,
3940
early_stop: int = 100,
@@ -59,7 +60,11 @@ def simulated_annealing_multi_order(X: Union[np.ndarray, torch.Tensor, List[np.n
5960
current_solution = initial_solution.to(device).contiguous()
6061

6162
# |batch_size|
62-
current_energy = _evaluate_nplet_hot_encoded(covmats, T, current_solution, metric, device=device)
63+
current_energy = _evaluate_nplet_hot_encoded(covmats, T,
64+
current_solution,
65+
metric,
66+
batch_size=batch_size,
67+
device=device)
6368

6469
if not largest:
6570
current_energy = -current_energy
@@ -97,7 +102,11 @@ def simulated_annealing_multi_order(X: Union[np.ndarray, torch.Tensor, List[np.n
97102

98103
# Calculate energy of new solution
99104
# |batch_size|
100-
new_energy = _evaluate_nplet_hot_encoded(covmats, T, current_solution, metric, device=device)
105+
new_energy = _evaluate_nplet_hot_encoded(covmats, T,
106+
current_solution,
107+
metric,
108+
batch_size=batch_size,
109+
device=device)
101110

102111
if not largest:
103112
new_energy = -new_energy

thoi/measures/gaussian_copula.py

+41-26
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def nplets_measures(X: Union[TensorLikeArray],
163163
covmat_precomputed: bool = False,
164164
T: Optional[Union[int, List[int]]] = None,
165165
device: torch.device = torch.device('cpu'),
166-
verbose: int = logging.INFO):
166+
verbose: int = logging.INFO,
167+
batch_size: int = 1000000):
167168

168169
"""
169170
Compute higher-order measures (TC, DTC, O, S) for specified n-plets in the given data matrices X.
@@ -202,6 +203,9 @@ def nplets_measures(X: Union[TensorLikeArray],
202203
verbose : int, optional
203204
Logging verbosity level. Default is `logging.INFO`.
204205
206+
batch_size : int, optional
207+
Batch size for processing n-plets. Default is 1,000,000.
208+
205209
Returns
206210
-------
207211
torch.Tensor
@@ -310,7 +314,8 @@ def nplets_measures(X: Union[TensorLikeArray],
310314

311315
# nplets must be a batched tensor
312316
assert len(nplets.shape) == 2, 'nplets must be a batched tensor with shape (batch_size, order)'
313-
batch_size, order = nplets.shape
317+
batch_size = min(batch_size, len(nplets))
318+
order = nplets.shape[1]
314319

315320
# Create marginal indexes
316321
# |N| x |N-1|
@@ -320,30 +325,40 @@ def nplets_measures(X: Union[TensorLikeArray],
320325
# |batch_size x D|, |batch_size x D|, |batch_size x D|
321326
bc1, bcN, bcNmin1 = _get_bias_correctors(T, order, batch_size, D, device)
322327

323-
# Create the covariance matrices for each nplet in the batch
324-
# |batch_size| x |D| x |N| x |N|
325-
nplets_covmats = _generate_nplets_covmants(covmats, nplets)
326-
327-
# Pack covmat in a single batch
328-
# |batch_size x D| x |order| x |order|
329-
nplets_covmats = nplets_covmats.view(batch_size*D, order, order)
330-
331-
# Batch process all nplets at once
332-
measures = _get_tc_dtc_from_batched_covmat(nplets_covmats,
333-
allmin1,
334-
bc1,
335-
bcN,
336-
bcNmin1)
337-
338-
# Unpack results
339-
# |batch_size x D|, |batch_size x D|, |batch_size x D|, |batch_size x D|
340-
nplets_tc, nplets_dtc, nplets_o, nplets_s = measures
341-
342-
# |batch_size| x |D| x |4 = (tc, dtc, o, s)|
343-
return torch.stack([nplets_tc.view(batch_size, D),
344-
nplets_dtc.view(batch_size, D),
345-
nplets_o.view(batch_size, D),
346-
nplets_s.view(batch_size, D)], dim=-1)
328+
# Create DataLoader for nplets
329+
dataloader = DataLoader(nplets, batch_size=batch_size, shuffle=False)
330+
331+
results = []
332+
for nplet_batch in tqdm(dataloader, desc='Processing n-plets', leave=False):
333+
curr_batch_size = nplet_batch.shape[0]
334+
335+
# Create the covariance matrices for each nplet in the batch
336+
# |curr_batch_size| x |D| x |order| x |order|
337+
nplets_covmats = _generate_nplets_covmants(covmats, nplet_batch)
338+
339+
# Pack covmats in a single batch
340+
# |curr_batch_size x D| x |order| x |order|
341+
nplets_covmats = nplets_covmats.view(curr_batch_size * D, order, order)
342+
343+
# Batch process all nplets at once
344+
measures = _get_tc_dtc_from_batched_covmat(nplets_covmats,
345+
allmin1,
346+
bc1[:curr_batch_size * D],
347+
bcN[:curr_batch_size * D],
348+
bcNmin1[:curr_batch_size * D])
349+
350+
# Unpack results
351+
# |curr_batch_size x D|, |curr_batch_size x D|, |curr_batch_size x D|, |curr_batch_size x D|
352+
nplets_tc, nplets_dtc, nplets_o, nplets_s = measures
353+
354+
# Collect results
355+
results.append(torch.stack([nplets_tc.view(curr_batch_size, D),
356+
nplets_dtc.view(curr_batch_size, D),
357+
nplets_o.view(curr_batch_size, D),
358+
nplets_s.view(curr_batch_size, D)], dim=-1))
359+
360+
# Concatenate all results
361+
return torch.cat(results, dim=0)
347362

348363
@torch.no_grad()
349364
def multi_order_measures(X: TensorLikeArray,

0 commit comments

Comments
 (0)