Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
enforce random sampling at first call in kornia parallell transforms (#…
Browse files Browse the repository at this point in the history
…351)

* add reuse_params in KorniaParallelTransforms

* enforce to sample parameters at first call

* update changelog

* fix test_kornia_parallel_transforms

* fix test_kornia_parallel_transforms

* Update test

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
edgarriba and ethanwharris authored Jun 1, 2021
1 parent 2da2328 commit 55dddea
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed

- Changed the installation command for extra features ([#346](https://github.com/PyTorchLightning/lightning-flash/pull/346))


- Fixed a bug where the translation task wasn't decoding tokens properly ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


- Fixed a bug where huggingface tokenizers were sometimes being pickled ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))
- Fixed issue with `KorniaParallelTransforms` to assure to share the random state between transforms ([#351](https://github.com/PyTorchLightning/lightning-flash/pull/351))
- Change resize interpolation default mode to nearest ([#352](https://github.com/PyTorchLightning/lightning-flash/pull/352))


## [0.3.0] - 2021-05-20
Expand Down
11 changes: 8 additions & 3 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,21 @@ def forward(self, inputs: Any):
result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
for transform in self.children():
inputs = result

# we enforce the first time to sample random params
result[0] = transform(inputs[0])

if hasattr(transform, "_params") and bool(transform._params):
params = transform._params
else:
params = None

for i, input in enumerate(inputs):
# apply transforms from (1, n)
for i, input in enumerate(inputs[1:]):
if params is not None:
result[i] = transform(input, params)
result[i + 1] = transform(input, params)
else: # case for non random transforms
result[i] = transform(input)
result[i + 1] = transform(input)
if hasattr(transform, "_params") and bool(transform._params):
transform._params = None
return result
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"post_tensor_transform": nn.Sequential(
ApplyToKeys(
[DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='bilinear')),
KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')),
),
),
"collate": Compose([kornia_collate, ApplyToKeys(DefaultDataKeys.TARGET, prepare_target)]),
Expand Down
7 changes: 4 additions & 3 deletions tests/core/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,18 @@ def test_kornia_parallel_transforms(with_params):
transform_b = Mock(spec=torch.nn.Module)

if with_params:
transform_a._params = "test"
transform_a._params = "test" # initialize params with some value

parallel_transforms = KorniaParallelTransforms(transform_a, transform_b)

parallel_transforms(samples)

assert transform_a.call_count == 2
assert transform_b.call_count == 2

if with_params:
assert transform_a.call_args_list[0][0][1] == transform_a.call_args_list[1][0][1] == "test"
assert transform_a.call_args_list[1][0][1] == "test"
# check that after the forward `_params` is set to None
assert transform_a._params == transform_a._params is None

assert torch.allclose(transform_a.call_args_list[0][0][0], samples[0])
assert torch.allclose(transform_a.call_args_list[1][0][0], samples[1])
Expand Down

0 comments on commit 55dddea

Please sign in to comment.