Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/base_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
pytorch-version: 1.6
torchvision-version: 0.7
- python-version: 3.9
pytorch-version: 2.1
torchvision-version: 0.16
pytorch-version: 2.3
torchvision-version: 0.18

steps:
- uses: actions/checkout@v2
Expand All @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
pip install .[with-hooks-cpu]
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install --upgrade protobuf==3.20.1
pip install six
pip install packaging
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
],
python_requires=">=3.0",
install_requires=[
"numpy",
"numpy < 2.0",
"scikit-learn",
"tqdm",
"torch >= 1.6.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.5.0"
__version__ = "2.6.0"
19 changes: 17 additions & 2 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch

from ..losses import BaseMetricLossFunction, CrossBatchMemory
Expand Down Expand Up @@ -93,15 +95,28 @@ def __init__(self, loss, efficient=False):

def forward(
self,
emb,
embeddings,
labels=None,
indices_tuple=None,
ref_emb=None,
ref_labels=None,
enqueue_mask=None,
):
if not is_distributed():
warnings.warn(
"DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is."
)
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)

world_size = torch.distributed.get_world_size()
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
common_args = [
embeddings,
labels,
indices_tuple,
ref_emb,
ref_labels,
world_size,
]
if isinstance(self.loss, CrossBatchMemory):
return self.forward_cross_batch(*common_args, enqueue_mask)
return self.forward_regular_loss(*common_args)
Expand Down