From 55f7ca855124d9700065d426a97bccbbaca3a530 Mon Sep 17 00:00:00 2001 From: bdvllrs <bdvllrs@gmail.com> Date: Tue, 5 Mar 2024 16:03:35 +0000 Subject: [PATCH] Update __init__ for documentation order --- shimmer/__init__.py | 49 ++++++++++++++++++------------------- shimmer/modules/__init__.py | 43 +++++++++++++++++++++----------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 3f3e7bdb..dcd79e93 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -39,6 +39,7 @@ LatentsDomainGroupsDT, LatentsDomainGroupsT, LatentsDomainGroupT, + ModelModeT, RawDomainGroupDT, RawDomainGroupsDT, RawDomainGroupsT, @@ -48,44 +49,42 @@ __all__ = [ "__version__", - "DomainModule", + "LatentsDomainGroupDT", + "LatentsDomainGroupsDT", + "LatentsDomainGroupsT", + "LatentsDomainGroupT", + "RawDomainGroupDT", + "RawDomainGroupsDT", + "RawDomainGroupsT", + "RawDomainGroupT", + "ModelModeT", + "SchedulerArgs", + "GWPredictions", + "GlobalWorkspaceBase", + "GlobalWorkspace", + "VariationalGlobalWorkspace", + "pretrained_global_workspace", "LossOutput", - "GWInterfaceBase", - "GWModule", + "DomainModule", "GWDecoder", "GWEncoder", - "GWInterface", - "GWModuleBase", - "VariationalGWEncoder", - "VariationalGWInterface", - "VariationalGWModule", "VariationalGWEncoder", + "GWInterfaceBase", + "GWInterface", "VariationalGWInterface", + "GWModuleBase", + "GWModule", "VariationalGWModule", - "ContrastiveLoss", "ContrastiveLossType", "VarContrastiveLossType", - "ContrastiveLossWithUncertainty", "contrastive_loss", + "ContrastiveLoss", "contrastive_loss_with_uncertainty", + "ContrastiveLossWithUncertainty", "LossCoefs", "VariationalLossCoefs", - "GWLosses", "GWLossesBase", + "GWLosses", "VariationalGWLosses", - "GlobalWorkspace", - "GlobalWorkspaceBase", - "VariationalGlobalWorkspace", - "SchedulerArgs", - "GWPredictions", - "pretrained_global_workspace", "RepeatedDataset", - "LatentsDomainGroupDT", - "LatentsDomainGroupsDT", - "LatentsDomainGroupsT", - "LatentsDomainGroupT", - "RawDomainGroupDT", - "RawDomainGroupsDT", - "RawDomainGroupsT", - "RawDomainGroupT", ] diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 42be9ef5..b7b34b5c 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -1,3 +1,4 @@ +from shimmer.dataset import RepeatedDataset from shimmer.modules.contrastive_loss import ( ContrastiveLoss, ContrastiveLossType, @@ -33,35 +34,49 @@ VariationalGWLosses, VariationalLossCoefs, ) +from shimmer.modules.vae import ( + VAE, + VAEDecoder, + VAEEncoder, + gaussian_nll, + kl_divergence_loss, + reparameterize, +) __all__ = [ - "DomainModule", + "SchedulerArgs", + "GWPredictions", + "GlobalWorkspaceBase", + "GlobalWorkspace", + "VariationalGlobalWorkspace", + "pretrained_global_workspace", "LossOutput", - "GWInterfaceBase", - "GWModule", + "DomainModule", "GWDecoder", "GWEncoder", + "VariationalGWEncoder", + "GWInterfaceBase", "GWInterface", + "VariationalGWInterface", "GWModuleBase", "GWModule", - "VariationalGWEncoder", - "VariationalGWInterface", "VariationalGWModule", - "ContrastiveLoss", "ContrastiveLossType", "VarContrastiveLossType", - "ContrastiveLossWithUncertainty", "contrastive_loss", + "ContrastiveLoss", "contrastive_loss_with_uncertainty", + "ContrastiveLossWithUncertainty", "LossCoefs", "VariationalLossCoefs", - "GWLosses", "GWLossesBase", + "GWLosses", "VariationalGWLosses", - "GlobalWorkspace", - "GlobalWorkspaceBase", - "VariationalGlobalWorkspace", - "SchedulerArgs", - "GWPredictions", - "pretrained_global_workspace", + "RepeatedDataset", + "reparameterize", + "kl_divergence_loss", + "gaussian_nll", + "VAEEncoder", + "VAEDecoder", + "VAE", ]