Skip to content

Commit 0271f52

Browse files
Merge pull request #426 from KevinMusgrave/dev
v.1.1.2
2 parents df53e74 + 81d80de commit 0271f52

File tree

6 files changed

+31
-11
lines changed

6 files changed

+31
-11
lines changed

conda_build/pytorch-metric-learning/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
{% set name = "pytorch-metric-learning" %}
2-
{% set version = "1.1.1" %}
2+
{% set version = "1.1.2" %}
33

44
package:
55
name: "{{ name|lower }}"
66
version: "{{ version }}"
77

88
source:
99
url: "https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz"
10-
sha256: 6e572dc54179c762abc333fc4c6f68fcd909e800f9519ca1463235d14b9f5c44
10+
sha256: aa2a28b7eb6a3a72f2ab14f59073de286832acc5433863de3c8cfc8e8fed38f4
1111

1212
build:
1313
number: 0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.1.1"
1+
__version__ = "1.1.2"

src/pytorch_metric_learning/utils/logging_presets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def end_of_iteration_hook(self, trainer):
5454
trainer.loss_tracker.loss_weights,
5555
{"parent_name": "loss_weights"},
5656
],
57-
[trainer.loss_funcs, {"recursive_types": [torch.nn.Module]}],
57+
[trainer.loss_funcs, {"recursive_types": [torch.nn.Module, dict]}],
5858
[trainer.mining_funcs, {}],
5959
[trainer.models, {}],
6060
[trainer.optimizers, {"custom_attr_func": self.optimizer_custom_attr_func}],

src/pytorch_metric_learning/utils/module_with_records.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55

66
class ModuleWithRecords(torch.nn.Module):
7-
def __init__(self, collect_stats=c_f.COLLECT_STATS):
7+
def __init__(self, collect_stats=None):
88
super().__init__()
9-
self.collect_stats = collect_stats
9+
self.collect_stats = (
10+
c_f.COLLECT_STATS if collect_stats is None else collect_stats
11+
)
1012

1113
def add_to_recordable_attributes(
1214
self, name=None, list_of_names=None, is_stat=False

tests/trainers/test_metric_loss_only.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision import datasets, transforms
99

1010
from pytorch_metric_learning.losses import NTXentLoss
11+
from pytorch_metric_learning.reducers import AvgNonZeroReducer
1112
from pytorch_metric_learning.samplers import MPerClassSampler
1213
from pytorch_metric_learning.testers import GlobalEmbeddingSpaceTester
1314
from pytorch_metric_learning.trainers import MetricLossOnly
@@ -36,7 +37,7 @@ def test_metric_loss_only(self):
3637
)
3738
)
3839

39-
loss_fn = NTXentLoss()
40+
loss_fn = NTXentLoss(reducer=AvgNonZeroReducer())
4041

4142
normalize_transform = transforms.Normalize(
4243
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
@@ -166,6 +167,18 @@ def test_metric_loss_only(self):
166167

167168
num_epochs = 2
168169
trainer.train(num_epochs=num_epochs)
170+
171+
for record_name in [
172+
"metric_loss_NTXentLoss",
173+
"metric_loss_NTXentLoss__modules_distance_CosineSimilarity",
174+
"metric_loss_NTXentLoss__modules_reducer_AvgNonZeroReducer",
175+
]:
176+
self.assertTrue(record_keeper.table_exists(record_name))
177+
self.assertTrue(
178+
len(record_keeper.query(f"SELECT * FROM {record_name}"))
179+
== num_epochs * iterations_per_epoch / log_freq
180+
)
181+
169182
best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy(
170183
tester, "val"
171184
)

tests/utils/test_common_functions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,15 @@ def test_torch_standard_scaler(self):
5757

5858
def test_collect_stats_flag(self):
5959
self.assertTrue(c_f.COLLECT_STATS == WITH_COLLECT_STATS)
60-
loss_fn = TripletMarginLoss()
61-
self.assertTrue(loss_fn.collect_stats == WITH_COLLECT_STATS)
62-
self.assertTrue(loss_fn.distance.collect_stats == WITH_COLLECT_STATS)
63-
self.assertTrue(loss_fn.reducer.collect_stats == WITH_COLLECT_STATS)
60+
for x in [True, False, True, False, WITH_COLLECT_STATS]:
61+
c_f.COLLECT_STATS = x
62+
loss_fn = TripletMarginLoss()
63+
self.assertTrue(loss_fn.collect_stats == x)
64+
self.assertTrue(loss_fn.distance.collect_stats == x)
65+
self.assertTrue(loss_fn.reducer.collect_stats == x)
66+
for x in [True, False]:
67+
loss_fn = TripletMarginLoss(collect_stats=x)
68+
self.assertTrue(loss_fn.collect_stats == x)
6469

6570
def test_check_shapes(self):
6671
embeddings = torch.randn(32, 512, 3)

0 commit comments

Comments
 (0)