Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion optimum/habana/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import os
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, torch_utils, unscale_lora_layers

from optimum.habana.diffusers.utils.torch_utils import gaudi_fourier_filter
from ..utils.torch_utils import gaudi_fourier_filter
from .attention_processor import (
AttentionProcessor,
AttnProcessor2_0,
ScaledDotProductAttention,
)


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -357,3 +363,50 @@ def gaudi_unet_2d_condition_model_forward(
return (sample,)

return UNet2DConditionOutput(sample=sample)


def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
Added env PATCH_SDPA for HPU specific handle to use ScaledDotProductAttention.
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if os.environ.get("PATCH_SDPA") is not None:
setattr(module, "attention_module", ScaledDotProductAttention())
module.set_processor(processor(module.attention_module))
else:
if isinstance(processor, dict):
attention_processor = processor.pop(f"{name}.processor", None)
if attention_processor is not None:
module.set_processor(attention_processor)
else:
module.set_processor(processor)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)


def set_default_attn_processor_hpu(self):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
Disables custom attention processors and sets the default attention implementation from HPU.
"""
processor = AttnProcessor2_0()
set_attn_processor_hpu(self, processor)
2 changes: 1 addition & 1 deletion optimum/habana/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from diffusers.utils.torch_utils import is_compiled_module
from huggingface_hub import create_repo

from optimum.habana.utils import to_device_dtype
from optimum.utils import logging

from ...transformers.gaudi_configuration import GaudiConfig
from ...utils import to_device_dtype


logger = logging.get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import time
from dataclasses import dataclass
from math import ceil
Expand All @@ -30,15 +29,11 @@
from diffusers.utils import BaseOutput, deprecate
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from optimum.habana.diffusers.models.attention_processor import (
AttentionProcessor,
AttnProcessor2_0,
ScaledDotProductAttention,
)
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ...models.unet_2d_condition import set_default_attn_processor_hpu
from ..pipeline_utils import GaudiDiffusionPipeline


Expand Down Expand Up @@ -101,59 +96,6 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
Added env PATCH_SDPA for HPU specific handle to use ScaledDotProductAttention.
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.

If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.

"""

count = len(self.attn_processors.keys())

if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if os.environ.get("PATCH_SDPA") is not None:
setattr(module, "attention_module", ScaledDotProductAttention())
module.set_processor(processor(module.attention_module))
else:
if isinstance(processor, dict):
attention_processor = processor.pop(f"{name}.processor", None)
if attention_processor is not None:
module.set_processor(attention_processor)
else:
module.set_processor(processor)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)


def set_default_attn_processor_hpu(self):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
Disables custom attention processors and sets the default attention implementation from HPU.
"""

processor = AttnProcessor2_0()
set_attn_processor_hpu(self, processor)


class GaudiStableDiffusionPipeline(GaudiDiffusionPipeline, StableDiffusionPipeline):
"""
Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L73
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from dataclasses import dataclass
from math import ceil
Expand All @@ -38,6 +37,7 @@

from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ...models import set_default_attn_processor_hpu
from ..pipeline_utils import GaudiDiffusionPipeline
from ..stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps

Expand Down Expand Up @@ -138,6 +138,8 @@ def __init__(
force_zeros_for_empty_prompt,
)

self.unet.set_default_attn_processor = set_default_attn_processor_hpu

self.to(self._device)

def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def test_no_throughput_regression_bf16(self):
gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion"),
torch_dtype=torch.bfloat16,
)
pipeline.unet.set_default_attn_processor(pipeline.unet)
set_seed(27)
outputs = pipeline(
prompt=prompts,
Expand Down Expand Up @@ -1388,6 +1389,7 @@ def _sdxl_generation(self, scheduler: str, batch_size: int, num_images_per_promp
"stabilityai/stable-diffusion-xl-base-1.0",
**kwargs,
)
pipeline.unet.set_default_attn_processor(pipeline.unet)
num_images_per_prompt = num_images_per_prompt
res = {}
outputs = pipeline(
Expand Down
6 changes: 0 additions & 6 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,12 +1027,6 @@ class MultiCardCausalLanguageModelingAdaloraExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
):
TASK_NAME = "adalora"


class MultiCardCausalLanguageModelingLoRACPExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", deepspeed=True
):
TASK_NAME = "tatsu-lab/alpaca_cp"
DATASET_NAME = "tatsu-lab/alpaca"


Expand Down