Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add half precision testing [1/n] #77

Merged
merged 21 commits into from
Mar 26, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'master' into half_testing1
SkafteNicki authored Mar 24, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 5e262f6eee61b7956e7b4f582b253b8a5bf56e92
23 changes: 22 additions & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
@@ -109,6 +109,28 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.
val3 = self.metric3['accuracy'](preds, target)
val4 = self.metric4(preds, target)

Metrics in Dataparallel (DP) mode
=================================

When using metrics in `Dataparallel (DP) <https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel>`_
mode, one should be aware DP will both create and clean-up replicas of Metric objects during a single forward pass.
This has the consequence, that the metric state of the replicas will as default be destroyed before we can sync
them. It is therefore recommended, when using metrics in DP mode, to initialize them with ``dist_sync_on_step=True``
such that metric states are synchonized between the main process and the replicas before they are destroyed.

Metrics in Distributed Data Parallel (DDP) mode
===============================================

When using metrics in `Distributed Data Parallel (DPP) <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_
mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is
not equally divisible by ``batch_size * num_processors``. The added samples will always be replicates of datapoints
already in your dataset. This is done to secure an equal load for all processes. However, this has the consequence
that the calculated metric value will be sligtly bias towards those replicated samples, leading to a wrong result.

During training and/or validation this may not be important, however it is highly recommended when evaluating
the test dataset to only run on a single gpu or use a `join <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.join>`_
context in conjunction with DDP to prevent this behaviour.

****************************
Metrics and 16-bit precision
****************************
@@ -123,7 +145,6 @@ the following limitations:
but they are also listed below:

- :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]`


******************
Metric Arithmetics
32 changes: 5 additions & 27 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -12,33 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities"""
import importlib
import operator
from distutils.version import LooseVersion
from importlib.util import find_spec

from pkg_resources import DistributionNotFound
import torch


def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements

>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = LooseVersion(pkg.__version__)
except AttributeError:
return False
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, LooseVersion(version))


_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")
_TORCH_GREATER_EQUAL_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
You are viewing a condensed version of this merge commit. You can view the full changes here.