Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.4.1"
__version__ = "2.5.0"
18 changes: 9 additions & 9 deletions src/pytorch_metric_learning/losses/manifold_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
if self.lambdaC != np.inf:
F = F[:N, N:]
loss_int = F - F[torch.arange(N), meta_classes].view(-1, 1) + self.margin
loss_int[
torch.arange(N), meta_classes
] = -np.inf # This way avoid numerical cancellation happening # NoQA
loss_int[torch.arange(N), meta_classes] = (
-np.inf
) # This way avoid numerical cancellation happening # NoQA
# instead with subtraction of margin term # NoQA
loss_int[
loss_int < 0
] = -np.inf # This way no loss for positive correlation with own proxy
loss_int[loss_int < 0] = (
-np.inf
) # This way no loss for positive correlation with own proxy

loss_int = torch.exp(loss_int)
loss_int = torch.log(1 + torch.sum(loss_int, dim=1))
Expand All @@ -106,9 +106,9 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
F_e, F_p.unsqueeze(1), dim=-1
).t()
loss_ctx += -loss_ctx[torch.arange(N), meta_classes].view(-1, 1) + self.margin
loss_ctx[
torch.arange(N), meta_classes
] = -np.inf # This way avoid numerical cancellation happening # NoQA
loss_ctx[torch.arange(N), meta_classes] = (
-np.inf
) # This way avoid numerical cancellation happening # NoQA
# instead with subtraction of margin term # NoQA
loss_ctx[loss_ctx < 0] = -np.inf

Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_metric_learning/testers/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,10 @@ def test(
query_split_name,
reference_split_names,
)
self.end_of_testing_hook(self) if self.end_of_testing_hook else c_f.LOGGER.info(
self.all_accuracies
(
self.end_of_testing_hook(self)
if self.end_of_testing_hook
else c_f.LOGGER.info(self.all_accuracies)
)
del self.embeddings_and_labels
return self.all_accuracies
50 changes: 48 additions & 2 deletions src/pytorch_metric_learning/utils/loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,57 @@ def neg_pairs_from_tuple(indices_tuple):


def get_all_triplets_indices(labels, ref_labels=None):
matches, diffs = get_matches_and_diffs(labels, ref_labels)
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)

if (
all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
< torch.iinfo(torch.int32).max
):
# torch.nonzero is not supported for tensors with more than INT_MAX elements
return get_all_triplets_indices_vectorized_method(all_matches, all_diffs)

return get_all_triplets_indices_loop_method(labels, all_matches, all_diffs)


def get_all_triplets_indices_vectorized_method(all_matches, all_diffs):
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
return torch.where(triplets)


def get_all_triplets_indices_loop_method(labels, all_matches, all_diffs):
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()

# Find anchors with at least a positive and a negative
indices = torch.arange(0, len(labels), device=labels.device)
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]

# No triplets found
if len(indices) == 0:
return (
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
torch.tensor([], device=labels.device, dtype=labels.dtype),
)

# Compute all triplets
anchors = []
positives = []
negatives = []
for i in indices:
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
nd = len(diffs)
nm = len(matches)
matches = matches.repeat_interleave(nd)
diffs = diffs.repeat(nm)
anchors.append(
torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device)
)
positives.append(matches)
negatives.append(diffs)
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)


# sample triplets, with a weighted distribution if weights is specified.
def get_random_triplet_indices(
labels, ref_labels=None, t_per_anchor=None, weights=None
Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_calculate_accuracies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def test_accuracy_calculator(self):
"query_labels": query_labels,
"label_counts": label_counts,
"knn_labels": knn_labels,
"not_lone_query_mask": torch.ones(6, dtype=torch.bool)
if i == 0
else torch.zeros(6, dtype=torch.bool),
"not_lone_query_mask": (
torch.ones(6, dtype=torch.bool)
if i == 0
else torch.zeros(6, dtype=torch.bool)
),
}

function_dict = AC.get_function_dict()
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_loss_and_miner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,29 @@ def test_remove_self_comparisons_small_ref(self):
self.assertTrue(torch.equal(a1, correct_a1))
self.assertTrue(torch.equal(p, correct_p))

def test_get_all_triplets_indices(self):
torch.manual_seed(920)
for dtype in TEST_DTYPES:
for batch_size in [32, 256, 512]:
for ref_labels in [None, torch.randint(0, 5, size=(batch_size // 2,))]:
labels = torch.randint(0, 5, size=(batch_size,))

a, p, n = lmu.get_all_triplets_indices(labels, ref_labels)
matches, diffs = lmu.get_matches_and_diffs(labels, ref_labels)

a2, p2, n2 = lmu.get_all_triplets_indices_vectorized_method(
matches, diffs
)
a3, p3, n3 = lmu.get_all_triplets_indices_loop_method(
labels, matches, diffs
)
self.assertTrue(
(a == a2).all() and (p == p2).all() and (n == n2).all()
)
self.assertTrue(
(a == a3).all() and (p == p3).all() and (n == n3).all()
)


if __name__ == "__main__":
unittest.main()