Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The example code in the Hugging Face documentation has an issue. #10287

Open
Zhiyuan-Fan opened this issue Dec 18, 2024 · 2 comments
Open

The example code in the Hugging Face documentation has an issue. #10287

Zhiyuan-Fan opened this issue Dec 18, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@Zhiyuan-Fan
Copy link

Describe the bug

https://huggingface.co/docs/diffusers/en/api/pipelines/ledits_pp

The examples in this link cannot be executed without encountering bugs.

for LEditsPPPipelineStableDiffusion:

Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.69it/s]
This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. The scheduler has been changed to DPMSolverMultistepScheduler.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 18.51it/s]
/home/zhiyuan/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
0%| | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/zhiyuan/prjs/unrealistic-images/diffusion/examples/leditspp.py", line 22, in
edited_image = pipe(
File "/home/zhiyuan/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/zhiyuan/anaconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py", line 1022, in call
out = self.attention_store.aggregate_attention(
File "/home/zhiyuan/anaconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py", line 128, in aggregate_attention
out = torch.stack([torch.cat(x, dim=0) for x in out])
File "/home/zhiyuan/anaconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py", line 128, in
out = torch.stack([torch.cat(x, dim=0) for x in out])
RuntimeError: torch.cat(): expected a non-empty list of Tensors

for LEditsPPPipelineStableDiffusionXL

Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 4.55it/s]
This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. The scheduler has been changed to DPMSolverMultistepScheduler.
Your input images far exceed the default resolution of the underlying diffusion model. The output images may contain severe artifacts! Consider down-sampling the input using the height and width parameters

File "/home/zhiyuan/anaconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
return F.conv2d(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 136.51 GiB. GPU 0 has a total capacity of 95.10 GiB of which 41.99 GiB is free. Including non-PyTorch memory, this process has 53.10 GiB memory in use. Of the allocated memory 48.46 GiB is allocated by PyTorch, and 4.20 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Reproduction

python

import PIL
import requests
import torch
from io import BytesIO

from diffusers import LEditsPPPipelineStableDiffusion
from diffusers.utils import load_image

pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
image = load_image(img_url).convert("RGB")

_ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)

edited_image = pipe(
    editing_prompt=["cherry blossom"], edit_guidance_scale=10.0, edit_threshold=0.75
).images[0]

python

import torch
import PIL
import requests
from io import BytesIO

from diffusers import LEditsPPPipelineStableDiffusionXL

pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")


def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")


img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
image = download_image(img_url)

_ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)

edited_image = pipe(
    editing_prompt=["tennis ball", "tomato"],
    reverse_editing_direction=[True, False],
    edit_guidance_scale=[5.0, 10.0],
    edit_threshold=[0.9, 0.85],
).images[0]

Logs

No response

System Info

  • 🤗 Diffusers version: 0.31.0
  • Platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.10.16
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.5
  • Transformers version: 4.47.0
  • Accelerate version: 1.2.1
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: 0.0.28.post3
  • Accelerator: NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
    NVIDIA H20, 97871 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @asomoza @sayakpaul @stevhliu

@Zhiyuan-Fan Zhiyuan-Fan added the bug Something isn't working label Dec 18, 2024
@sayakpaul
Copy link
Member

cc: @linoytsaban

@hlky
Copy link
Collaborator

hlky commented Dec 20, 2024

It looks like the store ends up empty due to max_size.

If we remove / 4 from max_size/att_res it seems to work, but I don't know the context of why / 4 was added.

@@ -911,12 +911,12 @@ class LEditsPPPipelineStableDiffusion(
             self.attention_store = LeditsAttentionStore(
                 average=store_averaged_over_steps,
                 batch_size=batch_size,
-                max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0),
+                max_size=latents.shape[-2] * latents.shape[-1],
                 max_resolution=None,
             )
             self.prepare_unet(self.attention_store, PnP=False)
             resolution = latents.shape[-2:]
-            att_res = (int(resolution[0] / 4), int(resolution[1] / 4))
+            att_res = (int(resolution[0]), int(resolution[1]))
 
         # 5. Prepare latent variables
         num_channels_latents = self.unet.config.in_channels
import torch
from diffusers import LEditsPPPipelineStableDiffusion
from diffusers.utils import load_image

pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
).to("cuda")

img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
image = load_image(img_url).convert("RGB")

_ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)

edited_image = pipe(
    editing_prompt=["cherry blossom"], edit_guidance_scale=10.0, edit_threshold=0.75
).images[0]

cherry blossom

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants