Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix metadata dict support #393

Merged
merged 4 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393))

## [0.3.2] - 2021-06-08

Expand Down
12 changes: 11 additions & 1 deletion flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __init__(
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

@staticmethod
def _extract_metadata(
self,
samples: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples]
Expand Down Expand Up @@ -229,8 +229,18 @@ def __init__(
self.save_fn = convert_to_modules(save_fn)
self.save_per_sample = convert_to_modules(save_per_sample)

@staticmethod
def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]:
if isinstance(batch, Mapping):
return batch, batch.get(DefaultDataKeys.METADATA, None)
return batch, None

def forward(self, batch: Sequence[Any]):
batch, metadata = self._extract_metadata(batch)
uncollated = self.uncollate_fn(self.per_batch_transform(batch))
if metadata:
for sample, sample_metadata in zip(uncollated, metadata):
sample[DefaultDataKeys.METADATA] = sample_metadata

final_preds = type(uncollated)([self.serializer(self.per_sample_transform(sample)) for sample in uncollated])

Expand Down