From 55f7ca855124d9700065d426a97bccbbaca3a530 Mon Sep 17 00:00:00 2001
From: bdvllrs <bdvllrs@gmail.com>
Date: Tue, 5 Mar 2024 16:03:35 +0000
Subject: [PATCH] Update __init__ for documentation order

---
 shimmer/__init__.py         | 49 ++++++++++++++++++-------------------
 shimmer/modules/__init__.py | 43 +++++++++++++++++++++-----------
 2 files changed, 53 insertions(+), 39 deletions(-)

diff --git a/shimmer/__init__.py b/shimmer/__init__.py
index 3f3e7bdb..dcd79e93 100644
--- a/shimmer/__init__.py
+++ b/shimmer/__init__.py
@@ -39,6 +39,7 @@
     LatentsDomainGroupsDT,
     LatentsDomainGroupsT,
     LatentsDomainGroupT,
+    ModelModeT,
     RawDomainGroupDT,
     RawDomainGroupsDT,
     RawDomainGroupsT,
@@ -48,44 +49,42 @@
 
 __all__ = [
     "__version__",
-    "DomainModule",
+    "LatentsDomainGroupDT",
+    "LatentsDomainGroupsDT",
+    "LatentsDomainGroupsT",
+    "LatentsDomainGroupT",
+    "RawDomainGroupDT",
+    "RawDomainGroupsDT",
+    "RawDomainGroupsT",
+    "RawDomainGroupT",
+    "ModelModeT",
+    "SchedulerArgs",
+    "GWPredictions",
+    "GlobalWorkspaceBase",
+    "GlobalWorkspace",
+    "VariationalGlobalWorkspace",
+    "pretrained_global_workspace",
     "LossOutput",
-    "GWInterfaceBase",
-    "GWModule",
+    "DomainModule",
     "GWDecoder",
     "GWEncoder",
-    "GWInterface",
-    "GWModuleBase",
-    "VariationalGWEncoder",
-    "VariationalGWInterface",
-    "VariationalGWModule",
     "VariationalGWEncoder",
+    "GWInterfaceBase",
+    "GWInterface",
     "VariationalGWInterface",
+    "GWModuleBase",
+    "GWModule",
     "VariationalGWModule",
-    "ContrastiveLoss",
     "ContrastiveLossType",
     "VarContrastiveLossType",
-    "ContrastiveLossWithUncertainty",
     "contrastive_loss",
+    "ContrastiveLoss",
     "contrastive_loss_with_uncertainty",
+    "ContrastiveLossWithUncertainty",
     "LossCoefs",
     "VariationalLossCoefs",
-    "GWLosses",
     "GWLossesBase",
+    "GWLosses",
     "VariationalGWLosses",
-    "GlobalWorkspace",
-    "GlobalWorkspaceBase",
-    "VariationalGlobalWorkspace",
-    "SchedulerArgs",
-    "GWPredictions",
-    "pretrained_global_workspace",
     "RepeatedDataset",
-    "LatentsDomainGroupDT",
-    "LatentsDomainGroupsDT",
-    "LatentsDomainGroupsT",
-    "LatentsDomainGroupT",
-    "RawDomainGroupDT",
-    "RawDomainGroupsDT",
-    "RawDomainGroupsT",
-    "RawDomainGroupT",
 ]
diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py
index 42be9ef5..b7b34b5c 100644
--- a/shimmer/modules/__init__.py
+++ b/shimmer/modules/__init__.py
@@ -1,3 +1,4 @@
+from shimmer.dataset import RepeatedDataset
 from shimmer.modules.contrastive_loss import (
     ContrastiveLoss,
     ContrastiveLossType,
@@ -33,35 +34,49 @@
     VariationalGWLosses,
     VariationalLossCoefs,
 )
+from shimmer.modules.vae import (
+    VAE,
+    VAEDecoder,
+    VAEEncoder,
+    gaussian_nll,
+    kl_divergence_loss,
+    reparameterize,
+)
 
 __all__ = [
-    "DomainModule",
+    "SchedulerArgs",
+    "GWPredictions",
+    "GlobalWorkspaceBase",
+    "GlobalWorkspace",
+    "VariationalGlobalWorkspace",
+    "pretrained_global_workspace",
     "LossOutput",
-    "GWInterfaceBase",
-    "GWModule",
+    "DomainModule",
     "GWDecoder",
     "GWEncoder",
+    "VariationalGWEncoder",
+    "GWInterfaceBase",
     "GWInterface",
+    "VariationalGWInterface",
     "GWModuleBase",
     "GWModule",
-    "VariationalGWEncoder",
-    "VariationalGWInterface",
     "VariationalGWModule",
-    "ContrastiveLoss",
     "ContrastiveLossType",
     "VarContrastiveLossType",
-    "ContrastiveLossWithUncertainty",
     "contrastive_loss",
+    "ContrastiveLoss",
     "contrastive_loss_with_uncertainty",
+    "ContrastiveLossWithUncertainty",
     "LossCoefs",
     "VariationalLossCoefs",
-    "GWLosses",
     "GWLossesBase",
+    "GWLosses",
     "VariationalGWLosses",
-    "GlobalWorkspace",
-    "GlobalWorkspaceBase",
-    "VariationalGlobalWorkspace",
-    "SchedulerArgs",
-    "GWPredictions",
-    "pretrained_global_workspace",
+    "RepeatedDataset",
+    "reparameterize",
+    "kl_divergence_loss",
+    "gaussian_nll",
+    "VAEEncoder",
+    "VAEDecoder",
+    "VAE",
 ]