Skip to content

Commit

Permalink
[CHERRY-PICK] MovingMNIST split fix (#7449) (#7451)
Browse files Browse the repository at this point in the history
Co-authored-by: Shu <[email protected]>
  • Loading branch information
pmeier and Shu-Wan authored Mar 23, 2023
1 parent e872006 commit 2bda93b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 8 additions & 7 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,29 +1504,30 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):

ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))

_NUM_FRAMES = 20

def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 20
num_samples = 5
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)),
np.ones((self._NUM_FRAMES - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples

@datasets_utils.test_all_configs
def test_split(self, config):
if config["split"] is None:
return

with self.create_dataset(config) as (dataset, info):
with self.create_dataset(config) as (dataset, _):
if config["split"] == "train":
assert (dataset.data == 0).all()
else:
elif config["split"] == "test":
assert (dataset.data == 1).all()
else:
assert dataset.data.size()[1] == self._NUM_FRAMES


class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
else:
elif self.split == "test":
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()

Expand Down

0 comments on commit 2bda93b

Please sign in to comment.