Skip to content

Commit

Permalink
Update arg defaults and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Nov 27, 2024
1 parent e9e76af commit 784935e
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions datasets/flwr_datasets/partitioner/image_semantic_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,34 @@ class ImageSemanticPartitioner(Partitioner):
efficient_net_type: int
The type of pretrained EfficientNet model.
Options: [0, 1, 2, 3, 4, 5, 6, 7], corresponding to EfficientNet B0-B7 models.
Defaults to 7.
batch_size: int
The batch size for EfficientNet extracting embeddings.
pca_components: int
Defaults to 32.
pca_components: Optional[int]
The number of PCA components for dimensionality reduction.
Defaults to None.
gmm_max_iter: int
The maximum number of iterations for the GMM algorithm.
Defaults to 100.
gmm_init_params: str
The initialization method for the GMM algorithm.
Options: ["random", "kmeans", "k-means++"]
Defaults to "random".
use_cuda: bool
Whether to use CUDA for computation acceleration.
Defaults to False.
image_column_name: Optional[str]
The name of the image column in the dataset. If not set, the first image column
is used.
Defaults to None.
kl_pairwise_batch_size: int
The batch size for computing pairwise KL-divergence of two label clusters.
Defaults to 32.
Defaults to 10.
shuffle: bool
Whether to randomize the order of samples. Shuffling applied after the
samples assignment to partitions.
Defaults to True.
rng_seed: Optional[int]
Seed used for numpy random number generator,
which used throughout the process. Defaults to None.
Expand Down Expand Up @@ -133,14 +141,14 @@ def __init__( # pylint: disable=R0913
self,
num_partitions: int,
partition_by: str,
efficient_net_type: int = 3,
efficient_net_type: int = 7,
batch_size: int = 32,
pca_components: int = 256,
pca_components: Optional[int] = None,
gmm_max_iter: int = 100,
gmm_init_params: str = "random",
use_cuda: bool = False,
image_column_name: Optional[str] = None,
kl_pairwise_batch_size: int = 32,
kl_pairwise_batch_size: int = 10,
shuffle: bool = True,
rng_seed: Optional[int] = None,
pca_seed: Optional[int] = None,
Expand Down Expand Up @@ -215,11 +223,11 @@ def _determine_partition_id_to_indices_if_needed(self) -> None:
from sklearn.preprocessing import StandardScaler
from torch.distributions import MultivariateNormal, kl_divergence
from torchvision import models
except ImportError:
except ImportError as err:
raise ImportError(
"ImageSemanticPartitioner requires scikit-learn, torch, "
"torchvision, scipy, and numpy."
) from None
) from err
efficient_nets_dict = [
(models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT),
(models.efficientnet_b1, models.EfficientNet_B1_Weights.DEFAULT),
Expand Down Expand Up @@ -310,7 +318,10 @@ def _subsample(embeddings: NDArrayFloat, num_samples: int) -> NDArrayFloat:
embedding_list
)

if 0 < self._pca_components < embeddings_scaled.shape[1]:
if self._pca_components is None or (
isinstance(self._pca_components, int)
and 0 < self._pca_components < embeddings_scaled.shape[1]
):
pca = PCA(n_components=self._pca_components, random_state=self._pca_seed)
# 100000 refers to official implementation
pca.fit(_subsample(embeddings_scaled, 100000))
Expand Down Expand Up @@ -528,7 +539,7 @@ def _check_variable_validation(self) -> None:
)
if self._gmm_max_iter <= 0:
raise ValueError("The gmm max iter needs to be greater than zero.")
if self._pca_components <= 0:
if self._pca_components is not None and self._pca_components <= 0:
raise ValueError("The pca components needs to be greater than zero.")
if self._rng_seed and not isinstance(self._rng_seed, int):
raise TypeError("The rng seed needs to be an integer.")
Expand Down

0 comments on commit 784935e

Please sign in to comment.