Skip to content

Commit

Permalink
update DALIClassificationLoader to not use deprecated arguments (#4925)
Browse files Browse the repository at this point in the history
* update DALIClassificationLoader to not use deprecated arguments

* fix line length

* dali version check added and changed args accordingly

* versions

* checking version using disutils.version.LooseVersion now

* .

* ver

* import

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2020
1 parent 81070be commit 16e819e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
3 changes: 1 addition & 2 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ steps:
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
- pip install git+https://${AUTH_TOKEN}@github.com/PyTorchLightning/[email protected] -v --no-cache-dir
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
# todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
- pip list
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
# Running special tests
Expand Down
23 changes: 18 additions & 5 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from argparse import ArgumentParser
from random import shuffle
from warnings import warn
from distutils.version import LooseVersion

import numpy as np
import torch
Expand All @@ -31,12 +32,17 @@
from tests.base.datasets import MNIST

if DALI_AVAILABLE:
import nvidia.dali.ops as ops
from nvidia.dali import ops
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali import __version__ as dali_version

NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0')
if NEW_DALI_API:
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
else:
warn('NVIDIA DALI is not available')
ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC
ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC


class ExternalMNISTInputIterator(object):
Expand Down Expand Up @@ -97,11 +103,18 @@ def __init__(
dynamic_shape=False,
last_batch_padded=False,
):
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)
if NEW_DALI_API:
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
else:
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
dynamic_shape, last_batch_padded)
self._fill_last_batch = fill_last_batch

def __len__(self):
batch_count = self._size // (self._num_gpus * self.batch_size)
last_batch = 1 if self._fill_last_batch else 0
last_batch = 1 if self._fill_last_batch else 1
return batch_count + last_batch


Expand Down Expand Up @@ -178,7 +191,7 @@ def cli_main():
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)

pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True)

pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)
Expand Down

0 comments on commit 16e819e

Please sign in to comment.