Skip to content

Commit

Permalink
Added has_custom to MixedCut
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Aug 8, 2024
1 parent bf37599 commit 1054836
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lhotse/cut/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ def __getattr__(self, name: str) -> Any:
f"when a MixedCut consists of more than one MonoCut with that attribute)."
)

def has_custom(self, name: str) -> bool:
(
non_padding_idx,
mono_cut,
) = self._assert_one_data_cut_with_attr_and_return_it_with_track_index(name)

return hasattr(mono_cut, name)

def load_custom(self, name: str) -> np.ndarray:
"""
Load custom data as numpy array. The custom data is expected to have
Expand Down
29 changes: 29 additions & 0 deletions test/cut/test_custom_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,32 @@ def test_multi_cut_custom_multi_recording_channel_selector():
audio = two_channel_out.load_target_recording()
assert audio.shape == (2, 16000)
np.testing.assert_allclose(ref_tgt_audio[::3], audio)


def test_padded_cut_custom_recording():
original_duration = 1.0 # seconds
padded_duration = 2.0 # seconds

# prepare cut
cut = dummy_cut(unique_id=0, with_data=True, duration=original_duration)
cut.target_recording = dummy_recording(
unique_id=1, duration=cut.duration, with_data=True
)
target_recording = cut.load_target_recording()

# prepare padded cut (MixedCut)
padded_cut = cut.pad(duration=padded_duration)

# check the padded cut (MixedCut) has the custom attribute
assert padded_cut.has_custom("target_recording")

# load the audio from the padded cut
padded_target_recording = padded_cut.load_target_recording()

# check the non-padded component is matching
np.testing.assert_allclose(
padded_target_recording[:, : cut.num_samples], target_recording
)

# check the padded component is zero
assert np.all(padded_target_recording[:, cut.num_samples :] == 0)

0 comments on commit 1054836

Please sign in to comment.