Skip to content

Commit

Permalink
feat(datasets) Add tests for pacs, cinic10, caltech101, office-home (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Sep 10, 2024
1 parent c5e28d4 commit 75d0243
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 15 deletions.
63 changes: 48 additions & 15 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,30 @@

mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]

mocked_by_partial_download_datasets = [
"flwrlabs/pacs",
"flwrlabs/cinic10",
"flwrlabs/caltech101",
"flwrlabs/office-home",
"flwrlabs/fed-isic2019",
]

natural_id_datasets = [
"flwrlabs/femnist",
]

mocked_natural_id_datasets = [
"flwrlabs/ucf101",
"flwrlabs/ambient-acoustic-context",
"LIUM/tedlium",
]


@parameterized_class(
("dataset_name", "test_split", "subset"),
[
# Downloaded
# #Image datasets
# Image
("mnist", "test", ""),
("cifar10", "test", ""),
("fashion_mnist", "test", ""),
Expand All @@ -52,15 +70,22 @@
("scikit-learn/adult-census-income", None, ""),
("jlh/uci-mushrooms", None, ""),
("scikit-learn/iris", None, ""),
# Mocked
# #Image
# Mocked by local recreation
# Image
("cifar100", "test", ""),
# Note: there's also the extra split and full_numbers subset
("svhn", "test", "cropped_digits"),
# Text
("sentiment140", "test", ""), # aka twitter
# Audio
("speech_commands", "test", "v0.01"),
# Mocked by partial download
# Image
("flwrlabs/pacs", None, ""),
("flwrlabs/cinic10", "test", ""),
("flwrlabs/caltech101", None, ""),
("flwrlabs/office-home", None, ""),
("flwrlabs/fed-isic2019", "test", ""),
],
)
class BaseFederatedDatasetsTest(unittest.TestCase):
Expand All @@ -86,10 +111,29 @@ def setUp(self) -> None:
self.mock_load_dataset.return_value = _load_mocked_dataset(
self.dataset_name, [200, 100], ["train", self.test_split], self.subset
)
elif self.dataset_name in mocked_by_partial_download_datasets:
split_names = ["train"]
skip_take_lists = [[(0, 30), (1000, 30), (2000, 40)]]
# If the dataset has split test update the mocking to include it
if self.test_split is not None:
split_names.append(self.test_split)
skip_take_lists.append([(0, 30), (100, 30), (200, 40)])
mock_return_value = _load_mocked_dataset_dict_by_partial_download(
dataset_name=self.dataset_name,
split_names=split_names,
skip_take_lists=skip_take_lists,
subset_name=None if self.subset == "" else self.subset,
)
self.patcher = patch("datasets.load_dataset")
self.mock_load_dataset = self.patcher.start()
self.mock_load_dataset.return_value = mock_return_value

def tearDown(self) -> None:
"""Clean up after the dataset mocking."""
if self.dataset_name in mocked_datasets:
if (
self.dataset_name in mocked_datasets
or self.dataset_name in mocked_by_partial_download_datasets
):
patch.stopall()

@parameterized.expand( # type: ignore
Expand Down Expand Up @@ -403,17 +447,6 @@ def test_mixed_type_partitioners_creates_from_int(self) -> None:
)


natural_id_datasets = [
"flwrlabs/femnist",
]

mocked_natural_id_datasets = [
"flwrlabs/ucf101",
"flwrlabs/ambient-acoustic-context",
"LIUM/tedlium",
]


@parameterized_class(
("dataset_name", "test_split", "subset", "partition_by"),
[
Expand Down
3 changes: 3 additions & 0 deletions datasets/flwr_datasets/mock_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,9 @@ def _load_mocked_dataset_dict_by_partial_download(
subset_name: Optional[str] = None,
) -> DatasetDict:
"""Like _load_mocked_dataset_by_partial_download but for many splits."""
assert len(split_names) == len(
skip_take_lists
), "The split_names should be thesame length as the skip_take_lists."
dataset_dict = {}
for split_name, skip_take_list in zip(split_names, skip_take_lists):
dataset_dict[split_name] = _load_mocked_dataset_by_partial_download(
Expand Down
5 changes: 5 additions & 0 deletions datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
"Mike0307/MNIST-M",
"flwrlabs/usps",
"scikit-learn/iris",
"flwrlabs/pacs",
"flwrlabs/cinic10",
"flwrlabs/caltech101",
"flwrlabs/office-home",
"flwrlabs/fed-isic2019",
]


Expand Down

0 comments on commit 75d0243

Please sign in to comment.