diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index 6dda26f796..94e9962e1e 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -381,7 +381,6 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -392,13 +391,10 @@ def save_lora_weights( unet_lora_layers = to_device_dtype(unet_lora_layers, target_device=torch.device("cpu")) if text_encoder_lora_layers: text_encoder_lora_layers = to_device_dtype(text_encoder_lora_layers, target_device=torch.device("cpu")) - if text_encoder_2_lora_layers: - text_encoder_2_lora_layers = to_device_dtype(text_encoder_2_lora_layers, target_device=torch.device("cpu")) return super().save_lora_weights( save_directory, unet_lora_layers, text_encoder_lora_layers, - text_encoder_2_lora_layers, is_main_process, weight_name, save_function, diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 6a1b74d129..7ef010c210 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import time from dataclasses import dataclass from math import ceil @@ -33,6 +34,7 @@ CLIPVisionModelWithProjection, ) +from optimum.habana.utils import to_device_dtype from optimum.utils import logging from ....transformers.gaudi_configuration import GaudiConfig @@ -142,6 +144,36 @@ def __init__( self.to(self._device) + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + # Move the state dict from HPU to CPU before saving + if unet_lora_layers: + unet_lora_layers = to_device_dtype(unet_lora_layers, target_device=torch.device("cpu")) + if text_encoder_lora_layers: + text_encoder_lora_layers = to_device_dtype(text_encoder_lora_layers, target_device=torch.device("cpu")) + if text_encoder_2_lora_layers: + text_encoder_2_lora_layers = to_device_dtype(text_encoder_2_lora_layers, target_device=torch.device("cpu")) + return StableDiffusionXLPipeline.save_lora_weights( + save_directory, + unet_lora_layers, + text_encoder_lora_layers, + text_encoder_2_lora_layers, + is_main_process, + weight_name, + save_function, + safe_serialization, + ) + def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != num_images: