Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Audio data sources + Numpy file support #651

Merged
merged 23 commits into from
Aug 13, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixes
  • Loading branch information
ethanwharris committed Aug 13, 2021
commit a4665129f5b4c6621e06b04f3c45296490781e3e
2 changes: 1 addition & 1 deletion flash/core/data/transforms.py
Original file line number Diff line number Diff line change
@@ -106,7 +106,7 @@ def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
"""
for sample in samples:
for key in sample.keys():
if torch.is_tensor(sample[key]) and sample[key].ndim() == 4:
if torch.is_tensor(sample[key]) and sample[key].ndim == 4:
sample[key] = sample[key].squeeze(0)
return default_collate(samples)

4 changes: 2 additions & 2 deletions tests/core/data/test_sampler.py
Original file line number Diff line number Diff line change
@@ -20,13 +20,13 @@
@mock.patch("flash.core.data.data_module.DataLoader")
def test_dataloaders_with_sampler(mock_dataloader):
train_ds = val_ds = test_ds = "dataset"
mock_sampler = "sampler"
mock_sampler = mock.MagicMock()
dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler)
assert dm.sampler is mock_sampler
dl = dm.train_dataloader()
kwargs = mock_dataloader.call_args[1]
assert "sampler" in kwargs
assert kwargs["sampler"] is mock_sampler
assert kwargs["sampler"] is mock_sampler.return_value
for dl in [dm.val_dataloader(), dm.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
assert "sampler" not in kwargs