Skip to content

Commit

Permalink
fix vae not downloading for pipelines othger than base
Browse files Browse the repository at this point in the history
  • Loading branch information
painebenjamin committed Jan 19, 2024
1 parent 47380df commit 87fb5dc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/python/enfugue/diffusion/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,6 @@ def refiner_vae(
"""
Sets a new refiner vae.
"""
pretrained_path = self.get_vae_path(new_vae)
existing_vae = getattr(self, "_refiner_vae", None)

if (
Expand All @@ -686,16 +685,17 @@ def refiner_vae(
self._refiner_vae = None
self.unload_refiner("VAE resetting to default")
else:
vae_path = self.check_download_model(self.engine_vae_dir, new_vae)
self._refiner_vae_name = new_vae
self._refiner_vae = self.get_vae(pretrained_path)
self._refiner_vae = self.get_vae(vae_path)
if self.refiner_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES:
self.unload_refiner("VAE changing")
elif hasattr(self, "_refiner_pipeline"):
logger.debug(f"Hot-swapping refiner pipeline VAE to {new_vae}")
self._refiner_pipeline.vae = self._refiner_vae # type: ignore[assignment]
if self.refiner_is_sdxl:
self._refiner_pipeline.register_to_config( # type: ignore[attr-defined]
force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"] or (new_vae.endswith("sdxl_vae.safetensors") and "16" not in new_vae)
force_full_precision_vae = "xl" in new_vae and "16" not in new_vae
)

@property
Expand Down Expand Up @@ -724,7 +724,6 @@ def inpainter_vae(
"""
Sets a new inpainter vae.
"""
pretrained_path = self.get_vae_path(new_vae)
existing_vae = getattr(self, "_inpainter_vae", None)

if (
Expand All @@ -737,16 +736,17 @@ def inpainter_vae(
self._inpainter_vae = None
self.unload_inpainter("VAE resetting to default")
else:
vae_path = self.check_download_model(self.engine_vae_dir, new_vae)
self._inpainter_vae_name = new_vae
self._inpainter_vae = self.get_vae(pretrained_path)
self._inpainter_vae = self.get_vae(vae_path)
if self.inpainter_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES:
self.unload_inpainter("VAE changing")
elif hasattr(self, "_inpainter_pipeline"):
logger.debug(f"Hot-swapping inpainter pipeline VAE to {new_vae}")
self._inpainter_pipeline.vae = self._inpainter_vae # type: ignore[assignment]
if self.inpainter_is_sdxl:
self._inpainter_pipeline.register_to_config( # type: ignore[attr-defined]
force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"] or (new_vae.endswith("sdxl_vae.safetensors") and "16" not in new_vae)
force_full_precision_vae = "xl" in new_vae and "16" not in new_vae
)

@property
Expand Down Expand Up @@ -775,7 +775,6 @@ def animator_vae(
"""
Sets a new animator vae.
"""
pretrained_path = self.get_vae_path(new_vae)
existing_vae = getattr(self, "_animator_vae", None)

if (
Expand All @@ -788,16 +787,17 @@ def animator_vae(
self._animator_vae = None
self.unload_animator("VAE resetting to default")
else:
vae_path = self.check_download_model(self.engine_vae_dir, new_vae)
self._animator_vae_name = new_vae
self._animator_vae = self.get_vae(pretrained_path)
self._animator_vae = self.get_vae(vae_path)
if self.animator_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES:
self.unload_animator("VAE changing")
elif hasattr(self, "_animator_pipeline"):
logger.debug(f"Hot-swapping animator pipeline VAE to {new_vae}")
self._animator_pipeline.vae = self._animator_vae # type: ignore [assignment]
if self.animator_is_sdxl:
self._animator_pipeline.register_to_config( # type: ignore[attr-defined]
force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"]
force_full_precision_vae = "xl" in new_vae and "16" not in new_vae
)

@property
Expand Down

0 comments on commit 87fb5dc

Please sign in to comment.