Skip to content

Commit df53e74

Browse files
Merge pull request #423 from KevinMusgrave/dev
v1.1.1
2 parents 537a5f5 + b6bfcff commit df53e74

File tree

14 files changed

+137
-28
lines changed

14 files changed

+137
-28
lines changed

conda_build/pytorch-metric-learning/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{% set name = "pytorch-metric-learning" %}
2-
{% set version = "1.1.0" %}
2+
{% set version = "1.1.1" %}
33

44
package:
55
name: "{{ name|lower }}"
66
version: "{{ version }}"
77

88
source:
99
url: "https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz"
10-
sha256: d52913eee027746de928bf0e9c031f59a0915cfee4f02cdf81c198a058bd6b21
10+
sha256: 6e572dc54179c762abc333fc4c6f68fcd909e800f9519ca1463235d14b9f5c44
1111

1212
build:
1313
number: 0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.1.0"
1+
__version__ = "1.1.1"

src/pytorch_metric_learning/distances/base_distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def forward(self, query_emb, ref_emb=None):
2727
)
2828
mat = self.compute_mat(query_emb_normalized, ref_emb_normalized)
2929
if self.power != 1:
30-
mat = mat ** self.power
30+
mat = mat**self.power
3131
assert mat.size() == torch.Size((query_emb.size(0), ref_emb.size(0)))
3232
return mat
3333

src/pytorch_metric_learning/losses/fast_ap_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
3434
return self.zero_losses()
3535
dist_mat = self.distance(embeddings)
3636

37-
histogram_max = 2 ** self.distance.power
37+
histogram_max = 2**self.distance.power
3838
histogram_delta = histogram_max / self.num_bins
3939
mid_points = torch.linspace(
4040
0.0, histogram_max, steps=self.num_edges, device=device, dtype=dtype

src/pytorch_metric_learning/losses/large_margin_softmax_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def get_cos_with_margin(self, cosine):
4646
cosine = cosine.unsqueeze(1)
4747
for attr in ["n_range", "margin_choose_n", "cos_powers", "alternating"]:
4848
setattr(self, attr, c_f.to_device(getattr(self, attr), cosine))
49-
cos_powered = cosine ** self.cos_powers
50-
sin_powered = (1 - cosine ** 2) ** self.n_range
49+
cos_powered = cosine**self.cos_powers
50+
sin_powered = (1 - cosine**2) ** self.n_range
5151
terms = (
5252
self.alternating * self.margin_choose_n * cos_powered * sin_powered
5353
) # Equation 7 in the paper

src/pytorch_metric_learning/miners/distance_weighted_miner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def mine(self, embeddings, labels, ref_emb, ref_labels):
3030

3131
# See the first equation from Section 4 of the paper
3232
log_weights = (2.0 - d) * torch.log(mat) - ((d - 3) / 2) * torch.log(
33-
1.0 - 0.25 * (mat ** 2.0)
33+
1.0 - 0.25 * (mat**2.0)
3434
)
3535

3636
inf_or_nan = torch.isinf(log_weights) | torch.isnan(log_weights)

src/pytorch_metric_learning/regularizers/center_invariant_regularizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def compute_loss(self, weights):
1515
deviations_from_mean = squared_weight_norms - torch.mean(squared_weight_norms)
1616
return {
1717
"loss": {
18-
"losses": (deviations_from_mean ** 2) / 4,
18+
"losses": (deviations_from_mean**2) / 4,
1919
"indices": c_f.torch_arange_from_size(weights),
2020
"reduction_type": "element",
2121
}

src/pytorch_metric_learning/regularizers/lp_regularizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, p=2, power=1, **kwargs):
1414
def compute_loss(self, embeddings):
1515
reg = torch.norm(embeddings, p=self.p, dim=1)
1616
if self.power != 1:
17-
reg = reg ** self.power
17+
reg = reg**self.power
1818
return {
1919
"loss": {
2020
"losses": reg,

src/pytorch_metric_learning/utils/accuracy_calculator.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def r_precision(
7070
matches_per_row = torch.sum(same_label * relevance_mask, dim=1)
7171
max_possible_matches_per_row = torch.sum(relevance_mask, dim=1)
7272
accuracy_per_sample = (
73-
c_f.to_dtype(matches_per_row, dtype=torch.float64)
74-
/ max_possible_matches_per_row
73+
matches_per_row.type(torch.float64) / max_possible_matches_per_row
7574
)
7675
return maybe_get_avg_of_avgs(
7776
accuracy_per_sample, gt_labels, avg_of_avgs, return_per_class
@@ -99,9 +98,7 @@ def mean_average_precision(
9998
equality = is_same_label * relevance_mask
10099
cumulative_correct = torch.cumsum(equality, dim=1)
101100
k_idx = torch.arange(1, num_k + 1, device=device).repeat(num_samples, 1)
102-
precision_at_ks = (
103-
c_f.to_dtype(cumulative_correct * equality, dtype=torch.float64) / k_idx
104-
)
101+
precision_at_ks = (cumulative_correct * equality).type(torch.float64) / k_idx
105102
summed_precision_per_row = torch.sum(precision_at_ks * relevance_mask, dim=1)
106103
if at_r:
107104
max_possible_matches_per_row = torch.sum(relevance_mask, dim=1)
@@ -172,9 +169,7 @@ def precision_at_k(
172169
):
173170
curr_knn_labels = knn_labels[:, :k]
174171
same_label = label_comparison_fn(gt_labels, curr_knn_labels)
175-
accuracy_per_sample = (
176-
c_f.to_dtype(torch.sum(same_label, dim=1), dtype=torch.float64) / k
177-
)
172+
accuracy_per_sample = torch.sum(same_label, dim=1).type(torch.float64) / k
178173
return maybe_get_avg_of_avgs(
179174
accuracy_per_sample, gt_labels, avg_of_avgs, return_per_class
180175
)
@@ -209,9 +204,7 @@ def get_lone_query_labels(
209204
unique_labels, match_counts = label_counts
210205
if embeddings_come_from_same_source:
211206
label_matches_itself = label_comparison_fn(unique_labels, unique_labels)
212-
lone_condition = (
213-
match_counts - c_f.to_dtype(label_matches_itself, dtype=torch.long) <= 0
214-
)
207+
lone_condition = match_counts - label_matches_itself.type(torch.long) <= 0
215208
else:
216209
lone_condition = match_counts == 0
217210
lone_query_labels = unique_labels[lone_condition]

src/pytorch_metric_learning/utils/common_functions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,12 @@ def get_hierarchy_label(batch_labels, hierarchy_level):
8686
def map_labels(label_map, labels):
8787
labels = to_numpy(labels)
8888
if labels.ndim == 2:
89+
new_labels = np.zeros(labels.shape, dtype=int)
8990
for h in range(labels.shape[1]):
90-
labels[:, h] = label_map(labels[:, h], h)
91+
new_labels[:, h] = label_map(labels[:, h], h)
9192
else:
92-
labels = label_map(labels, 0)
93-
return labels
93+
new_labels = label_map(labels, 0)
94+
return new_labels
9495

9596

9697
def process_label(labels, hierarchy_level, label_map):
@@ -235,7 +236,9 @@ def make_label_to_rank_dict(label_set):
235236
Returns:
236237
A dictionary mapping each label to its numeric rank in the original set
237238
"""
238-
ranked = scipy.stats.rankdata(label_set) - 1
239+
if len(set(label_set)) != len(label_set):
240+
raise ValueError("label set must not have duplicates")
241+
ranked = scipy.stats.rankdata(label_set).astype(int) - 1
239242
return {k: v for k, v in zip(label_set, ranked)}
240243

241244

0 commit comments

Comments
 (0)