Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/base_test_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
pytorch-version: 1.6
torchvision-version: 0.7
- python-version: 3.9
pytorch-version: 2.1
torchvision-version: 0.16
pytorch-version: 2.3
torchvision-version: 0.18

steps:
- uses: actions/checkout@v2
Expand All @@ -30,7 +30,7 @@ jobs:
- name: Install dependencies
run: |
pip install .[with-hooks-cpu]
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
pip install --upgrade protobuf==3.20.1
pip install six
pip install packaging
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@

## News

**April 1**: v2.5.0
- Improved `get_all_triplets_indices` so that large batch sizes don't trigger the `INT_MAX` error.
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.5.0).
- Thank you [mkmenta](https://github.com/mkmenta).

**December 15**: v2.4.0
- Added [DynamicSoftMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#dynamicsoftmarginloss).
- Added [RankedListLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#rankedlistloss).
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.4.0).
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0), [Puzer](https://github.com/Puzer), [interestingzhuo](https://github.com/interestingzhuo), and [GaetanLepage](https://github.com/GaetanLepage).

**July 25**: v2.3.0
- Added [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).

## Documentation
- [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/)
- [**View the installation instructions here**](https://github.com/KevinMusgrave/pytorch-metric-learning#installation)
Expand Down Expand Up @@ -236,6 +237,7 @@ Thanks to the contributors who made pull requests!
| [AlenUbuntu](https://github.com/AlenUbuntu) | [CircleLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#circleloss) |
| [interestingzhuo](https://github.com/interestingzhuo) | [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) |
| [wconnell](https://github.com/wconnell) | [Learning a scRNAseq Metric Embedding](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/scRNAseq_MetricEmbedding.ipynb) |
| [mkmenta](https://github.com/mkmenta) | Improved `get_all_triplets_indices` (fixed the `INT_MAX` error) |
| [AlexSchuy](https://github.com/AlexSchuy) | optimized ```utils.loss_and_miner_utils.get_random_triplet_indices``` |
| [JohnGiorgi](https://github.com/JohnGiorgi) | ```all_gather``` in [utils.distributed](https://kevinmusgrave.github.io/pytorch-metric-learning/distributed) |
| [Hummer12007](https://github.com/Hummer12007) | ```utils.key_checker``` |
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
],
python_requires=">=3.0",
install_requires=[
"numpy",
"numpy < 2.0",
"scikit-learn",
"tqdm",
"torch >= 1.6.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.5.0"
__version__ = "2.6.0"
12 changes: 10 additions & 2 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch

from ..losses import BaseMetricLossFunction, CrossBatchMemory
Expand Down Expand Up @@ -93,15 +95,21 @@ def __init__(self, loss, efficient=False):

def forward(
self,
emb,
embeddings,
labels=None,
indices_tuple=None,
ref_emb=None,
ref_labels=None,
enqueue_mask=None,
):
if not is_distributed():
warnings.warn(
"DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is."
)
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)

world_size = torch.distributed.get_world_size()
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
common_args = [embeddings, labels, indices_tuple, ref_emb, ref_labels, world_size]
if isinstance(self.loss, CrossBatchMemory):
return self.forward_cross_batch(*common_args, enqueue_mask)
return self.forward_regular_loss(*common_args)
Expand Down