diff --git a/simple_shapes_dataset/logging.py b/simple_shapes_dataset/logging.py index f74a6e2..fa287a7 100644 --- a/simple_shapes_dataset/logging.py +++ b/simple_shapes_dataset/logging.py @@ -13,11 +13,7 @@ from matplotlib import gridspec from matplotlib.figure import Figure from PIL import Image -from shimmer.modules.global_workspace import ( - GlobalWorkspace, - GlobalWorkspaceBase, - GlobalWorkspaceWithUncertainty, -) +from shimmer.modules.global_workspace import GlobalWorkspaceBase from torchvision.utils import make_grid from simple_shapes_dataset import LOGGER @@ -430,7 +426,7 @@ def on_train_epoch_end( if self.mode != "train": return - if not isinstance(pl_module, (GlobalWorkspace, GlobalWorkspaceWithUncertainty)): + if not isinstance(pl_module, GlobalWorkspaceBase): return if ( @@ -449,7 +445,7 @@ def on_validation_epoch_end( if self.mode != "val": return - if not isinstance(pl_module, (GlobalWorkspace, GlobalWorkspaceWithUncertainty)): + if not isinstance(pl_module, GlobalWorkspaceBase): return if ( @@ -468,7 +464,7 @@ def on_test_epoch_end( if self.mode != "test": return - if not isinstance(pl_module, (GlobalWorkspace, GlobalWorkspaceWithUncertainty)): + if not isinstance(pl_module, GlobalWorkspaceBase): return return self.on_callback(trainer.current_epoch, trainer.loggers, pl_module) @@ -481,7 +477,7 @@ def on_fit_end( if self.mode == "test": return - if not isinstance(pl_module, (GlobalWorkspace, GlobalWorkspaceWithUncertainty)): + if not isinstance(pl_module, GlobalWorkspaceBase): return return self.on_callback(trainer.current_epoch, trainer.loggers, pl_module)