Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PostprocessingDataset: pass along seq_tag #1623

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

from returnn.datasets.basic import DatasetSeq
from returnn.datasets.util.strings import str_to_numpy_array
from returnn.datasets.util.vocabulary import Vocabulary
from returnn.tensor import Tensor, TensorDict
from returnn.tensor.dim import Dim
Expand Down Expand Up @@ -139,7 +140,9 @@ def __init__(
self._out_tensor_dict_template = TensorDict()
self._out_tensor_dict_template.update(self._map_outputs, auto_convert=True)
else:
self._out_tensor_dict_template = self._in_tensor_dict_template
self._out_tensor_dict_template = self._in_tensor_dict_template.copy_template()
# update only after _out_tensor_dict_template has been created from _in_tensor_dict_template
self._in_tensor_dict_template.update({"seq_tag": {"dims": (), "dtype": "string"}}, auto_convert=True)
self.num_outputs = {
k: (t.sparse_dim.size if t.sparse_dim else t.shape[-1] if len(t.shape) > 0 else 1, t.ndim)
for k, t in self._out_tensor_dict_template.data.items()
Expand Down Expand Up @@ -199,7 +202,11 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
assert loaded_seq_idx <= seq_idx, "_collect_single_seq must be done monotonically"
if loaded_seq_idx != seq_idx:
continue
seq = DatasetSeq(features={k: t.raw_tensor for k, t in tensor_dict.data.items()}, seq_idx=seq_idx)
seq = DatasetSeq(
features={k: t.raw_tensor for k, t in tensor_dict.data.items() if k != "seq_tag"},
seq_idx=seq_idx,
seq_tag=str(tensor_dict["seq_tag"].raw_tensor),
)
return seq

def _build_mapping_iter(self) -> Iterator[TensorDict]:
Expand All @@ -209,6 +216,7 @@ def _build_mapping_iter(self) -> Iterator[TensorDict]:

def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]:
for t_dict in inner:
assert "seq_tag" in t_dict.data, "seq_tag dropped from TensorDict in postprocessing pipeline"
for data_key, out_t in self._out_tensor_dict_template.data.items():
in_t = t_dict.data[data_key]
assert (
Expand Down Expand Up @@ -237,16 +245,25 @@ def _iterate_dataset(self) -> Iterator[TensorDict]:
seq_index = 0
while self._dataset.is_less_than_num_seqs(seq_index):
self._dataset.load_seqs(seq_index, seq_index + 1)

tensor_dict = self._in_tensor_dict_template.copy_template()
for data_key in data_keys:
tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
albertz marked this conversation as resolved.
Show resolved Hide resolved

if self._map_seq is not None:
tensor_dict = self._map_seq(
tensor_dict, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None}
)
assert isinstance(
tensor_dict, TensorDict
), f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"

# Re-adding the seq tag here causes no harm in case it's dropped since we don't
# add/drop any segments w/ the non-iterator postprocessing function.
if "seq_tag" not in tensor_dict.data:
tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))

yield tensor_dict
seq_index += 1

Expand Down
3 changes: 3 additions & 0 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,9 @@ def _add_1337_to_classes(tdict: TensorDict, **kwargs) -> TensorDict:
classes = dataset.get_data(0, "classes")
assert all(c - 1337 >= 0 for c in classes)

# assert that default seq tags have been replaced w/ ones from oggzip dset
assert not dataset.get_tag(0).startswith("seq-")

count = 0

def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]:
Expand Down
Loading