Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

enforce random sampling at first call in kornia parallell transforms #351

Merged
merged 7 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class KorniaParallelTransforms(nn.Sequential):

def __init__(self, *args):
super().__init__(*[convert_to_modules(arg) for arg in args])
self._reuse_params: bool = False

@property
def reuse_params(self) -> bool:
return self._reuse_params

def set_reuse_params(self, val: bool):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
self._reuse_params = val

def forward(self, inputs: Any):
result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
Expand All @@ -88,7 +96,7 @@ def forward(self, inputs: Any):
result[i] = transform(input, params)
else: # case for non random transforms
result[i] = transform(input)
if hasattr(transform, "_params") and bool(transform._params):
if not self.reuse_params and hasattr(transform, "_params") and bool(transform._params):
transform._params = None
return result

Expand Down
4 changes: 4 additions & 0 deletions tests/core/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def test_kornia_parallel_transforms(with_params):
transform_a._params = "test"

parallel_transforms = KorniaParallelTransforms(transform_a, transform_b)
assert parallel_transforms.reuse_params is False

parallel_transforms.set_reuse_params(True)
assert parallel_transforms.reuse_params is True

parallel_transforms(samples)

Expand Down